【实验楼】基于BP神经网络的手写体识别——python3版

来源:互联网 发布:mysql复杂查询实例 编辑:程序博客网 时间:2024/06/16 05:13

用自己的机器跑BP神经网络手写体识别,刚开始因为Python2与3之间存在不兼容,所以需要对代码本身做一定的修改。(实验地址:https://www.shiyanlou.com/courses/593)


实验楼参考文档:https://www.shiyanlou.com/courses/593/labs/1966/document

代码提供(Python2版):https://github.com/aosabook/500lines/tree/master/ocr


环境我用的是Python3.6。对参考文档中的代码主要修改了以下几个部分:

1、ocr.py文件中train函数

    def train(self, training_data_array):        for data in training_data_array:            # 前向传播得到结果向量            y1 = np.dot(np.mat(self.theta1), np.mat(data['y0']).T)  # 修改后            sum1 = y1 + np.mat(self.input_layer_bias)            y1 = self.sigmoid(sum1)            y2 = np.dot(np.array(self.theta2), y1)            y2 = np.add(y2, self.hidden_layer_bias)            y2 = self.sigmoid(y2)            # 后向传播得到误差向量            actual_vals = [0] * 10            actual_vals[data['label']] = 1  # 修改后            output_errors = np.mat(actual_vals).T - np.mat(y2)            hidden_errors = np.multiply(np.dot(np.mat(self.theta2).T, output_errors), self.sigmoid_prime(sum1))            # 更新权重矩阵与偏置向量            self.theta1 += self.LEARNING_RATE * np.dot(np.mat(hidden_errors), np.mat(data['y0']))  # 修改后            self.theta2 += self.LEARNING_RATE * np.dot(np.mat(output_errors), np.mat(y1).T)            self.hidden_layer_bias += self.LEARNING_RATE * output_errors            self.input_layer_bias += self.LEARNING_RATE * hidden_errors

2、运行中报TypeError: a bytes-like object is required, not 'str'。定位后在server.py中处理接收到的POST请求一块,因为json.dumps()返回是str,需要对此部分解码即可。

self.send_response(response_code)        self.send_header("Content-type", "application/json")        self.send_header("Access-Control-Allow-Origin", "*")        self.end_headers()        if response:            self.wfile.write(json.dumps(response).encode())  # 修改后        return

3、Python3中不在用BaseHTTPServer,而是并入http.server中,所以此处导入http.server即可。贴出完整的server.py文件:

# -*- coding: UTF-8 -*-import http.server  # 修改后import jsonfrom ocr import OCRNeuralNetworkimport numpy as npimport random# 服务器端配置HOST_NAME = 'localhost'PORT_NUMBER = 9000# 这个值是通过运行神经网络设计脚本得到的最优值HIDDEN_NODE_COUNT = 15# 加载数据集data_matrix = np.loadtxt(open('data.csv', 'rb'), delimiter=',')data_labels = np.loadtxt(open('dataLabels.csv', 'rb'))# 转换成list类型data_matrix = data_matrix.tolist()data_labels = data_labels.tolist()# 数据集一共5000个数据,train_indice存储用来训练的数据的序号train_indice = list(range(5000))# 打乱训练顺序random.shuffle(train_indice)nn = OCRNeuralNetwork(HIDDEN_NODE_COUNT, data_matrix, data_labels, train_indice);class JSONHandler(http.server.BaseHTTPRequestHandler):   # 修改后    """处理接收到的POST请求"""    def do_POST(self):        response_code = 200        response = ""        var_len = int(self.headers.get('Content-Length'))        content = self.rfile.read(var_len)        payload = json.loads(content)        # 如果是训练请求,训练然后保存训练完的神经网络        if payload.get('train'):            nn.train(payload['trainArray'])            nn.save()        # 如果是预测请求,返回预测值        elif payload.get('predict'):            try:                print(nn.predict(data_matrix[0]))                response = {"type": "test", "result": str(nn.predict(payload['image']))}            except:                response_code = 500        else:            response_code = 400        self.send_response(response_code)        self.send_header("Content-type", "application/json")        self.send_header("Access-Control-Allow-Origin", "*")        self.end_headers()        if response:            self.wfile.write(json.dumps(response).encode())  # 修改后        returnif __name__ == '__main__':    server_class = http.server.HTTPServer  # 修改后    httpd = server_class((HOST_NAME, PORT_NUMBER), JSONHandler)    try:        # 启动服务器        httpd.serve_forever()    except KeyboardInterrupt:        pass    else:        print("Unexpected server exception occurred.")    finally:        httpd.server_close()

运行结果:


阅读全文
0 0
原创粉丝点击