Java nio 实践

来源:互联网 发布:wifi有信号没网络 编辑:程序博客网 时间:2024/06/05 16:58

花了一个星期的时间,通过在网上查阅的资料,了解了Java NIO的一些基本原理,相关API的基本使用。虽说这东西并不新鲜,但毕竟是第一次接触。

NIO的非阻塞优势在于仅需要少量的线程、一定量的的线程池便可以支撑起一定量的客户端并发访问。

在了解其基本原理和相关API的使用之后,动手写了一个demo,主要的类有:

NIOServer:NIO服务器端,用于接收多个客户端请求,并且对多个客户端请求进行响应

NIOClient:客户端,访问服务端并且发送请求,接收服务端响应

ReadHandler:ServerReadHandler和ClientReadHandler的抽象类,主要进行OP_READ(读)事件的处理

ServerReadHandler:服务端用于处理OP_READ(读)事件的处理类

ClientReadHandler:客户端用于处理OP_READ(读)事件的处理类

MsgPacket:用于封装请求数据和响应数据的类

下面是具体的代码:

NIOServer

package com.study.nio;import java.io.IOException;import java.net.InetSocketAddress;import java.net.Socket;import java.nio.ByteBuffer;import java.nio.channels.SelectionKey;import java.nio.channels.Selector;import java.nio.channels.ServerSocketChannel;import java.nio.channels.SocketChannel;import java.util.Iterator;import java.util.LinkedList;import java.util.Map;import java.util.Set;import java.util.concurrent.ConcurrentHashMap;import java.util.concurrent.Executors;import java.util.concurrent.ThreadPoolExecutor;public class NIOServer {// 缓存大小 最好设置大一些,但是也不要设得太大(不可少于2)private static final int BUFFER_SIZE = 1024 * 8;// 通道管理器private Selector selector;// 读处理线程池private ThreadPoolExecutor readPoolExecutor;// 分配读缓存private ByteBuffer byteBuffer;// 存放与客户端连接的SocketChannel和其待写队列private Map<SocketChannel, LinkedList<ByteBuffer>> writeBufferMap;// 存放与客户端连接的SocketChannel和ReadHandler(一个客户端对应一个ReadHandler)private Map<SocketChannel, ReadHandler> SRMap;// 现有连接数private int connnectCount;private String ip;private int port;public NIOServer() {}public NIOServer(String ip, int port) {this.ip = ip;this.port = port;}// 初始化private void init() {final int processors = Runtime.getRuntime().availableProcessors();readPoolExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(processors * 2);// 初始化分配读缓存byteBuffer = ByteBuffer.allocate(BUFFER_SIZE);// 初始化写队列writeBufferMap = new ConcurrentHashMap<SocketChannel, LinkedList<ByteBuffer>>();// 初始化SRMapSRMap = new ConcurrentHashMap<SocketChannel, ReadHandler>();// 初始化连接数connnectCount = 0;ServerSocketChannel serverSocketChannel = null;try {serverSocketChannel = ServerSocketChannel.open();serverSocketChannel.socket().bind(new InetSocketAddress(ip, port));serverSocketChannel.configureBlocking(false);selector = Selector.open();serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);} catch (IOException e) {e.printStackTrace();}}// 启动public void startup() {// 完成初始化init();// 报告服务器启动System.out.println(Report.reportCurrentTime() + "server startup...");System.out.println(Report.reportCurrentTime() + "server listen on " + ip + " port " + port);new Thread(new Runnable() {@Overridepublic void run() {// 一直跑while(!Thread.interrupted()) {try {// 阻塞在这里 如果写成selector.select(1000)则最多阻塞1000msint nKey = selector.select();if (nKey > 0) {Set<SelectionKey> keySet = selector.selectedKeys();Iterator<SelectionKey> iterator = keySet.iterator();while (iterator.hasNext()) {final SelectionKey key = iterator.next();// 移除,避免重复处理iterator.remove();// 根据key的类型进行判断 OP_ACCEPT|OP_READ|OP_WRITEif (key.isValid() && key.isAcceptable()) {// 处理接收连接请求acceptConnection(key);} else if (key.isValid() && key.isReadable()) {// 处理读操作readFromChannel(key);} else if(key.isValid() && key.isWritable()) {// 处理写操作writeToChannel(key);}}}} catch (IOException e) {e.printStackTrace();}}}}).start();}// 接收客户端连接请求private void acceptConnection(SelectionKey key) {ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel();// 接受客户端连接Socket socket = null;try {socket = serverSocketChannel.accept().socket();SocketChannel socketChannel = socket.getChannel();// 设置通道非阻塞socketChannel.configureBlocking(false);// 注册读权限socketChannel.register(selector,SelectionKey.OP_READ);// 测试写数据socketChannel.write(ByteBuffer.wrap(new MsgPacket(Report.reportCurrentTime() + "欢迎来到本地服务器").getBytes()));System.out.println(Report.reportCurrentTime() + "accept one Client");connnectCount++;// 保存SocketChannel与处理该通道的ReadHandlerSRMap.put(socketChannel, new ServerReadHandler(socketChannel, this));// 保存SocketChannel与其待写队列writeBufferMap.put(socketChannel, new LinkedList<ByteBuffer>());} catch (IOException e) {e.printStackTrace();}}// 从通道读private void readFromChannel(SelectionKey key) {final SocketChannel socketChannel = (SocketChannel) key.channel();synchronized (byteBuffer) {byteBuffer.clear();try {final int count = socketChannel.read(byteBuffer);//System.out.println("count = " + count);if(count > 0) { // 接收final byte[] data = new byte[count];System.arraycopy(byteBuffer.array(), 0, data, 0, count);final ReadHandler readHandler = SRMap.get(socketChannel);readHandler.read(data);// 线程池处理readPoolExecutor.execute(new Runnable() {@Overridepublic void run() {//readHandler.handle(array, count);readHandler.handle();}});} else if (count < 0) { // 客户端主动断开连接connnectCount --;// 释放资源releaseResource(socketChannel);socketChannel.close();key.cancel();System.out.println("客户端主动断开连接," + " 剩余连接数: " + connnectCount);}} catch (IOException e) {e.printStackTrace();try {if(socketChannel != null && socketChannel.isOpen()) {// 处理客户端异常断开socketChannel.close();}// 取消感兴趣的事件key.cancel();// 移除与该socketChannel的资源releaseResource(socketChannel);} catch (IOException e1) {e1.printStackTrace();}}}}// 模拟响应public synchronized void respone(SocketChannel socketChannel, String msg) {MsgPacket msgPacket = new MsgPacket(msg);LinkedList<ByteBuffer> bufferQueue = writeBufferMap.get(socketChannel);// 添加到写队列bufferQueue.add(ByteBuffer.wrap(msgPacket.getBytes()));socketChannel.keyFor(this.selector).interestOps(SelectionKey.OP_WRITE);// 唤醒selector.wakeup();}// 往通道写private synchronized void writeToChannel(SelectionKey key) {SocketChannel socketChannel = (SocketChannel) key.channel();LinkedList<ByteBuffer> bufferQueue = writeBufferMap.get(socketChannel);while(!bufferQueue.isEmpty()) {ByteBuffer buffer = bufferQueue.get(0);try {socketChannel.write(buffer);if(buffer.remaining() > 0) {// 该缓冲区中的字节还没有写完,break,让下一个write key继续写break;}// 写完一个bufferbufferQueue.remove(0);} catch (IOException e) {e.printStackTrace();// 处理客户端异常断开try {if(socketChannel != null && socketChannel.isOpen()) {socketChannel.close();}} catch (IOException e1) {// 关闭时可能遇到ClosedChannelExceptione1.printStackTrace();} finally {// 取消感兴趣的事件key.cancel();// 释放资源releaseResource(socketChannel);}}}if(bufferQueue.isEmpty()) {// 全部数据写完了 取消写等待事件(不取消会造成cpu很快达到100%,因为OP_WRITE没有移除,seletor.select()不会阻塞,一直执行while死循环)key.interestOps(SelectionKey.OP_READ);}}// 释放资源private synchronized void releaseResource(SocketChannel socketChannel) {SRMap.remove(socketChannel);//System.out.println("SRMap size = " + SRMap.size());writeBufferMap.remove(socketChannel);//System.out.println("writeBufferMap size = " + writeBufferMap.size());}public static void main(String[] args) {NIOServer server = new NIOServer("127.0.0.1", 9000);server.startup();}}


NIOClient

package com.study.nio;import java.io.IOException;import java.net.InetSocketAddress;import java.nio.ByteBuffer;import java.nio.channels.SelectionKey;import java.nio.channels.Selector;import java.nio.channels.SocketChannel;import java.util.Iterator;import java.util.LinkedList;import java.util.Set;public class NIOClient {// 空闲等待最长时间private static final int MAX_IDLE_COUNT = 60;// 定义最大缓存区大小private static final int BUFFER_SIZE = 1024 * 8;// 是否关闭客户端的标志private boolean isClosed = false;// 通道管理器private Selector selector;// 与服务端交互的socket通道private SocketChannel socketChannel;// 分配的读缓存private ByteBuffer readBuffer;// 读处理器private ReadHandler readHandler;// 待写队列private LinkedList<ByteBuffer> bufferQueue;// 当前空闲计数private int idleCount;private String serverIP;private int port;private String name;public String getName() {return this.name;}public NIOClient() {}public NIOClient(String serverIP, int port) {this.serverIP = serverIP;this.port = port;}public NIOClient(String name, String serverIP, int port) {this.name = name;this.serverIP = serverIP;this.port = port;}// 完成初始化工作private void init() {idleCount = 0;// 初始化读缓存readBuffer = ByteBuffer.allocate(BUFFER_SIZE);// 初始化待写队列bufferQueue = new LinkedList<ByteBuffer>();SocketChannel socketChannel = null;try {socketChannel = SocketChannel.open();// 需要设置为非阻塞模式才能进行一系列操作socketChannel.configureBlocking(false);socketChannel.connect(new InetSocketAddress(serverIP, port));selector = Selector.open();socketChannel.register(selector, SelectionKey.OP_CONNECT);} catch (IOException e) {e.printStackTrace();}}// 启动public void startup() {// 完成初始化init();new Thread(new Runnable() {@Overridepublic void run() {try {while (!isClosed) {int nKey = selector.select(1000);// 结合MAX_IDLE_COUNT 等价于 MAX_IDLE_COUNT(s)空闲时间检测if (nKey > 0) {idleCount = 0;Set<SelectionKey> keySet = selector.selectedKeys();Iterator<SelectionKey> iterator = keySet.iterator();while (iterator.hasNext()) {SelectionKey key = iterator.next();iterator.remove();if (key.isConnectable()) {// 连接事件finishedConnection(key);} else if (key.isReadable()) {// 读事件readFromChanel(key);} else if (key.isWritable()) {// 写事件writeToChannel(key);}}} else {idleCount++;if(idleCount >= MAX_IDLE_COUNT) {// 空闲超时,断开与客户端的连接close();}}}} catch (IOException e) {e.printStackTrace();}}}).start();}private void finishedConnection(SelectionKey key) {SocketChannel socketChannel = (SocketChannel) key.channel();if (socketChannel.isConnectionPending()) {try {socketChannel.finishConnect();socketChannel.configureBlocking(false);// 注册读权限socketChannel.register(selector,SelectionKey.OP_READ);this.socketChannel = socketChannel;readHandler = new ClientReadHandler(this, this.socketChannel);System.out.println(Report.reportCurrentTime() + this.name + " Connect to Server");final String msg = "本地服务器你好!" + "我是" + this.getName();final String msg1 = "我是" + this.getName() + " 本地服务器你好!";send(msg);send(msg1);} catch (IOException e) {e.printStackTrace();}}}private void readFromChanel(SelectionKey key) {        SocketChannel channel = (SocketChannel) key.channel();        // 清空读缓存        readBuffer.clear();        try {int count = channel.read(readBuffer);//System.out.println("count = " + count);if(count > 0) {byte[] data = new byte[count];System.arraycopy(readBuffer.array(), 0, data, 0, count);readHandler.read(data);readHandler.handle();}} catch (IOException e) {e.printStackTrace();try {socketChannel.close();key.cancel();} catch (IOException e1) {e1.printStackTrace();}}}public boolean isClosed() {return isClosed;}public void close() {try {// 使线程退出isClosed = true;selector.close();socketChannel.close();} catch (IOException e) {e.printStackTrace();}}public void send(String msg) {MsgPacket msgPacket = new MsgPacket(msg);// 添加到写队列bufferQueue.add(ByteBuffer.wrap(msgPacket.getBytes()));socketChannel.keyFor(this.selector).interestOps(SelectionKey.OP_WRITE);// 唤醒selector.wakeup();}// 往通道写private void writeToChannel(SelectionKey key) {SocketChannel socketChannel = (SocketChannel) key.channel();while(!bufferQueue.isEmpty()) {ByteBuffer buffer = bufferQueue.get(0);try {socketChannel.write(buffer);if(buffer.remaining() > 0) {break;}bufferQueue.remove(0);} catch (IOException e) {e.printStackTrace();}}if(bufferQueue.isEmpty()) {// 全部数据写完了 取消写等待事件key.interestOps(SelectionKey.OP_READ);}}public static void main(String[] args) {NIOClient client = new NIOClient("127.0.0.1", 9000);client.startup();System.out.println(Report.reportCurrentTime() + " client startup");}}


ReadHandler

package com.study.nio;import java.nio.channels.SocketChannel;import java.util.LinkedList;public abstract class ReadHandler {// 存放接收的字节数组protected LinkedList<byte[]> dataList = new LinkedList<byte[]>();// 存放读取的字节数protected int readCount = 0;// 存放未处理完成的包protected LinkedList<MsgPacket> packetQueue = new LinkedList<MsgPacket>();// 对应的SocketChannelprotected SocketChannel socketChannel;public ReadHandler() {}public ReadHandler(SocketChannel socketChannel) {this.socketChannel = socketChannel;}public synchronized void read(byte[] data) {dataList.add(data);readCount = readCount + data.length;}public synchronized void handle() {if(readCount == 0) {return ;}boolean flag = true;byte[] bytes = toBytes();//System.out.println("handle data = " + new String(bytes));while (flag) {switch (bytes[0]) {case MsgPacket.MSG_FLAG:if (bytes.length < MsgPacket.HEADER_SIZE) {// 此时连包的header都没有接收完MsgPacket msgPacket = new MsgPacket(bytes);packetQueue.add(msgPacket);flag = false;} else {byte[] header = new byte[MsgPacket.HEADER_SIZE];header[0] = MsgPacket.MSG_FLAG;header[1] = bytes[1];header[2] = bytes[2];MsgPacket msgPacket = new MsgPacket(header);bytes = adjust(bytes, bytes.length - header.length);if (bytes.length - header.length == 0) {// 刚好是接收了一个包的header部分packetQueue.add(msgPacket);flag = false;} else {int more = bytes.length;int length = bytes.length > msgPacket.getNeedDataLength() ? msgPacket.getNeedDataLength() : bytes.length;byte[] newdata = new byte[length];// 如果除header外剩余的数据大于msgPacket中记录的数据长度,那么证明还有剩余的包System.arraycopy(bytes, 0, newdata, 0, length);msgPacket.read(newdata);if (msgPacket.isCompleted()) {// 包完整地处理了,处理响应动作response(msgPacket);// 处理完一个包,计算剩余长度more = more - length;if (more > 0) {// 还有其他包数据,重新调整bytes数组bytes = adjust(bytes, more);} else { // 刚好是处理一个包的长度flag = false;}} else {// 包的数据部分接收不完整// 放到队列再进行处理packetQueue.add(msgPacket);flag = false;}}}break;default:// 发送的数据包括上一次未完成的部分,接着处理前面未处理完的包if(!packetQueue.isEmpty()) {MsgPacket msgPacket = packetQueue.get(0);// 可能的情况: 1.不知道包长度(header没有接受完整) 2.知道包长度if(!msgPacket.isHeaderCompleted()) {// 包header没有接收完整// 包的header部分接收不完整byte[] unCompletedHeader = msgPacket.getHeader();// 现有header长度int curLength = unCompletedHeader.length;byte[] header = new byte[MsgPacket.HEADER_SIZE];// 需要补全的header长度int needLength = header.length - unCompletedHeader.length;// 复制原来的System.arraycopy(unCompletedHeader, 0, header, 0, curLength);// 加上不足的System.arraycopy(bytes, 0, header, curLength, needLength);// 重新设置header,计算数据长度和包长度msgPacket.resetHeader(header);msgPacket.calLength();// 重新调整bytes数组(去掉header)int more = bytes.length - needLength;bytes = adjust(bytes, more);}int more = bytes.length;int length = bytes.length > msgPacket.getNeedDataLength() ? msgPacket.getNeedDataLength() : bytes.length;byte[] newdata = new byte[length];// 如果除header外剩余的数据大于msgPacket中记录的数据长度,那么证明还有剩余的包System.arraycopy(bytes, 0, newdata, 0, length);msgPacket.read(newdata);if (msgPacket.isCompleted()) {// 包完整地处理了,进行相关响应动作response(msgPacket);// 退出队列packetQueue.remove(0);// 处理完一个包,计算剩余长度more = more - length;if (more > 0) {// 还有其他包数据,重新调整bytes数组bytes = adjust(bytes, more);} else { // 刚好是处理一个包的长度flag = false;}} else {// 包的数据部分接收不完整// 本来就已经在队列里面了 不需要再add//packetQueue.add(msgPacket);flag = false;}}break;}}}// 将收到的字节转换为一个数组,进行处理private byte[] toBytes() {byte[] bytes = new byte[readCount];int destPos = 0;while(!dataList.isEmpty()) {byte[] bytes0 = dataList.remove(0);System.arraycopy(bytes0, 0, bytes, destPos, bytes0.length);destPos = destPos + bytes0.length;}readCount = 0;return bytes;}// 调整bytes中剩余的字节到新的数组,并返回protected byte[] adjust(byte[] bytes, int more) {byte[] tmp = new byte[more];int completed = bytes.length - more;// 从后往前复制for(int i = bytes.length - 1,j = tmp.length - 1; i >= completed; i--,j--) {tmp[j] = bytes[i];}return tmp;}// 收到数据后的响应动作protected abstract void response(MsgPacket msgPacket);}


ServerHandler

package com.study.nio;import java.nio.channels.SocketChannel;public class ServerReadHandler extends ReadHandler {private NIOServer nioServer;public ServerReadHandler(SocketChannel socketChannel, NIOServer nioServer) {super(socketChannel);this.nioServer = nioServer;}protected synchronized void response(MsgPacket msgPacket) {final String content = new String(msgPacket.getData());System.out.println(Report.reportCurrentTime() + "receive content -> "+ content);// 模拟耗时操作try {Thread.sleep(100);// 模拟一下回复客户端nioServer.respone(socketChannel, Report.reportCurrentTime() + "Server reply -> " + content);} catch (InterruptedException e) {e.printStackTrace();}}}


ClientReadHandler

package com.study.nio;import java.nio.channels.SocketChannel;public class ClientReadHandler extends ReadHandler {private NIOClient nioClient;public ClientReadHandler(NIOClient nioClient, SocketChannel socketChannel) {super(socketChannel);this.nioClient = nioClient;}protected void response(MsgPacket msgPacket) {String content = new String(msgPacket.getData());System.out.println(Report.reportCurrentTime() + nioClient.getName() + " receive content -> "+ content);}}



MsgPacket

package com.study.nio;/** * 该包用于封装发送的消息和读取发送的消息 * @author CrazyPig * */public class MsgPacket {public static final byte MSG_FLAG = 0x01;public static final int HEADER_SIZE = 3;private byte[] header = new byte[HEADER_SIZE];// 第一个字节为包类型,后面两个字节表示数据长度private byte[] data;// 数据private int length;// 包长度private int curDataLength = 0;// 当前data[]长度public byte[] getData() {return this.data;}public byte[] getHeader() {return this.header;}public int getPacketLength() {return this.length;}public int getDataLength() {return this.length - header.length;}public int getCurDataLength() {return this.curDataLength;}// 获取还需要填充的数据字节长度public int getNeedDataLength() {return this.getDataLength() - this.getCurDataLength();}// 根据发送的消息构造一个包public MsgPacket(String msg) {this.data = msg.getBytes();this.length = data.length + header.length;int dlen = data.length;genHeader(dlen);}// 根据收到的header构造一个包public MsgPacket(byte[] header) {this.header = header;if(header.length < HEADER_SIZE) {return ;}// 求数据长度calLength();}// 求数据长度public void calLength() {byte high = (byte) ((header[1] << 8) & 0xff00);byte low = (byte) (header[2] & 0x00ff);int dataLength = (high | low);this.data = new byte[dataLength];this.length = dataLength + header.length;}public void read(byte[] newdata) {// 复制newdata[]数组内容到data[]数组System.arraycopy(newdata, 0, this.data, this.curDataLength, newdata.length);this.curDataLength += newdata.length;}// 判断包是否完整public boolean isCompleted() {int curLength = this.curDataLength + this.header.length;return curLength == this.length;}// 判断header是否完整public boolean isHeaderCompleted() {return header.length == HEADER_SIZE;}public void genHeader(int dataLength) {header[0] = MSG_FLAG;// header[1] 高字节header[1] = (byte) ((dataLength & 0xff00) >> 8);header[2] = (byte) (dataLength & 0x00ff);}// 返回整个包的数据public byte[] getBytes() {byte[] allBytes = new byte[this.length];System.arraycopy(header, 0, allBytes, 0, this.header.length);System.arraycopy(data, 0, allBytes, this.header.length, this.data.length);return allBytes;}// 设置headerpublic void resetHeader(byte[] header) {this.header = header;}}


再用一个类进行测试,模拟多个Client连接并访问服务端

package com.study.nio;public class NIOPowerTest {public static final int CLIENT_SIZE = 300;public static void main(String[] args) {NIOClient[] client = new NIOClient[CLIENT_SIZE];for(int i = 0; i < client.length; i++) {client[i] = new NIOClient("Client" + i, "127.0.0.1", 9000);client[i].startup();System.out.println(Report.reportCurrentTime() + " CLIENT" + i + " startup");}}}





0 0
原创粉丝点击