python 网络 select

来源:互联网 发布:无人机飞控调参软件 编辑:程序博客网 时间:2024/06/02 01:29

最近写安卓手机客户端,要和后台通讯, python 写了个后台服务用于自测,感觉代码有通用性,发下吧。

设计:

分成三个部分, 报文设计,后台设计,后台测试用例。后台设计的比较挫,但是可以用。

细节:

报文部分

//包头和包体:sizeof(Pkg)=16+msgLen+extLenstruct Pkg{struct PkgHdr hdr;//包头,固定长度sizeof(PkgHdr)uint8_t msg[msgLen];//放置包体待解析消息,json,pb,tlv等,未使用则为空uint8_t ext[extLen];//放置二进制扩展数据(文件或数据流),未使用则为空};
//包头和包体:sizeof(Pkg)=16+msgLen+extLenstruct Pkg{struct PkgHdr hdr;//包头,固定长度sizeof(PkgHdr)uint8_t msg[msgLen];//放置包体待解析消息,json,pb,tlv等,未使用则为空uint8_t ext[extLen];//放置二进制扩展数据(文件或数据流),未使用则为空};

后台模块

# coding=utf-8import argparseimport loggingimport osimport timeimport uuidimport jsonimport threadingimport multiprocessingimport randomimport selectimport socketimport queueimport uuidimport structfrom enum import Enum, uniqueimport tornado.ioloopg_select_timeout = 10class Server(object):    def __init__(self, host='192.168.100.41', port=33333, timeout=2, client_nums=10, speech_recognizer=None):        self.__host = host        self.__port = port        self.__timeout = timeout        self.__client_nums = client_nums        self.__buffer_size = 1024        self.__frame_length = 16        self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)        self.server.setblocking(0)        self.server.settimeout(self.__timeout)        self.server.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) #keepalive        self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) #端口复用        server_host = (self.__host, self.__port)        try:            self.server.bind(server_host)            self.server.listen(self.__client_nums)        except:            raise                    #网络相关处理        self.inputs = [self.server]           self.outputs = []                     self.message_queues = {}              self.client_info = {}        #业务相关处理        self.store_dir_path = ""        self.filedesc_wavpath_dict = {}        self.speech_recognizer = speech_recognizer    def store_data(self, file_desc, data):      file_desc_list = self.filedesc_wavpath_dict.keys()      if file_desc not in file_desc_list:        random_str_wav = str(_get_ms_time_int_part()) + '-' + str(uuid.uuid4()) + "test.wav"        wav_path = os.path.join(self.store_dir_path, random_str_wav)        self.filedesc_wavpath_dict[file_desc] = wav_path        logging.info("random_str:{}".format(random_str_wav))        logging.info("file desc:{}".format(file_desc))        logging.info("file path:{}".format(wav_path))      with open(self.filedesc_wavpath_dict[file_desc], "ab+") as inputwav:        inputwav.write(data)      return self.filedesc_wavpath_dict[file_desc]    def readMSGHead(self, filedesc):      '''              //数据包头:sizeof(PkgHdr)=16              struct PkgHdr{              uint16_t magic;     //magic number:0xfffe              uint16_t bodyLen;   //包体长度,即msg和ext总长,不含包头              uint16_t pkgType;   //0,请求包 1,返回包 ...后续可扩展              uint16_t cmd;       //请求和返回命令字,目前默认填0              uint16_t retCode;   //0成功, 其它失败,返回包填写              uint16_t msgFmt;    //消息格式0:json 1:protobuf ...              uint16_t msgLen;    //消息长度              uint16_t extLen;    //扩展包长度              };      '''      frame_length_bytes = filedesc.recv(self.__frame_length)      magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen = struct.unpack('>HHHHHHHH', frame_length_bytes)      return magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen    def logMSG(self, msg ):      msg_head = msg[0:16]      magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen = struct.unpack('>HHHHHHHH', msg_head)      logging.info("========================================")      logging.info("msg_head: {}".format(msg_head))      logging.info("bodyLen:  {}".format(bodyLen))      logging.info("pkgType:  {}".format(pkgType))      logging.info("cmd:      {}".format(cmd))      logging.info("retCode:  {}".format(retCode))      logging.info("msgFmt:   {}".format(msgFmt))      logging.info("msgLen:   {}".format(msgLen))      logging.info("extLen:   {}".format(extLen))      logging.info("msgjson:  {}".format(msg[16:16+msgLen]))      logging.info("========================================")    def readMSGMsg(self, filedesc, msgLen):      msg = filedesc.recv(msgLen)      msg = msg.decode()      msg_json = json.loads(msg)      return msg_json    def readMSGExt(self, filedesc, extLen):      ext = filedesc.recv(extLen)      return ext    def decodeMSG(self, filedesc):      MSGHeadInfo = self.readMSGHead(filedesc)      logging.info("decodeMSG MSGHeadInfo:{}".format(MSGHeadInfo))      msgLen = MSGHeadInfo[6]      extLen = MSGHeadInfo[7]      logging.info("decodeMSG msgLen:{}".format(msgLen))      logging.info("decodeMSG extLen:{}".format(extLen))      msg_json = self.readMSGMsg(filedesc, msgLen)      Ext = self.readMSGExt(filedesc, extLen)      return msg_json, Ext    def encodeMSG(self, pkgType_input, errCode, errMsg, refresh, text, ext):      magic = 0xfffe      bodyLen = 0      pkgType = pkgType_input      cmd = 0      retCode = 0      msgFmt = 0      msgLen = 0      extLen = 0      msg_json_data = {"errCode": errCode, "errMsg": errMsg, "refresh": refresh, "text": text}      msg = json.dumps(msg_json_data)      if msg:        msgLen = len(msg)      if ext:        extLen = len(ext)      bodyLen = msgLen + extLen      msgHeadByteArray = bytearray(struct.pack('>HHHHHHHH', magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen))      msgMsgByteArray = bytearray(str(msg).encode())      msg = msgHeadByteArray + msgMsgByteArray      if ext:        msgExtByteArray = bytearray(ext)        msg = msg + msgExtByteArray      return msg    def call_asr(self, wav_file_path ):      decode_result = self.speech_recognizer.wav_to_txt(wav_file_path)      return decode_result    def run(self):        while True:            readable , writable , exceptional = select.select(self.inputs, self.outputs, self.inputs, g_select_timeout)            if not (readable or writable or exceptional) :                continue            for s in readable :                if s is self.server:#是客户端连接                    connection, client_address = s.accept()                    #print "connection", connection                    print( "%s connect. " %str(client_address) )                    connection.setblocking(False)                     self.inputs.append(connection) #客户端添加到inputs                    self.client_info[connection] = str(client_address)                    self.message_queues[connection] = queue.Queue()  #每个客户端一个消息队列                else:#是client, 数据发送过来                    receiveMSG = None                    try:                        receiveMSG = self.decodeMSG(s)                    except Exception as e:                        err_msg = "Client Error!!!"                        logging.error(err_msg)                        logging.error(str(e))                    if receiveMSG :                        msg_json, Ext = receiveMSG                        logging.info("receive msg_json:   {}".format(msg_json))                        logging.info("receiveExt length:  {}".format(len(Ext)))                        wav_file_path = self.store_data(s, Ext)                        dataoutput = "%s %s " % (time.strftime("%Y-%m-%d %H:%M:%S"), self.client_info[s])                        #dataoutput = "message from server"                        self.message_queues[s].put(dataoutput)                         if s not in self.outputs:                             self.outputs.append(s)                    else:                         #Interpret empty result as closed connection                        print ("Client:%s Close." % str( self.client_info[s]) )                        if s in self.outputs :                            self.outputs.remove(s)                        self.inputs.remove(s)                        s.close()                        del self.message_queues[s]                        del self.client_info[s]                        if s in self.filedesc_wavpath_dict.keys():                          del self.filedesc_wavpath_dict[s]            for s in writable: #outputs 有消息就要发出去了                try:                    next_msg = self.message_queues[s].get_nowait()  #非阻塞获取                except queue.Empty:                    err_msg = "Output Queue is Empty!"                    #g_logFd.writeFormatMsg(g_logFd.LEVEL_INFO, err_msg)                    self.outputs.remove(s)                except Exception as e:  #发送的时候客户端关闭了则会出现writable和readable同时有数据,会出现message_queues的keyerror                    err_msg = "Send Data Error! ErrMsg:%s" % str(e)                    logging.error(err_msg)                    if s in self.outputs:                        self.outputs.remove(s)                else:                    try:                        cli = s                        pkgType_input = 1                                                errCode = 0                        errMsg = "OK"                        refresh = 1                        text = next_msg                        ext = None                        '''                        logging.info("errCode   :{}".format(errCode))                         logging.info("errMsg    :{}".format(errMsg))                        logging.info("refresh   :{}".format(refresh))                        logging.info("text      :{}".format(text))                        logging.info("ext       :{}".format(ext))                        '''                        msgresp = self.encodeMSG(pkgType_input, errCode, errMsg, refresh, text, ext)                         self.logMSG(msgresp)                        cli.send(msgresp)                    except Exception as e: #发送失败就关掉                        err_msg = "Send Data to %s  Error! ErrMsg:%s" % (str(self.client_info[cli]), str(e))                        logging.error(err_msg)                        print( "Client: %s Close Error." % str(self.client_info[cli]) )                        if cli in self.inputs:                            self.inputs.remove(cli)                            cli.close()                        if cli in self.outputs:                            self.outputs.remove(s)                        if cli in self.message_queues:                            del self.message_queues[s]                        del self.client_info[cli]                        del self.filedesc_wavpath_dict[s]            for s in exceptional:                logging.error("Client:%s Close Error." % str(self.client_info[cli]))                if s in self.inputs:                    self.inputs.remove(s)                    s.close()                if s in self.outputs:                    self.outputs.remove(s)                if s in self.message_queues:                    del self.message_queues[s]                del self.client_info[s]                del self.filedesc_wavpath_dict[s]if "__main__" == __name__:    logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s %(message)s",                      filename='realtime_asr_server.log',                      level=logging.INFO)    Server().run()

后台测试代码

import sysimport timeimport socketimport threadingimport loggingimport jsonimport structclass Client(object):  def __init__(self, host, port=33333, timeout=1, reconnect=2):    self.__host = host    self.__port = port    self.__timeout = timeout    self.__buffer_size = 1024    self.__flag = 1    self.client = None    self.__lock = threading.Lock()    self.__frame_length = 16  def readMSGHead(self, filedesc):    '''            //数据包头:sizeof(PkgHdr)=16            struct PkgHdr{            uint16_t magic;     //magic number:0xfffe            uint16_t bodyLen;   //包体长度,即msg和ext总长,不含包头            uint16_t pkgType;   //0,请求包 1,返回包 ...后续可扩展            uint16_t cmd;       //请求和返回命令字,目前默认填0            uint16_t retCode;   //0成功, 其它失败,返回包填写            uint16_t msgFmt;    //消息格式0:json 1:protobuf ...            uint16_t msgLen;    //消息长度            uint16_t extLen;    //扩展包长度            };    '''    frame_length_bytes = filedesc.recv(self.__frame_length)    magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen = struct.unpack('>HHHHHHHH', frame_length_bytes)    return magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen  def readMSGMsg(self, filedesc, msgLen):    msg = filedesc.recv(msgLen)    msg = msg.decode()    msg_json = json.loads(msg)    return msg_json  def readMSGExt(self, filedesc, extLen):    ext = filedesc.recv(extLen)    return ext  def decodeMSG(self, filedesc):    MSGHeadInfo = self.readMSGHead(filedesc)    msgLen = MSGHeadInfo[6]    extLen = MSGHeadInfo[7]    logging.info("msgLen:{}".format(msgLen))    logging.info("extLen:{}".format(extLen))    msg_json = self.readMSGMsg(filedesc, msgLen)    Ext = self.readMSGExt(filedesc, extLen)    return msg_json, Ext  def encodeMSG(self, pkgType_input, errCode, errMsg, refresh, text, ext):    magic = 0xfffe    bodyLen = 0    pkgType = pkgType_input    cmd = 0    retCode = 0    msgFmt = 0    msgLen = 0    extLen = 0    msg_json_data = {"errCode": errCode, "errMsg": errMsg, "refresh": refresh, "text": text}    msg = json.dumps(msg_json_data)    if msg:      msgLen = len(msg)    if ext:      extLen = len(ext)    bodyLen = msgLen + extLen    msgHeadByteArray = bytearray(      struct.pack('>HHHHHHHH', magic, bodyLen, pkgType, cmd, retCode, msgFmt, msgLen, extLen))    msgMsgByteArray = bytearray(str(msg).encode())    msg = msgHeadByteArray + msgMsgByteArray    if ext:      msgExtByteArray = bytearray(ext)      msg = msg + msgExtByteArray    return msg  @property  def flag(self):    return self.__flag  @flag.setter  def flag(self, new_num):    self.__flag = new_num  def __connect(self):    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)    # client.bind(('0.0.0.0', 12345,))    client.setblocking(True)    client.settimeout(self.__timeout)    client.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # 端口复用    server_host = (self.__host, self.__port)    try:      client.connect(server_host)    except:      raise    return client  def send_msg(self):    if not self.client:      return    while True:      time.sleep(0.1)      #data = sys.stdin.readline().strip()      #data = input("input string:")      data = "hello from client"      if "exit" == data.lower():        with self.__lock:          self.flag = 0        break      data = self.encodeMSG(0, 0, "OK", 0, data, b"xiaojiba")      logging.info("client sendall")      self.client.sendall(data)    return  def recv_msg(self):    if not self.client:      return    while True:      data = None      with self.__lock:        if not self.flag:          print('ByeBye~~')          break      try:        logging.info("client recv")        msg_json, Ext = self.decodeMSG(self.client)        text = msg_json['text']        isend = msg_json['refresh']        logging.info("text:{}".format(text))        logging.info("isend:{}".format(isend))      except socket.timeout:        continue      except:        raise      if data:        print("%s\n" % text)      time.sleep(0.1)    return  def run(self):    self.client = self.__connect()    send_proc = threading.Thread(target=self.send_msg)    recv_proc = threading.Thread(target=self.recv_msg)    recv_proc.start()    send_proc.start()    recv_proc.join()    send_proc.join()    self.client.close()if "__main__" == __name__:  logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s %(message)s",                      #filename='realtime_asr_server.log',                      level=logging.INFO)  Client('192.168.100.41').run()


 
原创粉丝点击