1. 启动流程

首先看@RpcScan注解

1
2
3
4
5
6
7
8
9
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Import(CustomScannerRegistrar.class)
@Documented
public @interface RpcScan {

    String[] basePackage();

}

使用时将其加在服务端启动类上,并指定扫描路径

1
2
3
4
5
6
7
@SpringBootApplication
@RpcScan(basePackage = {"github.javaguide"})
public class ServerApplication {
    public static void main(String[] args) {
        SpringApplication.run(ServerApplication.class, args);
    }
}

接下来聚焦于@Import(CustomScannerRegistrar.class),直接看完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@Slf4j
public class CustomScannerRegistrar implements ImportBeanDefinitionRegistrar, ResourceLoaderAware {
    private static final String SPRING_BEAN_BASE_PACKAGE = "github.javaguide";
    private static final String BASE_PACKAGE_ATTRIBUTE_NAME = "basePackage";
    private ResourceLoader resourceLoader;

    @Override
    public void setResourceLoader(ResourceLoader resourceLoader) {
        this.resourceLoader = resourceLoader;

    }

    @Override
    public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {
        //get the attributes and values of RpcScan annotation
        AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName()));
        String[] rpcScanBasePackages = new String[0];
        if (rpcScanAnnotationAttributes != null) {
            // get the value of the basePackage property
            rpcScanBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME);
        }
        if (rpcScanBasePackages.length == 0) {
            //如果没有指定包路径,默认扫描当前类所在的包
            rpcScanBasePackages = new String[]{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()};
        }
        // Scan the RpcService annotation
        CustomScanner rpcServiceScanner = new CustomScanner(beanDefinitionRegistry, RpcService.class);
        // Scan the Component annotation
        CustomScanner springBeanScanner = new CustomScanner(beanDefinitionRegistry, Component.class);
        if (resourceLoader != null) {
            rpcServiceScanner.setResourceLoader(resourceLoader);
            springBeanScanner.setResourceLoader(resourceLoader);
        }
        int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE);
        log.info("springBeanScanner扫描的数量 [{}]", springBeanAmount);
        int rpcServiceCount = rpcServiceScanner.scan(rpcScanBasePackages);
        log.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);

    }

}

registerBeanDefinitions方法,形参AnnotationMetadata封装了@RpcScan所在类的所有注解信息

第一行

1
AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName()));

fromMap里面得到的是个Map类型,key为注解属性名,value为其相应的值,最终将其封装为AnnotationAttributes

这里的value是个String数组,所以接下来提取这个字符串数组,获得需要扫描的包的路径集,如果没有指定路径,就扫描启动类所在的包

1
2
3
4
5
6
7
8
9
String[] rpcScanBasePackages = new String[0];
if (rpcScanAnnotationAttributes != null) {
    // get the value of the basePackage property
    rpcScanBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME);
}
if (rpcScanBasePackages.length == 0) {
    //如果没有指定包路径,默认扫描当前类所在的包
    rpcScanBasePackages = new String[]{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()};
}

接下来实例化了两个扫描器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Scan the RpcService annotation
CustomScanner rpcServiceScanner = new CustomScanner(beanDefinitionRegistry, RpcService.class);
// Scan the Component annotation
CustomScanner springBeanScanner = new CustomScanner(beanDefinitionRegistry, Component.class);
if (resourceLoader != null) {
    rpcServiceScanner.setResourceLoader(resourceLoader);
    springBeanScanner.setResourceLoader(resourceLoader);
}
int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE);
log.info("springBeanScanner扫描的数量 [{}]", springBeanAmount);
int rpcServiceCount = rpcServiceScanner.scan(rpcScanBasePackages);
if (rpcServiceCount != 0) {
    NettyRpcServer.isNeedStart = true;
}
log.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);
1
2
3
4
5
6
7
8
9
10
11
12
public class CustomScanner extends ClassPathBeanDefinitionScanner {

    public CustomScanner(BeanDefinitionRegistry registry, Class<? extends Annotation> annoType) {
        super(registry);
        super.addIncludeFilter(new AnnotationTypeFilter(annoType));
    }

    @Override
    public int scan(String... basePackages) {
        return super.scan(basePackages);
    }
}

这里会将扫描到的满足条件的类注册到Spring中,scan返回了注册的类的数量

从实际代码来看就是扫描@Component@RpcService所在的类并加入到IOC容器

@Component主要是在RPC内部实现类上,后者是在服务提供者上。

当Spring容器初始化完成后,会启动Netty服务。

1
2
3
4
5
6
7
8
9
@Component
public class StartupApplicationListener {
    @Autowired
    private NettyRpcServer nettyRpcServer;
    @EventListener(ContextRefreshedEvent.class)
    public void rpcStart() {
        nettyRpcServer.start();
    }
}

2. SPI机制

首先通过ExtensionLoadergetExtensionLoader方法获得加载器

如下

1
2
3
LoadBalance loadBalance = ExtensionLoader
        .getExtensionLoader(LoadBalance.class)
        .getExtension(StringUtil.isBlank(rpcReferenceConfig.getLoadBalance()) ? LoadBalanceEnum.LOADBALANCE.getName() : rpcReferenceConfig.getLoadBalance());

看看getExtensionLoader方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();

public static <S> ExtensionLoader<S> getExtensionLoader(Class<S> type) {
    if (type == null) {
        throw new IllegalArgumentException("Extension type should not be null.");
    }
    if (!type.isInterface()) {
        throw new IllegalArgumentException("Extension type must be an interface.");
    }
    if (type.getAnnotation(SPI.class) == null) {
        throw new IllegalArgumentException("Extension type must be annotated by @SPI");
    }
    // firstly get from cache, if not hit, create one
    ExtensionLoader<S> extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
    if (extensionLoader == null) {
        EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader<S>(type));
        extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
    }
    return extensionLoader;
}

如果EXTENSION_LOADERS里面没有当前类的加载器,就创建一个并缓存,下次直接返回。

接下来通过与接口绑定的ExtensionLoadergetExtension方法获得实现类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();

public T getExtension(String name) {
    if (StringUtil.isBlank(name)) {
        throw new IllegalArgumentException("Extension name should not be null or empty.");
    }
    // firstly get from cache, if not hit, create one
    Holder<Object> holder = cachedInstances.get(name);
    if (holder == null) {
        cachedInstances.putIfAbsent(name, new Holder<>());
        holder = cachedInstances.get(name);
    }
    // create a singleton if no instance exists
    Object instance = holder.get();
    if (instance == null) {
        synchronized (holder) {
            instance = holder.get();
            if (instance == null) {
                instance = createExtension(name);
                holder.set(instance);
            }
        }
    }
    return (T) instance;
}

Holder看作具体实现类的容器

1
2
3
4
5
6
7
8
9
10
11
12
public class Holder<T> {

    private volatile T value;

    public T get() {
        return value;
    }

    public void set(T value) {
        this.value = value;
    }
}

如果通过Holer获取到的实例为空,则通过双检加锁的方式获得实力,具体实际到createExtension方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
private T createExtension(String name) {
    // load all extension classes of type T from file and get specific one by name
    Class<?> clazz = getExtensionClasses().get(name);
    if (clazz == null) {
        throw new RuntimeException("No such extension of name " + name);
    }
    T instance = (T) EXTENSION_INSTANCES.get(clazz);
    if (instance == null) {
        try {
            EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance());
            instance = (T) EXTENSION_INSTANCES.get(clazz);
        } catch (Exception e) {
            log.error(e.getMessage());
        }
    }
    return instance;
}



public Map<String, Class<?>> getExtensionClasses() {
    // get the loaded extension class from the cache
    Map<String, Class<?>> classes = cachedClasses.get();
    // double check
    if (classes == null) {
        synchronized (cachedClasses) {
            classes = cachedClasses.get();
            if (classes == null) {
                classes = new HashMap<>();
                // load all extensions from our extensions directory
                loadDirectory(classes);
                cachedClasses.set(classes);
            }
        }
    }
    return classes;
}

private void loadDirectory(Map<String, Class<?>> extensionClasses) {
    String fileName = ExtensionLoader.SERVICE_DIRECTORY + type.getName();
    try {
        Enumeration<URL> urls;
        ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
        urls = classLoader.getResources(fileName);
        if (urls != null) {
            while (urls.hasMoreElements()) {
                URL resourceUrl = urls.nextElement();
                loadResource(extensionClasses, classLoader, resourceUrl);
            }
        }
    } catch (IOException e) {
        log.error(e.getMessage());
    }
}

首先通过getExtensionClasses().get(name);获得实现类的Class对象,如果是第一次,即存储这种映射关系的Map对象为空,则会调用loadDirectory方法,将文件中的键值关系存到map里面。

image-20250203205620361

最终通过反射创建实现类对象,毕竟刚刚拿到了实现类的Class对象,并放入EXTENSION_INSTANCES

1
private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();

该实例对象作为createExtension的返回值返回,调用holder.set(instance);

下次再来就是通过name拿到Holder对象,通过Holder对象拿到具体实现类对象。

总结

  1. 首先去name去cachedInstances里面找具体对象,没有找到,往里面放入一个空Holder对象
  2. 调用createExtension方法创建对象
    1. 通过name去cachedClasses里面找相应的Class对象,没找到,调用loadDirectory方法加载当前接口的所有实现类并存储映射关系
    2. 拿到Class对象后,去EXTENSION_INSTANCES里面找实例对象,没找到,通过反射创建对象存入其中并返回该对象
  3. 将返回的对象设置到Holder里面去,下次再来就是通过name拿到相应的Holder,调用其get方法获取实例对象。

3. Netty配置

当双方发送消息时,会先经过RpcMessageEncoder,这里会按照协议填充数据包,并对数据进行序列化然后压缩最后发送。

image-20250227210647810

收到消息后,经过RpcMessageDecoder进行解码

1
2
3
4
5
6
7
8
9
10
11
@Slf4j
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {
    public RpcMessageDecoder() {
        // lengthFieldOffset: magic code is 4B, and version is 1B, and then full length. so value is 5
        // lengthFieldLength: full length is 4B. so value is 4
        // lengthAdjustment: full length include all data and read 9 bytes before, so the left length is (fullLength-9). so values is -9
        // initialBytesToStrip: we will check magic code and version manually, so do not strip any bytes. so values is 0
        this(RpcConstants.MAX_FRAME_LENGTH, 5, 4, -9, 0);
    }
    ......
}

这里分别指定了:

  • 最大帧长度(8k)
  • 长度偏移量
  • 长度字段长度
  • 数据长度修正(-9):因为在此之前一共读了魔数(4B)、版本(1B)、长度字段(4B)
  • 跳过的字节数(0):因为需要检查魔数等字段,所以不跳过任何字节。

3.1 服务端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@Slf4j
@Component
public class NettyRpcServer {
    

    private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
    
    private static boolean isStart = false;
    
    public static boolean isNeedStart = false;

    public void registerService(RpcServiceConfig rpcServiceConfig) {
        serviceProvider.publishService(rpcServiceConfig);
    }
    
    

    @SneakyThrows
    public void start() {
        if (isStart || !isNeedStart) return;
        isStart = true;
        new Thread(() -> {
            CustomShutdownHook.getCustomShutdownHook().clearAll();
            String host = null;
            try {
                host = InetAddress.getLocalHost().getHostAddress();
            } catch (UnknownHostException e) {
                throw new RuntimeException(e);
            }
            EventLoopGroup bossGroup = new NioEventLoopGroup(1);
            EventLoopGroup workerGroup = new NioEventLoopGroup();
            DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
                    RuntimeUtil.cpus() * 2,
                    ThreadPoolFactoryUtil.createThreadFactory("service-handler-group", false)
            );
            try {
                ServerBootstrap b = new ServerBootstrap();
                b.group(bossGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        // TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
                        .childOption(ChannelOption.TCP_NODELAY, true)
                        // 是否开启 TCP 底层心跳机制
                        .childOption(ChannelOption.SO_KEEPALIVE, true)
                        //表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
                        .option(ChannelOption.SO_BACKLOG, 128)
                        .handler(new LoggingHandler(LogLevel.INFO))
                        // 当客户端第一次进行请求的时候才会进行初始化
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) {
                                // 30 秒之内没有收到客户端请求的话就关闭连接
                                ChannelPipeline p = ch.pipeline();
                                p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
                                p.addLast(new RpcMessageEncoder());
                                p.addLast(new RpcMessageDecoder());
                                p.addLast(serviceHandlerGroup, new NettyRpcServerHandler());
                            }
                        });
                int port = SpringUtils.getBean(RpcProperties.class).getServerPort();
                // 绑定端口,同步等待绑定成功
                ChannelFuture f = b.bind(host, port).sync();
                // 等待服务端监听端口关闭
                f.channel().closeFuture().sync();
            } catch (InterruptedException e) {
                log.error("occur exception when start server:", e);
            } finally {
                log.error("shutdown bossGroup and workerGroup");
                bossGroup.shutdownGracefully();
                workerGroup.shutdownGracefully();
                serviceHandlerGroup.shutdownGracefully();
            }
        }).start();
    }
}

可以看到,bossGroup用来处理连接事件,线程数为1,workerGroup处理已建立连接的 Socket 的 I/O 读写操作,多线程提升并发处理能力。这里没有传参,默认为cpu核心数 * 2serviceHandlerGroup用来处理业务操作,这里也是cpu核心数 * 2

同时也禁用了Nagle算法,确保 RPC 请求的响应数据立即发送,减少网络传输延迟。

Nagle算法

当应用层发送数据时,若存在未确认的已发送小包,则新数据会暂存缓冲区。
直到收到前一个包的 ACK 确认,或缓冲区数据积累到最大报文长度(MSS),才会一次性发送缓冲区的数据。

作用:

  • 减少小数据包:强制发送端将多个小数据包合并成一个较大的数据包后再发送。
  • 降低网络拥塞:避免网络中充斥大量小包(如仅含 1 字节数据的包),减少带宽浪费和网络拥堵

这里配置了心跳检测处理器:

1
p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));

关于前面3个参数:

1
IdleStateHandler(int readerIdleTime, int writerIdleTime, int allIdleTime, TimeUnit unit)
  • readerIdleTime: 读空闲超时,多久未收到客户端数据时触发空闲事件
  • writerIdleTime: 写空闲超时,多久未向客户端发送数据时触发空闲事件
  • allIdleTime:全局空闲超时,多久既未读也未写数据时触发空闲事件

这里是30s未收到就断开连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Slf4j
public class NettyRpcServerHandler extends ChannelInboundHandlerAdapter {
    ......
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent) {
            IdleState state = ((IdleStateEvent) evt).state();
            //如果在设定的时间内没有接收到任何数据,IdleStateHandler 会触发 READER_IDLE 事件。
            if (state == IdleState.READER_IDLE) {
                log.info("idle check happen, so close the connection");
                ctx.close();
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }
    ......
}

请求到达服务端,首先经过RpcMessageDecoder处理,根据messageType构造相应的RpcMessage对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) {
    rpcMessage.setData(RpcConstants.PING);
    return rpcMessage;
}
if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
    rpcMessage.setData(RpcConstants.PONG);
    return rpcMessage;
}
//先解压再反序列化
if (messageType == RpcConstants.REQUEST_TYPE) {
    RpcRequest tmpValue = serializer.deserialize(bs, RpcRequest.class);
    rpcMessage.setData(tmpValue);
} else {
    RpcResponse tmpValue = serializer.deserialize(bs, RpcResponse.class);
    rpcMessage.setData(tmpValue);
}

之后数据传递给NettyRpcServerHandler,通过动态代理执行本地方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public Object handle(RpcRequest rpcRequest) {
    Object service = serviceProvider.getService(rpcRequest.getRpcServiceName());
    return invokeTargetMethod(rpcRequest, service);
}

/**
 * get method execution results
 *
 * @param rpcRequest client request
 * @param service    service object
 * @return the result of the target method execution
 */
private Object invokeTargetMethod(RpcRequest rpcRequest, Object service) {
    Object result;
    try {
        Method method = service.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParamTypes());
        result = method.invoke(service, rpcRequest.getParameters());
        log.info("service:[{}] successful invoke method:[{}]", rpcRequest.getInterfaceName(), rpcRequest.getMethodName());
    } catch (NoSuchMethodException | IllegalArgumentException | InvocationTargetException | IllegalAccessException e) {
        throw new RpcException(e.getMessage(), e.getCause());
    }
    return result;
}

最后将结果写回客户端。

3.2 客户端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public NettyRpcClient(RpcProperties rpcProperties) {
    this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
    this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
    this.rpcProperties = rpcProperties;
    this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension(rpcProperties.getRegistry());
    NettyRpcClient client = this;
    // initialize resources such as EventLoopGroup, Bootstrap
    eventLoopGroup = new NioEventLoopGroup();
    bootstrap = new Bootstrap();
    bootstrap.group(eventLoopGroup)
            .channel(NioSocketChannel.class)
            .handler(new LoggingHandler(LogLevel.INFO))
            //  The timeout period of the connection.
            //  If this time is exceeded or the connection cannot be established, the connection fails.
            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
            .handler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel ch) {
                    ChannelPipeline p = ch.pipeline();
                    // If no data is sent to the server within 15 seconds, a heartbeat request is sent
                    p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
                    p.addLast(new RpcMessageEncoder());
                    p.addLast(new RpcMessageDecoder());
                    p.addLast(new NettyRpcClientHandler(client));
                }
            });
}

这里如果5s内客户端没有向服务端写数据就会触发写空闲事件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
    if (evt instanceof IdleStateEvent) {
        IdleState state = ((IdleStateEvent) evt).state();
        //如果在设定的时间内没有向通道写入任何数据,IdleStateHandler 会触发 WRITER_IDLE 事件。
        if (state == IdleState.WRITER_IDLE) {
            log.info("write idle happen [{}]", ctx.channel().remoteAddress());
            Channel channel = nettyRpcClient.getChannel((InetSocketAddress) ctx.channel().remoteAddress());
            RpcMessage rpcMessage = new RpcMessage();
            rpcMessage.setCodec(SerializationTypeEnum.PROTOSTUFF.getCode());
            rpcMessage.setCompress(CompressTypeEnum.getCode(rpcProperties.getCompress()));
            rpcMessage.setMessageType(RpcConstants.HEARTBEAT_REQUEST_TYPE);
            rpcMessage.setData(RpcConstants.PING);
            channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
        }
    } else {
        super.userEventTriggered(ctx, evt);
    }
}

3.3 方法调用流程

1
2
3
public interface HelloService {
    String hello(Hello hello);
}

这个实现类就是服务提供者

1
2
3
4
5
6
7
8
9
10
@Slf4j
@RpcService(group = "test1", version = "version1")
public class HelloServiceImpl1 implements HelloService {
    @Override
    public String hello(Hello hello) {
        log.info("这是hello1");
        int a = 1 / 0;
        return hello.toString();
    }
}

这里是消费者

1
2
3
4
5
6
7
8
9
10
@RestController
public class TestController {
    @RpcReference(version = "version1", group = "test1", loadBalance = LoadBalanceStrategy.RANDOM)
    private HelloService helloService;
    @GetMapping
    public String test(){
        String hello = helloService.hello(new Hello("111", "222"));
        return hello;
    }
}

@RpcReference可以指定一些基础配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Inherited
public @interface RpcReference {

    /**
     * Service version, default value is empty string
     */
    String version() default "";

    /**
     * Service group, default value is empty string
     */
    String group() default "";

    /**
     * 重试次数
     */
    int retries() default 3;

    /**
     * 负载均衡
     */
    String loadBalance() default  "";

    /**
     * 集群容错
     */
    String cluster() default  "";

    /**
     * 是否异步执行
     */
    boolean async() default false;

}

当消费者发起方法调用时,会被动态代理拦截,执行RpcClientProxyinvoke方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@SneakyThrows
@SuppressWarnings("unchecked")
@Override
public Object invoke(Object proxy, Method method, Object[] args) {
    log.info("invoked method: [{}]", method.getName());
    RpcRequest rpcRequest = RpcRequest.builder()
            .methodName(method.getName())
            .parameters(args)
            .interfaceName(method.getDeclaringClass().getName())
            .paramTypes(method.getParameterTypes())
            .requestId(UUID.randomUUID().toString())
            .group(rpcServiceConfig.getGroup())
            .version(rpcServiceConfig.getVersion())
            .consumerName(o.getClass().getName())
            .build();
    RpcReferenceConfig rpcReferenceConfig = RpcReferenceConfig.getRpcReferenceConfig(rpcRequest.getConsumerName(), rpcRequest.getInterfaceName());
    //发送请求
    Object responseFuture = rpcRequestTransport.sendRpcRequest(rpcRequest);
    
    if (rpcReferenceConfig.isAsync()) {
        CompletableFuture<RpcResponse<?>> future = (CompletableFuture<RpcResponse<?>>) responseFuture;
        FutureContext.setFuture(future);
    } else {
        RpcResponse rpcResponse = ((CompletableFuture<RpcResponse<Object>>)responseFuture).get();
        check(rpcResponse, rpcRequest);
        return rpcResponse.getData();
    }
    return null;
}

主要看sendRpcRequest方法,里面会根据用户的配置选择相应的集群容错方式进行服务调用,为了测试方便我这里是直接new了一个FixedIntervalRetryCluster,即重试固定次数。

1
2
3
4
@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
    return new FixedIntervalRetryCluster(this).invoke(rpcRequest,new ClusterConfig());
}

FixedIntervalRetryCluster

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@Slf4j
public class FixedIntervalRetryCluster extends AbstractClusterFaultTolerant {
    private static final ScheduledExecutorService scheduler =
            Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());
    
    public FixedIntervalRetryCluster(NettyRpcClient rpcClient) {
        super(rpcClient);
    }

    @Override
    public Object invoke(RpcRequest rpcRequest, ClusterConfig clusterConfig) {
        CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
        AtomicInteger remainingRetries = new AtomicInteger(clusterConfig.getRetryTimes());
        Consumer<RpcRequest> attempt = new Consumer<RpcRequest>() {
            @Override
            public void accept(RpcRequest request) {
                sendSingleRequest(request).whenComplete((response, ex) -> {
                    if (response != null) {
                        // 成功时完成 Future
                        resultFuture.complete(response);
                    } else if (ex != null && isRetryable(ex) && remainingRetries.get() > 0) {
                        // 执行重试逻辑
                        int retryCount = clusterConfig.getRetryTimes() - remainingRetries.decrementAndGet();
                        log.info("第 {} 次重试请求: {}", retryCount, request);

                        // 6. 使用调度器实现非阻塞延迟
                        scheduler.schedule(() -> {
                            // 7. 创建新请求对象避免状态污染
                            RpcRequest newRequest = reBuildRequest(request);
                            accept(newRequest); // 迭代调用
                        }, clusterConfig.getRetryIntervals(), TimeUnit.MILLISECONDS);
                    } else {
                        // 8. 最终异常处理
                        resultFuture.completeExceptionally(
                                ex != null ? ex : new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE)
                        );
                    }
                });
            }
        };
        attempt.accept(rpcRequest);
        return resultFuture;
    }
}

这里对单独消息发送的逻辑抽取成了一个方法sendSingleRequest

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public CompletableFuture<RpcResponse<Object>> sendSingleRequest(RpcRequest rpcRequest) {
    CompletableFuture<RpcResponse<Object>> future = new CompletableFuture<>();
    InetSocketAddress address = rpcClient.getServiceDiscovery().lookupService(rpcRequest);
    Channel channel = null;
    try {
        channel = rpcClient.getChannel(address);
    } catch (Exception e) {
        future.completeExceptionally(e);
        return future;
    }
    if (channel.isActive()) {
        rpcClient.getUnprocessedRequests().put(rpcRequest.getRequestId(), future);
        RpcProperties rpcProperties = rpcClient.getRpcProperties();
        RpcMessage rpcMessage = RpcMessage.builder()
                .data(rpcRequest)
                .codec(SerializationTypeEnum.getCode(rpcProperties.getSerializer()))
                .compress(CompressTypeEnum.getCode(rpcProperties.getCompress()))
                .messageType(RpcConstants.REQUEST_TYPE).build();
        channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) sendFuture -> {
            if (!sendFuture.isSuccess()) {
                future.completeExceptionally(sendFuture.cause());
                log.error("Send request failed", sendFuture.cause());
            }
        });
    } else {
        future.completeExceptionally(new RpcException(RpcErrorMessageEnum.CLIENT_CONNECT_SERVER_FAILURE));
    }
    return future;
}

首先通过负载均衡选出一个提供者地址,通过getChannel查看当前客户端与服务端是否有已经建立并且还活跃的连接(Channel),没有则进行连接建立。如果没有异常,则建立requestIdCompletableFuture之间的映射关系并构造RpcMessage对象,最后直接返回future。

最终请求到达服务端并执行目标方法,返回结果,到达NettyRpcClientHandler,取出执行结果,根据里面的requestId找到对应的future对象,调用其complete方法,此时触发集群容错里面重写的invoke方法的回调函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
sendSingleRequest(request).whenComplete((response, ex) -> {
    if (response != null) {
        // 成功时完成 Future
        resultFuture.complete(response);
    } else if (ex != null && isRetryable(ex) && remainingRetries.get() > 0) {
        // 执行重试逻辑
        int retryCount = clusterConfig.getRetryTimes() - remainingRetries.decrementAndGet();
        log.info("第 {} 次重试请求: {}", retryCount, request);

        // 6. 使用调度器实现非阻塞延迟
        scheduler.schedule(() -> {
            // 7. 创建新请求对象避免状态污染
            RpcRequest newRequest = reBuildRequest(request);
            accept(newRequest); // 迭代调用
        }, clusterConfig.getRetryIntervals(), TimeUnit.MILLISECONDS);
    } else {
        // 8. 最终异常处理
        resultFuture.completeExceptionally(
                ex != null ? ex : new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE)
        );
    }
});

当重试次数耗尽就调用completeExceptionally方法。一开始的动态代理的invoke方法就通过这里返回的future对象获取原始数据。

这里的返回的future对象作为sendRpcRequest的返回值

1
2
3
4
5
6
7
8
9
10
Object responseFuture = rpcRequestTransport.sendRpcRequest(rpcRequest);

if (rpcReferenceConfig.isAsync()) {
    CompletableFuture<RpcResponse<?>> future = (CompletableFuture<RpcResponse<?>>) responseFuture;
    FutureContext.setFuture(future);
} else {
    RpcResponse rpcResponse = ((CompletableFuture<RpcResponse<Object>>)responseFuture).get();
    check(rpcResponse, rpcRequest);
    return rpcResponse.getData();
}

如果不是异步调用,就在这里通过get方法阻塞获取最终值。

4. 服务注册流程

服务通过@RpcScan指定扫描路径,通过自定义扫描器扫描指定路径,将被@RpcService修饰的类注册为Spring Bean.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Inherited
public @interface RpcService {

    /**
     * Service version, default value is empty string
     */
    String version() default "";

    /**
     * Service group, default value is empty string
     */
    String group() default "";

}

这里自定义了Spring的后置处理器SpringBeanPostProcessor并重写了postProcessBeforeInitializationpostProcessAfterInitialization方法,前者用来处理@RpcService注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@SneakyThrows
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
    if (bean.getClass().isAnnotationPresent(RpcService.class)) {
        log.info("[{}] is annotated with  [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
        // get RpcService annotation
        RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
        // build RpcServiceProperties
        RpcServiceConfig rpcServiceConfig = RpcServiceConfig.builder()
                .group(rpcService.group())
                .version(rpcService.version())
                .service(bean)
                .build();
        serviceProvider.publishService(rpcServiceConfig);
    }
    return bean;
}

这里会根据注解属性构造RpcServiceConfig对象,这里的bean变量就是某个接口的具体实现类,通过publishService进行服务发布,这里默认使用Zookeeper作为注册中心。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@Override
public void publishService(RpcServiceConfig rpcServiceConfig) {
    int port = SpringUtils.getBean(RpcProperties.class).getServerPort();
    try {
        String host = InetAddress.getLocalHost().getHostAddress();
        this.addService(rpcServiceConfig);
        //发布到注册中心  key为rpcServiceName,value为ip
        serviceRegistry.registerService(rpcServiceConfig.getRpcServiceName(), new InetSocketAddress(host, port));
    } catch (UnknownHostException e) {
        log.error("occur exception when getHostAddress", e);
    }
}

/**
 * Set中加入rpcServiceName(github.javaguide.HelloServicetest1version1)
 * map中加入rpcServiceName和service对象(HelloServiceImpl)
 * @param rpcServiceConfig rpc service related attributes
 */
@Override
public void addService(RpcServiceConfig rpcServiceConfig) {
    String rpcServiceName = rpcServiceConfig.getRpcServiceName();
    if (registeredService.contains(rpcServiceName)) {
        return;
    }
    registeredService.add(rpcServiceName);
    serviceMap.put(rpcServiceName, rpcServiceConfig.getService());
    log.info("Add service: {} and interfaces:{}", rpcServiceName, rpcServiceConfig.getService().getClass().getInterfaces());
}

这里的addService方法主要是添加服务名、以及服务名和实现类的映射关系。

服务名为接口名加上版本号和分组信息(github.javaguide.HelloServicetest1version1)

这里的registerService就是将提供者信息注册到注册中心上去。

1
2
3
4
5
6
7
8
9
public class ZkServiceRegistryImpl implements ServiceRegistry {

    @Override
    public void registerService(String rpcServiceName, InetSocketAddress inetSocketAddress) {
        String servicePath = CuratorUtils.ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName + inetSocketAddress.toString();
        CuratorFramework zkClient = CuratorUtils.getZkClient();
        CuratorUtils.createPersistentNode(zkClient, servicePath);
    }
}

image-20250228202444446

同时会在本地保存这个一份儿这个path。

5. 消费者注册流程

这里主要看SpringBeanPostProcessorpostProcessAfterInitialization方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
    Class<?> targetClass = bean.getClass();
    Field[] declaredFields = targetClass.getDeclaredFields();
    for (Field declaredField : declaredFields) {
        RpcReference rpcReference = declaredField.getAnnotation(RpcReference.class);
        if (rpcReference != null) {
            RpcServiceConfig rpcServiceConfig = RpcServiceConfig.builder()
                    .group(rpcReference.group())
                    .version(rpcReference.version())
                    .build();
            RpcReferenceConfig.addRpcReferenceConfig(bean.getClass().getName(),declaredField.getType().getName(), rpcReference);
            RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient, rpcServiceConfig,bean);
            Object clientProxy = rpcClientProxy.getProxy(declaredField.getType());
            declaredField.setAccessible(true);
            try {
                declaredField.set(bean, clientProxy);
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }

    }
    return bean;
}

会检查每个Bean里面是否有被@RpcReference修饰的字段,如果有的话就创建一个代理对象来替代这个字段,这就是为什么进行方法调用会被拦截。这里同时也会建立consumerName + "." + interfaceNameRpcReferenceConfig之间的映射关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
public class RpcReferenceConfig {
    
    public static Map<String,RpcReferenceConfig> rpcReferenceConfigMap = new ConcurrentHashMap<>();

    /**
     * Service version, default value is empty string
     */
    private String version;

    /**
     * Service group, default value is empty string
     */
    private String group;

    private int retries;

    private String loadBalance;

    private String cluster;
    
    private boolean async;
    
    public static void addRpcReferenceConfig(String consumerName, String fieldName, RpcReference rpcReference) {
        rpcReferenceConfigMap.put(consumerName + "." + fieldName,RpcReferenceConfig.builder()
                .group(rpcReference.group())
                .version(rpcReference.version())
                .retries(rpcReference.retries())
                .loadBalance(rpcReference.loadBalance())
                .cluster(rpcReference.cluster())
                .async(rpcReference.async())
                .build());
    }
    
    public static RpcReferenceConfig getRpcReferenceConfig(String consumerName, String interfaceName) {
        return rpcReferenceConfigMap.get(consumerName + "." + interfaceName);
    }
}

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Inherited
public @interface RpcReference {

    /**
     * Service version, default value is empty string
     */
    String version() default "";

    /**
     * Service group, default value is empty string
     */
    String group() default "";

    /**
     * 重试次数
     */
    int retries() default 3;

    /**
     * 负载均衡
     */
    String loadBalance() default  "";

    /**
     * 集群容错
     */
    String cluster() default  "";

    /**
     * 是否异步执行
     */
    boolean async() default false;

}

6. 负载均衡策略

在发送请求之前,要先得到服务提供者的ip地址

1
InetSocketAddress address = rpcClient.getServiceDiscovery().lookupService(rpcRequest);

目前只实现了Zookeeper作为注册中心

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Slf4j
public class ZkServiceDiscoveryImpl implements ServiceDiscovery {


    @Override
    public InetSocketAddress lookupService(RpcRequest rpcRequest) {
        String rpcServiceName = rpcRequest.getRpcServiceName();
        CuratorFramework zkClient = CuratorUtils.getZkClient();
        List<String> serviceUrlList = CuratorUtils.getChildrenNodes(zkClient, rpcServiceName);
        if (CollectionUtil.isEmpty(serviceUrlList)) {
            throw new RpcException(RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND, rpcServiceName);
        }
        // load balancing
        RpcReferenceConfig rpcReferenceConfig = RpcReferenceConfig.getRpcReferenceConfig(rpcRequest.getConsumerName(), rpcRequest.getInterfaceName());
        LoadBalance loadBalance = ExtensionLoader
                .getExtensionLoader(LoadBalance.class)
                .getExtension(StringUtil.isBlank(rpcReferenceConfig.getLoadBalance()) ? LoadBalanceEnum.LOADBALANCE.getName() : rpcReferenceConfig.getLoadBalance());
        String targetServiceUrl = loadBalance.selectServiceAddress(serviceUrlList, rpcRequest);
        log.info("Successfully found the service address:[{}]", targetServiceUrl);
        String[] socketAddressArray = targetServiceUrl.split(":");
        String host = socketAddressArray[0];
        int port = Integer.parseInt(socketAddressArray[1]);
        return new InetSocketAddress(host, port);
    }
}

这里会先通过rpcServiceName获取可用服务提供者ip地址,再通过配置的负载均衡算法选择出ip地址。

7. 异步调用

jdk动态代理里面的invoke方法里面调用sendRpcRequest发送请求,返回一个future对象,如果不是异步调用,则会通过其get方法阻塞获取返回值,否则将这个future对象绑定到当前线程的ThreadLocal上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class FutureContext {
    private static final ThreadLocal<CompletableFuture<RpcResponse<?>>> FUTURE_THREAD_LOCAL = new ThreadLocal<>();

    public static <T> CompletableFuture<T> getFuture(Class<T> targetType) {
        CompletableFuture<RpcResponse<?>> responseFuture = FUTURE_THREAD_LOCAL.get();
        if (responseFuture == null) {
            throw new IllegalStateException("No future in context for current thread");
        }
        clearFuture();
        CompletableFuture<T> resultFuture = new CompletableFuture<>();
        responseFuture.whenComplete((rpcResponse, throwable) -> {
            if (throwable != null) {
                resultFuture.completeExceptionally(throwable);
            } else {
                Object data = rpcResponse.getData();
                if (targetType.isInstance(data)) {
                    resultFuture.complete(targetType.cast(data));
                } else {
                    resultFuture.completeExceptionally(new ClassCastException("Cannot cast response data to " + targetType.getName()));
                }
            }
        });
        return resultFuture;
    }

    public static void setFuture(CompletableFuture<RpcResponse<?>> future) {
        FUTURE_THREAD_LOCAL.set(future);
    }

    public static void clearFuture() {
        FUTURE_THREAD_LOCAL.remove();
    }
}

后续用户可以通过FutureContext#getFuture方法获取返回的future对象。