Rpc
1. 启动流程
首先看@RpcScan注解
1
2
3
4
5
6
7
8
9
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Import(CustomScannerRegistrar.class)
@Documented
public @interface RpcScan {
String[] basePackage();
}使用时将其加在服务端启动类上,并指定扫描路径
1
2
3
4
5
6
7
@SpringBootApplication
@RpcScan(basePackage = {"github.javaguide"})
public class ServerApplication {
public static void main(String[] args) {
SpringApplication.run(ServerApplication.class, args);
}
}接下来聚焦于@Import(CustomScannerRegistrar.class),直接看完整代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@Slf4j
public class CustomScannerRegistrar implements ImportBeanDefinitionRegistrar, ResourceLoaderAware {
private static final String SPRING_BEAN_BASE_PACKAGE = "github.javaguide";
private static final String BASE_PACKAGE_ATTRIBUTE_NAME = "basePackage";
private ResourceLoader resourceLoader;
@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
this.resourceLoader = resourceLoader;
}
@Override
public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {
//get the attributes and values of RpcScan annotation
AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName()));
String[] rpcScanBasePackages = new String[0];
if (rpcScanAnnotationAttributes != null) {
// get the value of the basePackage property
rpcScanBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME);
}
if (rpcScanBasePackages.length == 0) {
//如果没有指定包路径,默认扫描当前类所在的包
rpcScanBasePackages = new String[]{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()};
}
// Scan the RpcService annotation
CustomScanner rpcServiceScanner = new CustomScanner(beanDefinitionRegistry, RpcService.class);
// Scan the Component annotation
CustomScanner springBeanScanner = new CustomScanner(beanDefinitionRegistry, Component.class);
if (resourceLoader != null) {
rpcServiceScanner.setResourceLoader(resourceLoader);
springBeanScanner.setResourceLoader(resourceLoader);
}
int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE);
log.info("springBeanScanner扫描的数量 [{}]", springBeanAmount);
int rpcServiceCount = rpcServiceScanner.scan(rpcScanBasePackages);
log.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);
}
}看registerBeanDefinitions方法,形参AnnotationMetadata封装了@RpcScan所在类的所有注解信息
第一行
1
AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName()));fromMap里面得到的是个Map类型,key为注解属性名,value为其相应的值,最终将其封装为AnnotationAttributes

这里的value是个String数组,所以接下来提取这个字符串数组,获得需要扫描的包的路径集,如果没有指定路径,就扫描启动类所在的包
1
2
3
4
5
6
7
8
9
String[] rpcScanBasePackages = new String[0];
if (rpcScanAnnotationAttributes != null) {
// get the value of the basePackage property
rpcScanBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME);
}
if (rpcScanBasePackages.length == 0) {
//如果没有指定包路径,默认扫描当前类所在的包
rpcScanBasePackages = new String[]{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()};
}接下来实例化了两个扫描器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Scan the RpcService annotation
CustomScanner rpcServiceScanner = new CustomScanner(beanDefinitionRegistry, RpcService.class);
// Scan the Component annotation
CustomScanner springBeanScanner = new CustomScanner(beanDefinitionRegistry, Component.class);
if (resourceLoader != null) {
rpcServiceScanner.setResourceLoader(resourceLoader);
springBeanScanner.setResourceLoader(resourceLoader);
}
int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE);
log.info("springBeanScanner扫描的数量 [{}]", springBeanAmount);
int rpcServiceCount = rpcServiceScanner.scan(rpcScanBasePackages);
if (rpcServiceCount != 0) {
NettyRpcServer.isNeedStart = true;
}
log.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);1
2
3
4
5
6
7
8
9
10
11
12
public class CustomScanner extends ClassPathBeanDefinitionScanner {
public CustomScanner(BeanDefinitionRegistry registry, Class<? extends Annotation> annoType) {
super(registry);
super.addIncludeFilter(new AnnotationTypeFilter(annoType));
}
@Override
public int scan(String... basePackages) {
return super.scan(basePackages);
}
}这里会将扫描到的满足条件的类注册到Spring中,scan返回了注册的类的数量
从实际代码来看就是扫描@Component和@RpcService所在的类并加入到IOC容器
@Component主要是在RPC内部实现类上,后者是在服务提供者上。
当Spring容器初始化完成后,会启动Netty服务。
1
2
3
4
5
6
7
8
9
@Component
public class StartupApplicationListener {
@Autowired
private NettyRpcServer nettyRpcServer;
@EventListener(ContextRefreshedEvent.class)
public void rpcStart() {
nettyRpcServer.start();
}
}2. SPI机制
首先通过ExtensionLoader的getExtensionLoader方法获得加载器
如下
1
2
3
LoadBalance loadBalance = ExtensionLoader
.getExtensionLoader(LoadBalance.class)
.getExtension(StringUtil.isBlank(rpcReferenceConfig.getLoadBalance()) ? LoadBalanceEnum.LOADBALANCE.getName() : rpcReferenceConfig.getLoadBalance());看看getExtensionLoader方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();
public static <S> ExtensionLoader<S> getExtensionLoader(Class<S> type) {
if (type == null) {
throw new IllegalArgumentException("Extension type should not be null.");
}
if (!type.isInterface()) {
throw new IllegalArgumentException("Extension type must be an interface.");
}
if (type.getAnnotation(SPI.class) == null) {
throw new IllegalArgumentException("Extension type must be annotated by @SPI");
}
// firstly get from cache, if not hit, create one
ExtensionLoader<S> extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
if (extensionLoader == null) {
EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader<S>(type));
extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
}
return extensionLoader;
}如果EXTENSION_LOADERS里面没有当前类的加载器,就创建一个并缓存,下次直接返回。
接下来通过与接口绑定的ExtensionLoader的getExtension方法获得实现类。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
public T getExtension(String name) {
if (StringUtil.isBlank(name)) {
throw new IllegalArgumentException("Extension name should not be null or empty.");
}
// firstly get from cache, if not hit, create one
Holder<Object> holder = cachedInstances.get(name);
if (holder == null) {
cachedInstances.putIfAbsent(name, new Holder<>());
holder = cachedInstances.get(name);
}
// create a singleton if no instance exists
Object instance = holder.get();
if (instance == null) {
synchronized (holder) {
instance = holder.get();
if (instance == null) {
instance = createExtension(name);
holder.set(instance);
}
}
}
return (T) instance;
}Holder看作具体实现类的容器
1
2
3
4
5
6
7
8
9
10
11
12
public class Holder<T> {
private volatile T value;
public T get() {
return value;
}
public void set(T value) {
this.value = value;
}
}如果通过Holer获取到的实例为空,则通过双检加锁的方式获得实力,具体实际到createExtension方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
private T createExtension(String name) {
// load all extension classes of type T from file and get specific one by name
Class<?> clazz = getExtensionClasses().get(name);
if (clazz == null) {
throw new RuntimeException("No such extension of name " + name);
}
T instance = (T) EXTENSION_INSTANCES.get(clazz);
if (instance == null) {
try {
EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance());
instance = (T) EXTENSION_INSTANCES.get(clazz);
} catch (Exception e) {
log.error(e.getMessage());
}
}
return instance;
}
public Map<String, Class<?>> getExtensionClasses() {
// get the loaded extension class from the cache
Map<String, Class<?>> classes = cachedClasses.get();
// double check
if (classes == null) {
synchronized (cachedClasses) {
classes = cachedClasses.get();
if (classes == null) {
classes = new HashMap<>();
// load all extensions from our extensions directory
loadDirectory(classes);
cachedClasses.set(classes);
}
}
}
return classes;
}
private void loadDirectory(Map<String, Class<?>> extensionClasses) {
String fileName = ExtensionLoader.SERVICE_DIRECTORY + type.getName();
try {
Enumeration<URL> urls;
ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
urls = classLoader.getResources(fileName);
if (urls != null) {
while (urls.hasMoreElements()) {
URL resourceUrl = urls.nextElement();
loadResource(extensionClasses, classLoader, resourceUrl);
}
}
} catch (IOException e) {
log.error(e.getMessage());
}
}首先通过getExtensionClasses().get(name);获得实现类的Class对象,如果是第一次,即存储这种映射关系的Map对象为空,则会调用loadDirectory方法,将文件中的键值关系存到map里面。

最终通过反射创建实现类对象,毕竟刚刚拿到了实现类的Class对象,并放入EXTENSION_INSTANCES中
1
private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();该实例对象作为createExtension的返回值返回,调用holder.set(instance);
下次再来就是通过name拿到Holder对象,通过Holder对象拿到具体实现类对象。
总结
- 首先去name去
cachedInstances里面找具体对象,没有找到,往里面放入一个空Holder对象- 调用
createExtension方法创建对象
- 通过name去
cachedClasses里面找相应的Class对象,没找到,调用loadDirectory方法加载当前接口的所有实现类并存储映射关系- 拿到Class对象后,去
EXTENSION_INSTANCES里面找实例对象,没找到,通过反射创建对象存入其中并返回该对象- 将返回的对象设置到Holder里面去,下次再来就是通过name拿到相应的Holder,调用其get方法获取实例对象。
3. Netty配置
当双方发送消息时,会先经过RpcMessageEncoder,这里会按照协议填充数据包,并对数据进行序列化然后压缩最后发送。

收到消息后,经过RpcMessageDecoder进行解码
1
2
3
4
5
6
7
8
9
10
11
@Slf4j
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {
public RpcMessageDecoder() {
// lengthFieldOffset: magic code is 4B, and version is 1B, and then full length. so value is 5
// lengthFieldLength: full length is 4B. so value is 4
// lengthAdjustment: full length include all data and read 9 bytes before, so the left length is (fullLength-9). so values is -9
// initialBytesToStrip: we will check magic code and version manually, so do not strip any bytes. so values is 0
this(RpcConstants.MAX_FRAME_LENGTH, 5, 4, -9, 0);
}
......
}这里分别指定了:
- 最大帧长度(8k)
- 长度偏移量
- 长度字段长度
- 数据长度修正(-9):因为在此之前一共读了魔数(4B)、版本(1B)、长度字段(4B)
- 跳过的字节数(0):因为需要检查魔数等字段,所以不跳过任何字节。
3.1 服务端
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@Slf4j
@Component
public class NettyRpcServer {
private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
private static boolean isStart = false;
public static boolean isNeedStart = false;
public void registerService(RpcServiceConfig rpcServiceConfig) {
serviceProvider.publishService(rpcServiceConfig);
}
@SneakyThrows
public void start() {
if (isStart || !isNeedStart) return;
isStart = true;
new Thread(() -> {
CustomShutdownHook.getCustomShutdownHook().clearAll();
String host = null;
try {
host = InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
throw new RuntimeException(e);
}
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
RuntimeUtil.cpus() * 2,
ThreadPoolFactoryUtil.createThreadFactory("service-handler-group", false)
);
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
// TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
.childOption(ChannelOption.TCP_NODELAY, true)
// 是否开启 TCP 底层心跳机制
.childOption(ChannelOption.SO_KEEPALIVE, true)
//表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
.option(ChannelOption.SO_BACKLOG, 128)
.handler(new LoggingHandler(LogLevel.INFO))
// 当客户端第一次进行请求的时候才会进行初始化
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
// 30 秒之内没有收到客户端请求的话就关闭连接
ChannelPipeline p = ch.pipeline();
p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
p.addLast(new RpcMessageEncoder());
p.addLast(new RpcMessageDecoder());
p.addLast(serviceHandlerGroup, new NettyRpcServerHandler());
}
});
int port = SpringUtils.getBean(RpcProperties.class).getServerPort();
// 绑定端口,同步等待绑定成功
ChannelFuture f = b.bind(host, port).sync();
// 等待服务端监听端口关闭
f.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("occur exception when start server:", e);
} finally {
log.error("shutdown bossGroup and workerGroup");
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
serviceHandlerGroup.shutdownGracefully();
}
}).start();
}
}可以看到,bossGroup用来处理连接事件,线程数为1,workerGroup处理已建立连接的 Socket 的 I/O 读写操作,多线程提升并发处理能力。这里没有传参,默认为cpu核心数 * 2,serviceHandlerGroup用来处理业务操作,这里也是cpu核心数 * 2。
同时也禁用了Nagle算法,确保 RPC 请求的响应数据立即发送,减少网络传输延迟。
Nagle算法:
当应用层发送数据时,若存在未确认的已发送小包,则新数据会暂存缓冲区。
直到收到前一个包的 ACK 确认,或缓冲区数据积累到最大报文长度(MSS),才会一次性发送缓冲区的数据。作用:
- 减少小数据包:强制发送端将多个小数据包合并成一个较大的数据包后再发送。
- 降低网络拥塞:避免网络中充斥大量小包(如仅含 1 字节数据的包),减少带宽浪费和网络拥堵
这里配置了心跳检测处理器:
1
p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));关于前面3个参数:
1
IdleStateHandler(int readerIdleTime, int writerIdleTime, int allIdleTime, TimeUnit unit)- readerIdleTime: 读空闲超时,多久未收到客户端数据时触发空闲事件
- writerIdleTime: 写空闲超时,多久未向客户端发送数据时触发空闲事件
- allIdleTime:全局空闲超时,多久既未读也未写数据时触发空闲事件
这里是30s未收到就断开连接:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Slf4j
public class NettyRpcServerHandler extends ChannelInboundHandlerAdapter {
......
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
//如果在设定的时间内没有接收到任何数据,IdleStateHandler 会触发 READER_IDLE 事件。
if (state == IdleState.READER_IDLE) {
log.info("idle check happen, so close the connection");
ctx.close();
}
} else {
super.userEventTriggered(ctx, evt);
}
}
......
}请求到达服务端,首先经过RpcMessageDecoder处理,根据messageType构造相应的RpcMessage对象。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) {
rpcMessage.setData(RpcConstants.PING);
return rpcMessage;
}
if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
rpcMessage.setData(RpcConstants.PONG);
return rpcMessage;
}
//先解压再反序列化
if (messageType == RpcConstants.REQUEST_TYPE) {
RpcRequest tmpValue = serializer.deserialize(bs, RpcRequest.class);
rpcMessage.setData(tmpValue);
} else {
RpcResponse tmpValue = serializer.deserialize(bs, RpcResponse.class);
rpcMessage.setData(tmpValue);
}之后数据传递给NettyRpcServerHandler,通过动态代理执行本地方法。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public Object handle(RpcRequest rpcRequest) {
Object service = serviceProvider.getService(rpcRequest.getRpcServiceName());
return invokeTargetMethod(rpcRequest, service);
}
/**
* get method execution results
*
* @param rpcRequest client request
* @param service service object
* @return the result of the target method execution
*/
private Object invokeTargetMethod(RpcRequest rpcRequest, Object service) {
Object result;
try {
Method method = service.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParamTypes());
result = method.invoke(service, rpcRequest.getParameters());
log.info("service:[{}] successful invoke method:[{}]", rpcRequest.getInterfaceName(), rpcRequest.getMethodName());
} catch (NoSuchMethodException | IllegalArgumentException | InvocationTargetException | IllegalAccessException e) {
throw new RpcException(e.getMessage(), e.getCause());
}
return result;
}最后将结果写回客户端。
3.2 客户端
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public NettyRpcClient(RpcProperties rpcProperties) {
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
this.rpcProperties = rpcProperties;
this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension(rpcProperties.getRegistry());
NettyRpcClient client = this;
// initialize resources such as EventLoopGroup, Bootstrap
eventLoopGroup = new NioEventLoopGroup();
bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
// The timeout period of the connection.
// If this time is exceeded or the connection cannot be established, the connection fails.
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline p = ch.pipeline();
// If no data is sent to the server within 15 seconds, a heartbeat request is sent
p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
p.addLast(new RpcMessageEncoder());
p.addLast(new RpcMessageDecoder());
p.addLast(new NettyRpcClientHandler(client));
}
});
}这里如果5s内客户端没有向服务端写数据就会触发写空闲事件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
//如果在设定的时间内没有向通道写入任何数据,IdleStateHandler 会触发 WRITER_IDLE 事件。
if (state == IdleState.WRITER_IDLE) {
log.info("write idle happen [{}]", ctx.channel().remoteAddress());
Channel channel = nettyRpcClient.getChannel((InetSocketAddress) ctx.channel().remoteAddress());
RpcMessage rpcMessage = new RpcMessage();
rpcMessage.setCodec(SerializationTypeEnum.PROTOSTUFF.getCode());
rpcMessage.setCompress(CompressTypeEnum.getCode(rpcProperties.getCompress()));
rpcMessage.setMessageType(RpcConstants.HEARTBEAT_REQUEST_TYPE);
rpcMessage.setData(RpcConstants.PING);
channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} else {
super.userEventTriggered(ctx, evt);
}
}3.3 方法调用流程
1
2
3
public interface HelloService {
String hello(Hello hello);
}这个实现类就是服务提供者
1
2
3
4
5
6
7
8
9
10
@Slf4j
@RpcService(group = "test1", version = "version1")
public class HelloServiceImpl1 implements HelloService {
@Override
public String hello(Hello hello) {
log.info("这是hello1");
int a = 1 / 0;
return hello.toString();
}
}这里是消费者
1
2
3
4
5
6
7
8
9
10
@RestController
public class TestController {
@RpcReference(version = "version1", group = "test1", loadBalance = LoadBalanceStrategy.RANDOM)
private HelloService helloService;
@GetMapping
public String test(){
String hello = helloService.hello(new Hello("111", "222"));
return hello;
}
}@RpcReference可以指定一些基础配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Inherited
public @interface RpcReference {
/**
* Service version, default value is empty string
*/
String version() default "";
/**
* Service group, default value is empty string
*/
String group() default "";
/**
* 重试次数
*/
int retries() default 3;
/**
* 负载均衡
*/
String loadBalance() default "";
/**
* 集群容错
*/
String cluster() default "";
/**
* 是否异步执行
*/
boolean async() default false;
}当消费者发起方法调用时,会被动态代理拦截,执行RpcClientProxy的invoke方法。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@SneakyThrows
@SuppressWarnings("unchecked")
@Override
public Object invoke(Object proxy, Method method, Object[] args) {
log.info("invoked method: [{}]", method.getName());
RpcRequest rpcRequest = RpcRequest.builder()
.methodName(method.getName())
.parameters(args)
.interfaceName(method.getDeclaringClass().getName())
.paramTypes(method.getParameterTypes())
.requestId(UUID.randomUUID().toString())
.group(rpcServiceConfig.getGroup())
.version(rpcServiceConfig.getVersion())
.consumerName(o.getClass().getName())
.build();
RpcReferenceConfig rpcReferenceConfig = RpcReferenceConfig.getRpcReferenceConfig(rpcRequest.getConsumerName(), rpcRequest.getInterfaceName());
//发送请求
Object responseFuture = rpcRequestTransport.sendRpcRequest(rpcRequest);
if (rpcReferenceConfig.isAsync()) {
CompletableFuture<RpcResponse<?>> future = (CompletableFuture<RpcResponse<?>>) responseFuture;
FutureContext.setFuture(future);
} else {
RpcResponse rpcResponse = ((CompletableFuture<RpcResponse<Object>>)responseFuture).get();
check(rpcResponse, rpcRequest);
return rpcResponse.getData();
}
return null;
}主要看sendRpcRequest方法,里面会根据用户的配置选择相应的集群容错方式进行服务调用,为了测试方便我这里是直接new了一个FixedIntervalRetryCluster,即重试固定次数。
1
2
3
4
@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
return new FixedIntervalRetryCluster(this).invoke(rpcRequest,new ClusterConfig());
}FixedIntervalRetryCluster:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@Slf4j
public class FixedIntervalRetryCluster extends AbstractClusterFaultTolerant {
private static final ScheduledExecutorService scheduler =
Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());
public FixedIntervalRetryCluster(NettyRpcClient rpcClient) {
super(rpcClient);
}
@Override
public Object invoke(RpcRequest rpcRequest, ClusterConfig clusterConfig) {
CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
AtomicInteger remainingRetries = new AtomicInteger(clusterConfig.getRetryTimes());
Consumer<RpcRequest> attempt = new Consumer<RpcRequest>() {
@Override
public void accept(RpcRequest request) {
sendSingleRequest(request).whenComplete((response, ex) -> {
if (response != null) {
// 成功时完成 Future
resultFuture.complete(response);
} else if (ex != null && isRetryable(ex) && remainingRetries.get() > 0) {
// 执行重试逻辑
int retryCount = clusterConfig.getRetryTimes() - remainingRetries.decrementAndGet();
log.info("第 {} 次重试请求: {}", retryCount, request);
// 6. 使用调度器实现非阻塞延迟
scheduler.schedule(() -> {
// 7. 创建新请求对象避免状态污染
RpcRequest newRequest = reBuildRequest(request);
accept(newRequest); // 迭代调用
}, clusterConfig.getRetryIntervals(), TimeUnit.MILLISECONDS);
} else {
// 8. 最终异常处理
resultFuture.completeExceptionally(
ex != null ? ex : new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE)
);
}
});
}
};
attempt.accept(rpcRequest);
return resultFuture;
}
}这里对单独消息发送的逻辑抽取成了一个方法sendSingleRequest
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public CompletableFuture<RpcResponse<Object>> sendSingleRequest(RpcRequest rpcRequest) {
CompletableFuture<RpcResponse<Object>> future = new CompletableFuture<>();
InetSocketAddress address = rpcClient.getServiceDiscovery().lookupService(rpcRequest);
Channel channel = null;
try {
channel = rpcClient.getChannel(address);
} catch (Exception e) {
future.completeExceptionally(e);
return future;
}
if (channel.isActive()) {
rpcClient.getUnprocessedRequests().put(rpcRequest.getRequestId(), future);
RpcProperties rpcProperties = rpcClient.getRpcProperties();
RpcMessage rpcMessage = RpcMessage.builder()
.data(rpcRequest)
.codec(SerializationTypeEnum.getCode(rpcProperties.getSerializer()))
.compress(CompressTypeEnum.getCode(rpcProperties.getCompress()))
.messageType(RpcConstants.REQUEST_TYPE).build();
channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) sendFuture -> {
if (!sendFuture.isSuccess()) {
future.completeExceptionally(sendFuture.cause());
log.error("Send request failed", sendFuture.cause());
}
});
} else {
future.completeExceptionally(new RpcException(RpcErrorMessageEnum.CLIENT_CONNECT_SERVER_FAILURE));
}
return future;
}首先通过负载均衡选出一个提供者地址,通过getChannel查看当前客户端与服务端是否有已经建立并且还活跃的连接(Channel),没有则进行连接建立。如果没有异常,则建立requestId与CompletableFuture之间的映射关系并构造RpcMessage对象,最后直接返回future。
最终请求到达服务端并执行目标方法,返回结果,到达NettyRpcClientHandler,取出执行结果,根据里面的requestId找到对应的future对象,调用其complete方法,此时触发集群容错里面重写的invoke方法的回调函数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
sendSingleRequest(request).whenComplete((response, ex) -> {
if (response != null) {
// 成功时完成 Future
resultFuture.complete(response);
} else if (ex != null && isRetryable(ex) && remainingRetries.get() > 0) {
// 执行重试逻辑
int retryCount = clusterConfig.getRetryTimes() - remainingRetries.decrementAndGet();
log.info("第 {} 次重试请求: {}", retryCount, request);
// 6. 使用调度器实现非阻塞延迟
scheduler.schedule(() -> {
// 7. 创建新请求对象避免状态污染
RpcRequest newRequest = reBuildRequest(request);
accept(newRequest); // 迭代调用
}, clusterConfig.getRetryIntervals(), TimeUnit.MILLISECONDS);
} else {
// 8. 最终异常处理
resultFuture.completeExceptionally(
ex != null ? ex : new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE)
);
}
});当重试次数耗尽就调用completeExceptionally方法。一开始的动态代理的invoke方法就通过这里返回的future对象获取原始数据。
这里的返回的future对象作为sendRpcRequest的返回值
1
2
3
4
5
6
7
8
9
10
Object responseFuture = rpcRequestTransport.sendRpcRequest(rpcRequest);
if (rpcReferenceConfig.isAsync()) {
CompletableFuture<RpcResponse<?>> future = (CompletableFuture<RpcResponse<?>>) responseFuture;
FutureContext.setFuture(future);
} else {
RpcResponse rpcResponse = ((CompletableFuture<RpcResponse<Object>>)responseFuture).get();
check(rpcResponse, rpcRequest);
return rpcResponse.getData();
}如果不是异步调用,就在这里通过get方法阻塞获取最终值。
4. 服务注册流程
服务通过@RpcScan指定扫描路径,通过自定义扫描器扫描指定路径,将被@RpcService修饰的类注册为Spring Bean.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Inherited
public @interface RpcService {
/**
* Service version, default value is empty string
*/
String version() default "";
/**
* Service group, default value is empty string
*/
String group() default "";
}这里自定义了Spring的后置处理器SpringBeanPostProcessor并重写了postProcessBeforeInitialization和postProcessAfterInitialization方法,前者用来处理@RpcService注解
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@SneakyThrows
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean.getClass().isAnnotationPresent(RpcService.class)) {
log.info("[{}] is annotated with [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
// get RpcService annotation
RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
// build RpcServiceProperties
RpcServiceConfig rpcServiceConfig = RpcServiceConfig.builder()
.group(rpcService.group())
.version(rpcService.version())
.service(bean)
.build();
serviceProvider.publishService(rpcServiceConfig);
}
return bean;
}这里会根据注解属性构造RpcServiceConfig对象,这里的bean变量就是某个接口的具体实现类,通过publishService进行服务发布,这里默认使用Zookeeper作为注册中心。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@Override
public void publishService(RpcServiceConfig rpcServiceConfig) {
int port = SpringUtils.getBean(RpcProperties.class).getServerPort();
try {
String host = InetAddress.getLocalHost().getHostAddress();
this.addService(rpcServiceConfig);
//发布到注册中心 key为rpcServiceName,value为ip
serviceRegistry.registerService(rpcServiceConfig.getRpcServiceName(), new InetSocketAddress(host, port));
} catch (UnknownHostException e) {
log.error("occur exception when getHostAddress", e);
}
}
/**
* Set中加入rpcServiceName(github.javaguide.HelloServicetest1version1)
* map中加入rpcServiceName和service对象(HelloServiceImpl)
* @param rpcServiceConfig rpc service related attributes
*/
@Override
public void addService(RpcServiceConfig rpcServiceConfig) {
String rpcServiceName = rpcServiceConfig.getRpcServiceName();
if (registeredService.contains(rpcServiceName)) {
return;
}
registeredService.add(rpcServiceName);
serviceMap.put(rpcServiceName, rpcServiceConfig.getService());
log.info("Add service: {} and interfaces:{}", rpcServiceName, rpcServiceConfig.getService().getClass().getInterfaces());
}这里的addService方法主要是添加服务名、以及服务名和实现类的映射关系。
服务名为接口名加上版本号和分组信息(github.javaguide.HelloServicetest1version1)
这里的registerService就是将提供者信息注册到注册中心上去。
1
2
3
4
5
6
7
8
9
public class ZkServiceRegistryImpl implements ServiceRegistry {
@Override
public void registerService(String rpcServiceName, InetSocketAddress inetSocketAddress) {
String servicePath = CuratorUtils.ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName + inetSocketAddress.toString();
CuratorFramework zkClient = CuratorUtils.getZkClient();
CuratorUtils.createPersistentNode(zkClient, servicePath);
}
}
同时会在本地保存这个一份儿这个path。
5. 消费者注册流程
这里主要看SpringBeanPostProcessor的postProcessAfterInitialization方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
Class<?> targetClass = bean.getClass();
Field[] declaredFields = targetClass.getDeclaredFields();
for (Field declaredField : declaredFields) {
RpcReference rpcReference = declaredField.getAnnotation(RpcReference.class);
if (rpcReference != null) {
RpcServiceConfig rpcServiceConfig = RpcServiceConfig.builder()
.group(rpcReference.group())
.version(rpcReference.version())
.build();
RpcReferenceConfig.addRpcReferenceConfig(bean.getClass().getName(),declaredField.getType().getName(), rpcReference);
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient, rpcServiceConfig,bean);
Object clientProxy = rpcClientProxy.getProxy(declaredField.getType());
declaredField.setAccessible(true);
try {
declaredField.set(bean, clientProxy);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return bean;
}会检查每个Bean里面是否有被@RpcReference修饰的字段,如果有的话就创建一个代理对象来替代这个字段,这就是为什么进行方法调用会被拦截。这里同时也会建立consumerName + "." + interfaceName与RpcReferenceConfig之间的映射关系。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
public class RpcReferenceConfig {
public static Map<String,RpcReferenceConfig> rpcReferenceConfigMap = new ConcurrentHashMap<>();
/**
* Service version, default value is empty string
*/
private String version;
/**
* Service group, default value is empty string
*/
private String group;
private int retries;
private String loadBalance;
private String cluster;
private boolean async;
public static void addRpcReferenceConfig(String consumerName, String fieldName, RpcReference rpcReference) {
rpcReferenceConfigMap.put(consumerName + "." + fieldName,RpcReferenceConfig.builder()
.group(rpcReference.group())
.version(rpcReference.version())
.retries(rpcReference.retries())
.loadBalance(rpcReference.loadBalance())
.cluster(rpcReference.cluster())
.async(rpcReference.async())
.build());
}
public static RpcReferenceConfig getRpcReferenceConfig(String consumerName, String interfaceName) {
return rpcReferenceConfigMap.get(consumerName + "." + interfaceName);
}
}
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Inherited
public @interface RpcReference {
/**
* Service version, default value is empty string
*/
String version() default "";
/**
* Service group, default value is empty string
*/
String group() default "";
/**
* 重试次数
*/
int retries() default 3;
/**
* 负载均衡
*/
String loadBalance() default "";
/**
* 集群容错
*/
String cluster() default "";
/**
* 是否异步执行
*/
boolean async() default false;
}6. 负载均衡策略
在发送请求之前,要先得到服务提供者的ip地址
1
InetSocketAddress address = rpcClient.getServiceDiscovery().lookupService(rpcRequest);目前只实现了Zookeeper作为注册中心
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Slf4j
public class ZkServiceDiscoveryImpl implements ServiceDiscovery {
@Override
public InetSocketAddress lookupService(RpcRequest rpcRequest) {
String rpcServiceName = rpcRequest.getRpcServiceName();
CuratorFramework zkClient = CuratorUtils.getZkClient();
List<String> serviceUrlList = CuratorUtils.getChildrenNodes(zkClient, rpcServiceName);
if (CollectionUtil.isEmpty(serviceUrlList)) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND, rpcServiceName);
}
// load balancing
RpcReferenceConfig rpcReferenceConfig = RpcReferenceConfig.getRpcReferenceConfig(rpcRequest.getConsumerName(), rpcRequest.getInterfaceName());
LoadBalance loadBalance = ExtensionLoader
.getExtensionLoader(LoadBalance.class)
.getExtension(StringUtil.isBlank(rpcReferenceConfig.getLoadBalance()) ? LoadBalanceEnum.LOADBALANCE.getName() : rpcReferenceConfig.getLoadBalance());
String targetServiceUrl = loadBalance.selectServiceAddress(serviceUrlList, rpcRequest);
log.info("Successfully found the service address:[{}]", targetServiceUrl);
String[] socketAddressArray = targetServiceUrl.split(":");
String host = socketAddressArray[0];
int port = Integer.parseInt(socketAddressArray[1]);
return new InetSocketAddress(host, port);
}
}这里会先通过rpcServiceName获取可用服务提供者ip地址,再通过配置的负载均衡算法选择出ip地址。
7. 异步调用
jdk动态代理里面的invoke方法里面调用sendRpcRequest发送请求,返回一个future对象,如果不是异步调用,则会通过其get方法阻塞获取返回值,否则将这个future对象绑定到当前线程的ThreadLocal上。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class FutureContext {
private static final ThreadLocal<CompletableFuture<RpcResponse<?>>> FUTURE_THREAD_LOCAL = new ThreadLocal<>();
public static <T> CompletableFuture<T> getFuture(Class<T> targetType) {
CompletableFuture<RpcResponse<?>> responseFuture = FUTURE_THREAD_LOCAL.get();
if (responseFuture == null) {
throw new IllegalStateException("No future in context for current thread");
}
clearFuture();
CompletableFuture<T> resultFuture = new CompletableFuture<>();
responseFuture.whenComplete((rpcResponse, throwable) -> {
if (throwable != null) {
resultFuture.completeExceptionally(throwable);
} else {
Object data = rpcResponse.getData();
if (targetType.isInstance(data)) {
resultFuture.complete(targetType.cast(data));
} else {
resultFuture.completeExceptionally(new ClassCastException("Cannot cast response data to " + targetType.getName()));
}
}
});
return resultFuture;
}
public static void setFuture(CompletableFuture<RpcResponse<?>> future) {
FUTURE_THREAD_LOCAL.set(future);
}
public static void clearFuture() {
FUTURE_THREAD_LOCAL.remove();
}
}后续用户可以通过FutureContext#getFuture方法获取返回的future对象。


