在上一个版本上进行优化。RPC通信简单实现(二)
1、使用netty进行通讯
2、使用protostuff进行序列化
相关依赖包
<dependency> <groupId>io.protostuff</groupId> <artifactId>protostuff-core</artifactId> <version>1.7.2</version> </dependency> <dependency> <groupId>io.protostuff</groupId> <artifactId>protostuff-runtime</artifactId> <version>1.7.2</version> </dependency> <dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> <version>4.1.51.Final</version> </dependency>
import org.springframework.stereotype.Component; import java.lang.annotation.*; @Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE,ElementType.FIELD}) @Component public @interface RpcReference { Class<?> value() default RpcReference.class ; String version() default ""; }
import org.springframework.stereotype.Component; import java.lang.annotation.*; @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) @Component public @interface RpcService { Class<?> value(); String version() default ""; }
import java.io.Serializable; import java.util.Arrays; public class RpcRequest implements Serializable { private String reqId;//请求ID private String className; //类名 private String methodName;// 方法名 private Object[] parameters;// 方法参数 private String version;//版本号 public RpcRequest(String className, String methodName, Object[] parameters, String version) { this.className = className; this.methodName = methodName; this.parameters = parameters; this.version = version; } public RpcRequest() { } public String getReqId() { return reqId; } public void setReqId(String reqId) { this.reqId = reqId; } public String getVersion() { return version; } public void setVersion(String version) { this.version = version; } public String getClassName() { return className; } public void setClassName(String className) { this.className = className; } public String getMethodName() { return methodName; } public void setMethodName(String methodName) { this.methodName = methodName; } public Object[] getParameters() { return parameters; } public void setParameters(Object[] parameters) { this.parameters = parameters; } @Override public String toString() { return "RpcRequest{" + "reqId='" + reqId + '\'' + ", className='" + className + '\'' + ", methodName='" + methodName + '\'' + ", parameters=" + Arrays.toString(parameters) + ", version='" + version + '\'' + '}'; } }
import java.io.Serializable; import java.util.Arrays; public class RpcResponse implements Serializable { private String reqId;//请求ID private Object result; public RpcResponse() { } public String getReqId() { return reqId; } public void setReqId(String reqId) { this.reqId = reqId; } public Object getResult() { return result; } public void setResult(Object result) { this.result = result; } @Override public String toString() { return "RpcResponse{" + "reqId='" + reqId + '\'' + ", result=" + result + '}'; } }
package com.meihaocloud.rpc; public interface IUserService { public String getUserName (String firstName); }
package com.meihaocloud.rpc; public interface IHelloService { public String sayHello(String content); }
package com.meihaocloud.rpc; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; public class App { public static void main(String[] args) { ApplicationContext context=new AnnotationConfigApplicationContext(SpringConfig.class); ((AnnotationConfigApplicationContext) context).start(); } }
package com.meihaocloud.rpc; @RpcService(value=IHelloService.class ,version = "1.0") public class HelloService implements IHelloService { @Override public String sayHello(String content) { System.out.println("RpcRequest : " + content); try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } return "RpcResponse : " + content; } }
package com.meihaocloud.rpc; import java.util.List; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; /** * 自定义解码类,需要继承ByteToMessageDecoder */ public class MyDecoder extends ByteToMessageDecoder { private Class<?> genericClass; public MyDecoder(Class<?> genericClass) { this.genericClass = genericClass; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { if (in.readableBytes() < 4) { return; } in.markReaderIndex(); int dataLength = in.readInt(); if (dataLength < 0) { ctx.close(); } if (in.readableBytes() < dataLength) { in.resetReaderIndex(); } byte[] data = new byte[dataLength]; in.readBytes(data); Object obj = ProtoStuffSerializer.deserialize(data, genericClass); out.add(obj); } }
package com.meihaocloud.rpc; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToByteEncoder; /* * 自定义编码类,继承MessageToByteEncoder */ public class MyEncoder extends MessageToByteEncoder { private Class<?> genericClass; public MyEncoder(Class<?> genericClass) { this.genericClass = genericClass; } @Override protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception { if (genericClass.isInstance(msg)) { byte[] data = ProtoStuffSerializer.serialize(msg); out.writeInt(data.length); out.writeBytes(data); } } }
package com.meihaocloud.rpc; import io.protostuff.LinkedBuffer; import io.protostuff.ProtostuffIOUtil; import io.protostuff.Schema; import io.protostuff.runtime.RuntimeSchema; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * ProtoStuff工具类 */ public class ProtoStuffSerializer { /** * 代码优化,避免每次序列化都重新申请Buffer空间 */ private static LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE); /** * 代码优化,缓存Schema */ private static Map<Class<?>, Schema<?>> schemaCache = new ConcurrentHashMap<>(); @SuppressWarnings("unchecked") public static <T> byte[] serialize(T t) { Class<T> clazz = (Class<T>) t.getClass(); Schema<T> schema = getSchema(clazz); try { return ProtostuffIOUtil.toByteArray(t, schema, buffer); } finally { buffer.clear(); } } public static <T> T deserialize(byte[] b, Class<T> clazz) { Schema<T> schema = getSchema(clazz); T obj = schema.newMessage(); ProtostuffIOUtil.mergeFrom(b, obj, schema); return obj; } @SuppressWarnings("unchecked") private static <T> Schema<T> getSchema(Class<T> clazz) { Schema<T> schema = (Schema<T>) schemaCache.get(clazz); if (schema == null) { schema = RuntimeSchema.createFrom(clazz); if (schema != null) { schemaCache.put(clazz, schema); } } return schema; } }
package com.meihaocloud.rpc; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class RpcProxyServer implements ApplicationContextAware, InitializingBean { private int port; public RpcProxyServer(int port) { this.port = port; } /** * Spring中Bean初始化完成后执行 * @throws Exception */ @Override public void afterPropertiesSet() throws Exception { // 配置服务端NIO线程组 EventLoopGroup parentGroup = new NioEventLoopGroup(); EventLoopGroup childGroup = new NioEventLoopGroup(); try { ServerBootstrap b = new ServerBootstrap(); b.group(parentGroup, childGroup) .channel(NioServerSocketChannel.class) // 非阻塞模式 .option(ChannelOption.SO_BACKLOG, 128) .option(ChannelOption.TCP_NODELAY, true) .childHandler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel channel) { // 服务端接收的数据是RpcRequest ,所以decode RpcRequest channel.pipeline().addLast(new MyDecoder(RpcRequest.class)); // 服务端发送出去的数据是RpcResponse,所以encode RpcResponse channel.pipeline().addLast(new MyEncoder(RpcResponse.class)); // 在管道中添加我们自己的接收数据实现方法 channel.pipeline().addLast(new ServerChannelHandler()); } }); ChannelFuture f = b.bind(port).sync(); System.out.println("server started ..."); f.channel().closeFuture().sync(); } catch (InterruptedException e) { e.printStackTrace(); } finally { childGroup.shutdownGracefully(); parentGroup.shutdownGracefully(); } } @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class); if (serviceBeanMap != null && serviceBeanMap.size() > 0) { for (Object servcieBean : serviceBeanMap.values()) { //拿到注解 RpcService rpcService = servcieBean.getClass().getAnnotation((RpcService.class)); String serviceName = rpcService.value().getName();//拿到接口类名 String version = rpcService.version(); //拿到版本号 if (!StringUtils.isEmpty(version)) { serviceName += "-" + version; } SpringConfig.handlerMap.put(serviceName, servcieBean); } } } }
package com.meihaocloud.rpc; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.util.ReferenceCountUtil; import org.apache.commons.lang3.StringUtils; import java.io.*; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.Socket; import java.util.HashMap; import java.util.Map; public class ServerChannelHandler extends ChannelInboundHandlerAdapter { public ServerChannelHandler() { } /**. * 接收数据时触发 */ @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { try { // 接收参数,调用方法,返回结果 RpcRequest request = (RpcRequest) msg; System.out.println("request : " + request ); Object result = invoke(request); System.out.println("return result : " + result); RpcResponse rpcResponse = new RpcResponse() ; rpcResponse.setReqId(request.getReqId()); rpcResponse.setResult(result); ctx.channel().writeAndFlush(rpcResponse); } finally { ReferenceCountUtil.release(msg); } } /** * 利用反射调用方法 * * @param rpcRequest * @return * @throws ClassNotFoundException * @throws NoSuchMethodException * @throws InvocationTargetException * @throws IllegalAccessException */ public Object invoke(RpcRequest rpcRequest) throws Exception { String className = rpcRequest.getClassName(); // 类名 String methodName = rpcRequest.getMethodName(); // 方法名 Object[] parameters = rpcRequest.getParameters(); // 参数 String version = rpcRequest.getVersion(); // 版本号 //增加版本号的判断 String serviceName = className; if (!StringUtils.isEmpty(version)) { serviceName += "-" + version; } Object service = SpringConfig.handlerMap.get(serviceName); // 获取实例对象 if (service == null) { throw new RuntimeException("service not found : " + serviceName); } Class<?> aClass = Class.forName(className); // 根据类名反射对象 Object invoke = null; if (parameters == null || parameters.length == 0) { //无参数 Method method = aClass.getMethod(methodName); // 获取到方法对象 invoke = method.invoke(service); //调用方法 } else { Class<?>[] type = new Class<?>[parameters.length]; // 参数的类型 for (int i = 0; i < parameters.length; i++) { type[i] = parameters[i].getClass(); } Method method = aClass.getMethod(methodName, type); // 获取到方法对象 invoke = method.invoke(service, parameters); //调用方法 } return invoke; } }
package com.meihaocloud.rpc; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import java.util.HashMap; import java.util.Map; @Configuration @ComponentScan(basePackages = "com.meihaocloud.rpc") public class SpringConfig { public static final Map<String, Object> handlerMap = new HashMap(); @Bean(name="rpcProxyServer") public RpcProxyServer rpcProxyServer(){ return new RpcProxyServer(8080); } }
package com.meihaocloud.rpc; @RpcService(value = IUserService.class, version = "1.0") public class UserService implements IUserService { @Override public String getUserName(String firstName) { return "firstName: " + firstName + " , lastName: zhang" ; } }
package com.meihaocloud.rpc; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; public class App { public static void main(String[] args) { ApplicationContext context=new AnnotationConfigApplicationContext(SpringConfig.class); ((AnnotationConfigApplicationContext) context).start(); } }
package com.meihaocloud.rpc; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.socket.SocketChannel; import io.netty.util.ReferenceCountUtil; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; public class ClientChannelHandler extends ChannelInboundHandlerAdapter { private Object obj = new Object(); private RpcResponse response; @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { try { this.response = (RpcResponse) msg; synchronized (obj) { obj.notifyAll(); } } finally { ReferenceCountUtil.release(msg); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.close(); System.out.println("异常信息:" + cause.getMessage()); } public RpcResponse getResponse() { return response; } public void setResponse(RpcResponse response) { this.response = response; } public Object getObj() { return obj; } public void setObj(Object obj) { this.obj = obj; } }
package com.meihaocloud.rpc; import org.springframework.stereotype.Component; @Component public class HelloService { @RpcReference(value=IHelloService.class,version = "1.0") private IHelloService helloService ; @RpcReference(version = "1.0") private IUserService userService ; public String hello(String msg){ return helloService.sayHello(msg) ; } public String getUserName(String name) { return userService.getUserName(name) ; } }
package com.meihaocloud.rpc; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationContextException; import java.io.IOException; import java.lang.reflect.Field; import java.net.ServerSocket; import java.net.Socket; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class MyApplicationContext implements ApplicationContextAware { @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { // 获取IOC容器中所有的bean String[] beanDefinitionNames = applicationContext.getBeanDefinitionNames(); if (beanDefinitionNames != null && beanDefinitionNames.length > 0) { for (String beanName : beanDefinitionNames) { Object bean = applicationContext.getBean(beanName); // 遍历所有bean的属性 Field[] declaredFields = bean.getClass().getDeclaredFields(); if (declaredFields != null && declaredFields.length > 0) { for (Field field : declaredFields) { RpcReference annotation = field.getAnnotation(RpcReference.class); // 获取对应的注解 if (annotation != null) { Class<?> type = null ; if(annotation.value() != RpcReference.class){ type = annotation.value() ; } else { type = field.getType(); // 获取属性的类型, } String serviceName = type.getName(); //接口名称 String version = annotation.version();//版本号 if (StringUtils.isNotBlank(version)) { serviceName += "-" + version; } // 生成代理对象,可以从配置文件中读取IP+port Object proxy = RpcClientProxy.getProxy(type, "127.0.0.1",8080,version); try { //将代理对象赋值给属性 field.setAccessible(true); field.set(bean, proxy); } catch (IllegalAccessException e) { e.printStackTrace(); } } } } } } } }
下面3个工具类和服务端代码相同。
MyDecoder.java
MyEncoder.java
ProtoStuffSerializer.java
package com.meihaocloud.rpc; import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.net.Socket; import java.util.Date; import java.util.List; import java.util.concurrent.BrokenBarrierException; public class RpcClientInvocation implements InvocationHandler { private String ip ; private int port ; private String version ; public RpcClientInvocation(String ip , int port , String version) { this.version = version ; this.ip = ip ; this.port = port ; } public RpcClientInvocation(){} @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { //如果传进来是一个已实现的具体类(本次演示略过此逻辑) if (Object.class.equals(method.getDeclaringClass())) { try { return method.invoke(this, args); } catch (Throwable t) { t.printStackTrace(); } //如果传进来的是一个接口(核心) return null; } else { try { return startNetty(method, args); } catch (Exception e) { e.printStackTrace(); } return null; } } public Object startNetty(Method method, Object[] args) throws BrokenBarrierException, InterruptedException { //这里可以使用轮训,随机等策略 RpcRequest request = new RpcRequest(); request.setReqId(""+new Date().getTime()); request.setClassName(method.getDeclaringClass().getName()); request.setMethodName(method.getName()); request.setParameters(args); request.setVersion(version); System.out.println("request : " + request); EventLoopGroup workerGroup = new NioEventLoopGroup(); final ClientChannelHandler clientChannelHandler = new ClientChannelHandler(); try { Bootstrap b = new Bootstrap(); b.group(workerGroup) .channel(NioSocketChannel.class) .option(ChannelOption.AUTO_READ, true) .option(ChannelOption.TCP_NODELAY, true) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel channel) throws Exception { //将接收到的数据进行decode , 接收到的对象是RpcResponse channel.pipeline().addLast(new MyDecoder(RpcResponse.class)); //对发出去的数据进行encode , 发出去的对象是RpcRequest channel.pipeline().addLast(new MyEncoder(RpcRequest.class)); // 在管道中添加我们自己的接收数据实现方法 channel.pipeline().addLast(clientChannelHandler); } }); ChannelFuture f = b.connect(ip, port).sync(); f.channel().writeAndFlush(request); synchronized (clientChannelHandler.getObj()) { clientChannelHandler.getObj().wait(); } if (clientChannelHandler.getResponse() != null) { f.channel().close(); } } catch (InterruptedException e) { e.printStackTrace(); } finally { workerGroup.shutdownGracefully(); } RpcResponse response = clientChannelHandler.getResponse(); if (response == null) { System.out.println(" RpcResponse is null"); return null; } Object result = clientChannelHandler.getResponse().getResult(); return result; } }
package com.meihaocloud.rpc; import java.lang.reflect.Proxy; import java.util.List; public class RpcClientProxy { public static <T> T getProxy(Class<T> clazz, String ip , int port , String version) { return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, new RpcClientInvocation(ip,port,version)); } }
package com.meihaocloud.rpc; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; @Configuration @ComponentScan(basePackages = "com.meihaocloud.rpc") public class SpringConfig { @Bean(name="rpcProxyServer") public MyApplicationContext rpcProxyServer(){ return new MyApplicationContext(); } }
分别启动服务端和客户端,调用正常,流程基本跑通。
服务端打印:
request : RpcRequest{reqId='1614486131409', className='com.meihaocloud.rpc.IHelloService', methodName='sayHello', parameters=[你好啊!张三 ], version='1.0'}
return result : RpcResponse : 你好啊!张三
request : RpcRequest{reqId='1614486132949', className='com.meihaocloud.rpc.IUserService', methodName='getUserName', parameters=[lisi], version='1.0'}
return result : firstName: lisi , lastName: zhang
客户端打印:
request : RpcRequest{reqId='1614486131409', className='com.meihaocloud.rpc.IHelloService', methodName='sayHello', parameters=[你好啊!张三 ], version='1.0'}
hello : RpcResponse : 你好啊!张三
request : RpcRequest{reqId='1614486132949', className='com.meihaocloud.rpc.IUserService', methodName='getUserName', parameters=[lisi], version='1.0'}
name : firstName: lisi , lastName: zhang