自己实现简单RPC功能

来源:互联网 发布:软件外包群 编辑:程序博客网 时间:2024/06/05 10:51

最近对RMI RPC比较感兴趣, 所以自己做了一个简单的实现, 如果有时间,之后会继续完善。


RPC主要分为服务端与客户端。 

服务端的实现如下:

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.server;  
  2. import java.io.IOException;  
  3. import java.io.InputStream;  
  4. import java.io.ObjectInput;  
  5. import java.io.ObjectInputStream;  
  6. import java.io.ObjectOutput;  
  7. import java.io.ObjectOutputStream;  
  8. import java.io.OutputStream;  
  9. import java.lang.reflect.InvocationTargetException;  
  10. import java.lang.reflect.Method;  
  11. import java.net.ServerSocket;  
  12. import java.net.Socket;  
  13. import java.util.concurrent.ConcurrentHashMap;  
  14. import java.util.concurrent.ExecutorService;  
  15. import java.util.concurrent.Executors;  
  16. import java.util.concurrent.atomic.AtomicBoolean;  
  17.   
  18.   
  19. public class RPCServer {  
  20.       
  21.     private static final ExecutorService taskPool = Executors.newFixedThreadPool(50) ;  
  22.   
  23.     /** 
  24.      * 服务接口对象库 
  25.      * key:接口名    value:接口实现 
  26.      */  
  27.     private static final ConcurrentHashMap<String, Object> serviceTargets =   
  28.         new ConcurrentHashMap<String, Object>() ;  
  29.   
  30.     private static AtomicBoolean run = new AtomicBoolean(false) ;  
  31.       
  32.     /** 
  33.      * 注册服务 
  34.      * @param service 
  35.      */  
  36.     public void registService(Object service){  
  37.         Class<?>[] interfaces = service.getClass().getInterfaces() ;  
  38.         if(interfaces == null){  
  39.             throw new IllegalArgumentException("服务对象必须实现接口");   
  40.         }  
  41.         Class<?> interfacez = interfaces[0] ;  
  42.         String interfaceName = interfacez.getName() ;  
  43.         serviceTargets.put(interfaceName, service) ;  
  44.     }  
  45.   
  46.     /** 
  47.      * 启动Server 
  48.      * @param port 
  49.      */  
  50.     public void startServer(final int port){  
  51.         Runnable lifeThread = new Runnable() {  
  52.             @Override  
  53.             public void run() {  
  54.                 ServerSocket lifeSocket = null ;  
  55.                 Socket client = null ;  
  56.                 ServiceTask serviceTask = null ;  
  57.                 try {  
  58.                     lifeSocket = new ServerSocket(port) ;  
  59.                     run.set(true) ;  
  60.                     while(run.get()){  
  61.                         client = lifeSocket.accept() ;  
  62.                         serviceTask = new ServiceTask(client);   
  63.                         serviceTask.accept() ;  
  64.                     }  
  65.                 } catch (IOException e) {  
  66.                     e.printStackTrace();  
  67.                 }  
  68.             }  
  69.         };  
  70.         taskPool.execute(lifeThread) ;    
  71.         System.out.println("服务启动成功...");  
  72.     }  
  73.       
  74.     public void stopServer(){  
  75.         run.set(false) ;  
  76.         taskPool.shutdown() ;  
  77.     }  
  78.   
  79.     public static final class ServiceTask implements Runnable{  
  80.   
  81.         private Socket client  ;  
  82.           
  83.         public ServiceTask(Socket client){  
  84.             this.client = client ;  
  85.         }  
  86.           
  87.         public void accept(){  
  88.             taskPool.execute(this) ;  
  89.         }  
  90.   
  91.         @Override  
  92.         public void run() {  
  93.             InputStream is = null ;  
  94.             ObjectInput oi = null ;  
  95.             OutputStream os = null ;  
  96.             ObjectOutput oo = null ;  
  97.             try {  
  98.                 is = client.getInputStream() ;  
  99.                 os = client.getOutputStream() ;  
  100.                 oi = new ObjectInputStream(is);  
  101.                 String serviceName = oi.readUTF() ;  
  102.                 String methodName = oi.readUTF();  
  103.                 Class<?>[] paramTypes =  (Class[]) oi.readObject() ;    
  104.                 Object[] arguments = (Object[]) oi.readObject() ;  
  105.                 System.out.println("serviceName:" + serviceName + " methodName:" + methodName);  
  106.                 Object targetService = serviceTargets.get(serviceName) ;  
  107.                 if(targetService == null){  
  108.                     throw new ClassNotFoundException(serviceName + "服务未找到!") ;  
  109.                 }  
  110.                   
  111.                 Method targetMethod = targetService.getClass().getMethod(methodName, paramTypes) ;  
  112.                 Object result = targetMethod.invoke(targetService, arguments) ;  
  113.                   
  114.                 oo = new ObjectOutputStream(os) ;  
  115.                 oo.writeObject(result) ;  
  116.             } catch (IOException e) {  
  117.                 e.printStackTrace();  
  118.             } catch (ClassNotFoundException e) {  
  119.                 e.printStackTrace();  
  120.             } catch (SecurityException e) {  
  121.                 e.printStackTrace();  
  122.             } catch (NoSuchMethodException e) {  
  123.                 e.printStackTrace();  
  124.             } catch (IllegalArgumentException e) {  
  125.                 e.printStackTrace();  
  126.             } catch (IllegalAccessException e) {  
  127.                 e.printStackTrace();  
  128.             } catch (InvocationTargetException e) {  
  129.                 e.printStackTrace();  
  130.             }finally{  
  131.                 try {  
  132.                     if(oo != null){  
  133.                         oo.close() ;  
  134.                     }  
  135.                     if(os != null){  
  136.                         os.close() ;  
  137.                     }  
  138.                     if(is != null){  
  139.                         is.close() ;  
  140.                     }  
  141.                     if(oi != null){  
  142.                         oi.close() ;  
  143.                     }  
  144.                 } catch (IOException e) {  
  145.                     e.printStackTrace();  
  146.                 }  
  147.             }  
  148.         }  
  149.   
  150.     }  
  151.   
  152.   
  153. }  

客户端如下:

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.client;  
  2.   
  3. import java.io.InputStream;  
  4. import java.io.ObjectInput;  
  5. import java.io.ObjectInputStream;  
  6. import java.io.ObjectOutput;  
  7. import java.io.ObjectOutputStream;  
  8. import java.io.OutputStream;  
  9. import java.lang.reflect.InvocationHandler;  
  10. import java.lang.reflect.Method;  
  11. import java.lang.reflect.Proxy;  
  12. import java.net.Socket;  
  13.   
  14.   
  15.   
  16. public class RPCClient {  
  17.   
  18.     /** 
  19.      * 根据接口类型得到代理的接口实现 
  20.      * @param <T> 
  21.      * @param host  RPC服务器IP 
  22.      * @param port  RPC服务端口 
  23.      * @param serviceInterface  接口类型 
  24.      * @return  被代理的接口实现 
  25.      */  
  26.     @SuppressWarnings("unchecked")  
  27.     public static <T> T findService(final String host , final int port ,final Class<T> serviceInterface){  
  28.         return (T) Proxy.newProxyInstance(serviceInterface.getClassLoader(), new Class[]{serviceInterface}, new InvocationHandler() {  
  29.             @Override  
  30.             public Object invoke(final Object proxy, final Method method, final Object[] args)  
  31.             throws Throwable {  
  32.                 Socket socket = null ;  
  33.                 InputStream is = null ;  
  34.                 OutputStream os = null ;  
  35.                 ObjectInput oi = null ;  
  36.                 ObjectOutput oo = null ;  
  37.                 try {  
  38.                     socket = new Socket(host, port) ;  
  39.                     os = socket.getOutputStream() ;  
  40.                     oo = new ObjectOutputStream(os);  
  41.                     oo.writeUTF(serviceInterface.getName()) ;  
  42.                     oo.writeUTF(method.getName()) ;  
  43.                     oo.writeObject(method.getParameterTypes()) ;  
  44.                     oo.writeObject(args);  
  45.   
  46.                     is = socket.getInputStream() ;  
  47.                     oi = new ObjectInputStream(is) ;  
  48.                     return oi.readObject() ;  
  49.                 } catch (Exception e) {  
  50.                     System.out.println("调用服务异常...");  
  51.                     return null ;  
  52.                 }finally{  
  53.                     if(is != null){  
  54.                         is.close() ;  
  55.                     }  
  56.                     if(os != null){  
  57.                         is.close() ;  
  58.                     }  
  59.                     if(oi != null){  
  60.                         is.close() ;  
  61.                     }  
  62.                     if(oo != null){  
  63.                         is.close() ;  
  64.                     }  
  65.                     if(socket != null){  
  66.                         is.close() ;  
  67.                     }  
  68.                 }  
  69.             }  
  70.         });   
  71.     }  
  72.   
  73. }  




现在写一个接口和一个实现。

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.test;  
  2.   
  3. public interface IHelloWorld {  
  4.   
  5.     String sayHello(String name) ;  
  6.       
  7. }  

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.test;  
  2.   
  3. public class HelloWorld implements IHelloWorld{  
  4.   
  5.     @Override  
  6.     public String sayHello(String name) {  
  7.         return "hello " + name + "!";  
  8.     }  
  9.   
  10. }  



下面就可以开始测试了。 

先写RPC服务端,并启动

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.test;  
  2.   
  3. import com.zf.rpc.server.RPCServer;  
  4.   
  5. public class RPCServerTest {  
  6.       
  7.     public static void main(String[] args) {  
  8.           
  9.         RPCServer server = new RPCServer() ;  
  10.         server.registService(new HelloWorld()) ;  
  11.         server.startServer(8080) ;  
  12.           
  13.     }  
  14.   
  15. }  

启动后,会看到输出

服务启动成功...


然后写RPC客户端,并启动

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package com.zf.rpc.test;  
  2.   
  3. import com.zf.rpc.client.RPCClient;  
  4.   
  5. public class RPCClientTest {  
  6.   
  7.     public static void main(String[] args) {  
  8.   
  9.         IHelloWorld helloWorld =   
  10.             RPCClient.findService("127.0.0.1" , 8080 , IHelloWorld.class) ;  
  11.         String  result = helloWorld.sayHello("is_zhoufeng");  
  12.         System.out.println(result );  
  13.   
  14.     }  
  15.   
  16. }  


会看到客户端输出:

hello is_zhoufeng!


到此, 一个远程调用就实现了。

0 0