• Java

Netty--RPC的原理和实现

RPC(Remote Procedure Call)远程过程调用,是通过网络调用远程计算机进程中的某个方法,从而达到获取和传递数据或状态的实现,调用风格就如同调用本地的方法一样。

以下为一个简单的RPC框架实现,使用Netty作为通信底层实现,发送和接收Java对象;使用zookeeper做服务中心,本机调试时可不启用。

工程结构如下:

其中common为client端和server端共同需要的,定义RPC请求、应答结构

RpcRequest, 用于封装client期望的类名和方法名,以及调用该方法所需参数的类型和值。Id唯一标识每一个Request

public class RpcRequest implements Serializable {
    private Long id;
    private String className;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class<?>[] getParameterTypes() {
        return parameterTypes;
    }

    public void setParameterTypes(Class<?>[] parameterTypes) {
        this.parameterTypes = parameterTypes;
    }

    public Object[] getParameters() {
        return parameters;
    }

    public void setParameters(Object[] parameters) {
        this.parameters = parameters;
    }

    @Override
    public String toString() {
        return "RpcRequest{" +
                "requestId='" + id + '\'' +
                ", className='" + className + '\'' +
                ", methodName='" + methodName + '\'' +
                ", parameterTypes=" + Arrays.toString(parameterTypes) +
                ", parameters=" + Arrays.toString(parameters) +
                '}';
    }
}

RpcResponse, 封装调用方法返回的结果或异常;Id为对应Request的Id,一一对应。

public class RpcResponse implements Serializable {
    private Long id;
    private Exception exception;
    private Object result;

    public Object getResult() throws Exception {
        if (this.exception != null) {
            throw this.exception;
        }
        return result;
    }

    public void setResult(Object result) {
        this.result = result;
    }

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public Exception getException() {
        return exception;
    }

    public void setException(Exception exception) {
        this.exception = exception;
    }

    @Override
    public String toString() {
        return "RpcResponse{" +
                "id=" + id +
                ", exception=" + exception +
                ", result=" + result +
                '}';
    }
}

ResponseHolder, 持有RpcResponse对象,提供等待结果返回从而阻塞获取的特性

public class ResponseHolder {
    private RpcResponse response;

    public void setResponse(RpcResponse response) {
        this.response = response;
        synchronized (this) {
            notify();
        }
    }

    public RpcResponse getResponse() throws Exception {
        synchronized (this) {
            wait(10000);
        }
        return response;
    }

    public RpcResponse getResponse(long timeout) throws Exception {
        synchronized (this) {
            wait(timeout);
        }
        return response;
    }
}

服务端

RpcServer, 在独立的线程中启动Netty,初始化拆包和序列化反序列化handler

public class RpcServer extends Thread {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);

    private static boolean start = false;

    private static RpcServer server = new RpcServer();

    private RpcServer() {
    }

    public static void startUp() {
        if (start) {
            LOGGER.info("Server already startup.");
            return;
        }
        server.start();
    }

    private static void bind(int port) {
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup(4);
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workerGroup);
            bootstrap.channel(NioServerSocketChannel.class);
            bootstrap.option(ChannelOption.SO_BACKLOG, 128);
            bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true);
            bootstrap.childHandler(new MyChannelInitializer());

            ChannelFuture future = bootstrap.bind(port).sync();
            LOGGER.info("server started on port:" + port);
            start = true;
            future.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            workerGroup.shutdownGracefully();
            bossGroup.shutdownGracefully();
        }
    }

    public static void shutDown() {
        start = false;
    }

    @Override
    public void run() {
        int port = 8080;
        bind(port);
    }

    private static class MyChannelInitializer extends ChannelInitializer<SocketChannel> {
        @Override
        public void initChannel(SocketChannel channel) throws Exception {
            channel.pipeline()
                    .addLast(new LengthFieldBasedFrameDecoder(65535, 0, 4, 0, 4))
                    .addLast(new ObjectDecoder(ClassResolvers.weakCachingConcurrentResolver(this.getClass().getClassLoader())))
                    .addLast(new LengthFieldPrepender(4))
                    .addLast(new ObjectEncoder())
                    .addLast(new ServerHandler());
        }
    }

}

ServerHandle, 接收Request请求并通过反射调用相应的方法,将结果或异常设置到Response对象并返回

public class ServerHandler extends SimpleChannelInboundHandler<RpcRequest> {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServerHandler.class);

    @Override
    public void channelRead0(final ChannelHandlerContext ctx, RpcRequest request) throws Exception {
        RpcResponse response = new RpcResponse();
        LOGGER.info("server read:" + request.toString());
        response.setId(request.getId());
        try {
            Object result = invoke(request);
            response.setResult(result);
        } catch (Exception e) {
            response.setException(e);
        }
        ctx.writeAndFlush(response);
        LOGGER.info("server send:" + response.toString());
    }

    public Object invoke(RpcRequest request) throws Exception {
        String classname = request.getClassName();
        String methodname = request.getMethodName();
        Class[] parameterTypes = request.getParameterTypes();
        Object[] parameters = request.getParameters();

        Object o = ServiceMng.getService(classname);
        Class clazz = o.getClass();
        Method method = clazz.getMethod(methodname, parameterTypes);
        Object result = method.invoke(o, parameters);
        return result;
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        LOGGER.error("server caught exception", cause);
        ctx.close();
    }
}

RpcClass, 标识为服务类的注解

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface RpcClass {
    Class<?> value();
}

ServeiceMng, 注册和获取服务类,通过Spring扫描,缓存接口名称和服务对象的映射

@Service
public class ServiceMng implements ApplicationContextAware {

    private static ApplicationContext context;

    //所有提供的服务
    private static Map<String, Object> serviceMap = new HashMap<>();

    public static Object getService(String name) {
        return serviceMap.get(name);
    }

    public static <T> T getBean(Class<T> clazz) {
        return context.getBean(clazz);
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        context = applicationContext;
        Map<String, Object> beansMap = context.getBeansWithAnnotation(RpcClass.class);
        for (Map.Entry entry : beansMap.entrySet()) {
            String interfaceName = entry.getValue().getClass().getAnnotation(RpcClass.class).value().getName();
            serviceMap.put(interfaceName, entry.getValue());
        }
        if (!serviceMap.isEmpty()) {
            //ServiceCenter.register("DemoService", IPUtil.getLoaclIP()+"8080");
            RpcServer.startUp();
        }
    }

}

服务类和接口:

public interface DemoService {

    String hello(String name);

    String hello(String msg, String name);

    Pojo test(Pojo pojo);
}
@RpcClass(DemoService.class)
@Service
public class DemoServiceImpl implements DemoService {

    @Override
    public String hello(String name) {
        return "Hello! " + name;
    }

    @Override
    public String hello(String msg, String name) {
        return msg + name;
    }

    @Override
    public Pojo test(Pojo pojo) {
        pojo.setId(-pojo.getId());
        pojo.setName(pojo.getName().toUpperCase());
        pojo.setMan(!pojo.isMan());
        pojo.getList().add("last");
        return pojo;
    }

}

客户端

RpcClient, 封装一个client与server的tcp连接,并缓存服务端地址和该连接的映射

public class RpcClient {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);

    //已连接主机的缓存
    private static Map<String, RpcClient> clientMap = new HashMap<String, RpcClient>();

    private Channel channel;

    private EventLoopGroup group;

    private String ip;

    private int port;

    private RpcClient(String ip, int port) {
        this.ip = ip;
        this.port = port;
    }

    public static RpcClient getConnect(String host, int port) throws InterruptedException {
        if (clientMap.containsKey(host + port)) {
            return clientMap.get(host + port);
        }
        RpcClient con = connect(host, port);
        clientMap.put(host + port, con);
        return con;
    }

    private static RpcClient connect(String host, int port) throws InterruptedException {
        RpcClient client = new RpcClient(host, port);

        EventLoopGroup group = new NioEventLoopGroup();
        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(group);
        bootstrap.channel(NioSocketChannel.class);
        bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
        bootstrap.handler(new ChannelInitializer<SocketChannel>() {
            @Override
            public void initChannel(SocketChannel channel) throws Exception {
                channel.pipeline()
                        .addLast(new LengthFieldBasedFrameDecoder(65535, 0, 4, 0, 4))
                        .addLast(new ObjectDecoder(ClassResolvers.weakCachingConcurrentResolver(this.getClass().getClassLoader())))
                        .addLast(new LengthFieldPrepender(4))
                        .addLast(new ObjectEncoder())
                        .addLast(new ClientHandler());
            }
        });

        ChannelFuture future = bootstrap.connect(host, port).sync();
        LOGGER.info("client connect to " + host + ":" + port);
        Channel c = future.channel();

        client.setChannel(c);
        client.setGroup(group);
        return client;
    }

    public RpcResponse invoke(RpcRequest request) throws Exception {
        ClientHandler handle = channel.pipeline().get(ClientHandler.class);
        Assert.notNull(handle);
        return handle.invoke(request);
    }

    public void closeConnect() {
        this.group.shutdownGracefully();
    }


    public void setChannel(Channel channel) {
        this.channel = channel;
    }

    public void setGroup(EventLoopGroup group) {
        this.group = group;
    }

    public String getIp() {
        return ip;
    }

    public void setIp(String ip) {
        this.ip = ip;
    }

    public int getPort() {
        return port;
    }

    public void setPort(int port) {
        this.port = port;
    }
}

Clienthandler, 发送请求和接收Server端的回应,其中responseMap对象在丢失大量连接的情况下会导致内存溢出,暂未解决

public class ClientHandler extends SimpleChannelInboundHandler<RpcResponse> {

    private static final Logger LOGGER = LoggerFactory.getLogger(ClientHandler.class);

    private Channel channel;

    //request Id 与 response的映射
    private Map<Long, ResponseHolder> responseMap = new ConcurrentHashMap<Long, ResponseHolder>();

    @Override
    public void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception {
        ResponseHolder holder = responseMap.get(response.getId());
        if (holder != null) {
            responseMap.remove(response.getId());
            holder.setResponse(response);
        }
    }

    @Override
    public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
        super.channelRegistered(ctx);
        channel = ctx.channel();
    }


    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        LOGGER.error("exceptionCaught", cause);
        ctx.close();
    }

    public RpcResponse invoke(RpcRequest request) throws Exception {
        ResponseHolder holder = new ResponseHolder();
        responseMap.put(request.getId(), holder);
        channel.writeAndFlush(request);
        return holder.getResponse();
    }

}

RpcProxy, 通过动态代理封装Request对象的构建和与服务端的连接,然后发送请求,等待反馈结果

public class RpcProxy implements InvocationHandler {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcProxy.class);

    private static AtomicLong id = new AtomicLong(0);

    private RpcClient client = null;

    public static <T> T get(Class<?> interfaceClass) {
        return (T) Proxy.newProxyInstance(
                interfaceClass.getClassLoader(),
                new Class<?>[]{interfaceClass},
                new RpcProxy()
        );
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        RpcRequest request = new RpcRequest();
        request.setId(id.incrementAndGet());
        request.setClassName(method.getDeclaringClass().getName());
        request.setMethodName(method.getName());
        request.setParameterTypes(method.getParameterTypes());
        request.setParameters(args);

        if (client == null) {
            client = RpcClient.getConnect("localhost", 8080);
        }
        RpcResponse r = client.invoke(request);
        return r.getResult();
    }

    public void close() {
        this.client.closeConnect();
    }
}

Zookeeper服务中心

ServiceCenter, Server端用来注册一个服务,Client端用来获取可用Server端的地址

public class ServiceCenter {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceCenter.class);

    private static final String APPS_PATH = "/__apps__";
    private static ZooKeeper zk;

    private static void connect() throws InterruptedException, KeeperException, IOException {
        if (zk != null) {
            return;
        }
        CountDownLatch latch = new CountDownLatch(1);
        //服务中心地址,从配置文件读取
        String serverCenterAddress = "localhost";

        zk = new ZooKeeper(serverCenterAddress, 30000, null);
        latch.await();
        if (zk.exists(APPS_PATH, false) == null) {
            zk.create(APPS_PATH, null, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
        }
    }

    public static void register(String serviceName, String address) {
        try {
            connect();
            if (zk.exists(APPS_PATH + "/" + serviceName, false) == null) {
                zk.create(APPS_PATH + "/" + serviceName, null, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
            }
            String path = zk.create(APPS_PATH + "/" + serviceName + "/", address.getBytes(), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (KeeperException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public String queryService(String serviceName) throws KeeperException, InterruptedException {
        List<String> apps = zk.getChildren(APPS_PATH + "/" + serviceName, false);
        if (apps.isEmpty()) {
            return null;
        }
        //Collections.sort(apps);
        byte[] data = zk.getData(apps.get(0), false, null);
        return new String(data);
    }

}

测试代码:

POJO

public class Pojo implements Serializable {

    private int id;

    private String name;

    private Date birth;

    private boolean man;

    private List<String> list;

    public int getId() {
        return id;
    }

    public void setId(int id) {
        this.id = id;
    }

    。。。

    @Override
    public String toString() {
        return "Pojo{" +
                "id=" + id +
                ", name='" + name + '\'' +
                ", birth=" + birth +
                ", man=" + man +
                ", list=" + list +
                '}';
    }
}

RpcTest

public class RpcTest {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcTest.class);

    public static void main(String[] args) throws InterruptedException {

        //
        ApplicationContext ctx = new ClassPathXmlApplicationContext("rpcApplication.xml");

        //
        DemoService demoService = RpcProxy.get(DemoService.class);

        for (int i = 0; i < 100; i++) {

            String result = demoService.hello("luangeng" + i);
            LOGGER.info("result: " + result);

            result = demoService.hello("你好,", "luangeng" + i);
            LOGGER.info("result: " + result);

            Pojo p = new Pojo();
            p.setId(i);
            p.setName("luanegng" + i);
            p.setMan(true);
            p.setBirth(new Date());
            p.setList(new ArrayList<>());
            p.getList().add("pojo" + i);
            Pojo p2 = demoService.test(p);
            LOGGER.info(p2.toString());
            LOGGER.info(" ");
        }

    }
}

输出结果:

2017-12-02 19:34:33,887 INFO [com.luangeng.rpc.test.RpcTest] - result: Hello! luangeng0
2017-12-02 19:34:33,889 INFO [com.luangeng.rpc.test.RpcTest] - result: 你好,luangeng0
2017-12-02 19:34:33,894 INFO [com.luangeng.rpc.test.RpcTest] - Pojo{id=0, name='LUANEGNG0', birth=Sat Dec 02 19:34:33 CST 2017, man=false, list=[pojo0, last]}
2017-12-02 19:34:33,894 INFO [com.luangeng.rpc.test.RpcTest] -
2017-12-02 19:34:33,896 INFO [com.luangeng.rpc.test.RpcTest] - result: Hello! luangeng1
2017-12-02 19:34:33,898 INFO [com.luangeng.rpc.test.RpcTest] - result: 你好,luangeng1
2017-12-02 19:34:33,900 INFO [com.luangeng.rpc.test.RpcTest] - Pojo{id=-1, name='LUANEGNG1', birth=Sat Dec 02 19:34:33 CST 2017, man=false, list=[pojo1, last]}
2017-12-02 19:34:33,900 INFO [com.luangeng.rpc.test.RpcTest] -
2017-12-02 19:34:33,902 INFO [com.luangeng.rpc.test.RpcTest] - result: Hello! luangeng2
2017-12-02 19:34:33,903 INFO [com.luangeng.rpc.test.RpcTest] - result: 你好,luangeng2
2017-12-02 19:34:33,905 INFO [com.luangeng.rpc.test.RpcTest] - Pojo{id=-2, name='LUANEGNG2', birth=Sat Dec 02 19:34:33 CST 2017, man=false, list=[pojo2, last]}
2017-12-02 19:34:33,905 INFO [com.luangeng.rpc.test.RpcTest] -
2017-12-02 19:34:33,906 INFO [com.luangeng.rpc.test.RpcTest] - result: Hello! luangeng3
2017-12-02 19:34:33,908 INFO [com.luangeng.rpc.test.RpcTest] - result: 你好,luangeng3
2017-12-02 19:34:33,910 INFO [com.luangeng.rpc.test.RpcTest] - Pojo{id=-3, name='LUANEGNG3', birth=Sat Dec 02 19:34:33 CST 2017, man=false, list=[pojo3, last]}

相关

最新