【实验楼】基于BP神经网络的手写体识别——python3版
来源:互联网 发布:mysql复杂查询实例 编辑:程序博客网 时间:2024/06/16 05:13
用自己的机器跑BP神经网络手写体识别,刚开始因为Python2与3之间存在不兼容,所以需要对代码本身做一定的修改。(实验地址:https://www.shiyanlou.com/courses/593)
实验楼参考文档:https://www.shiyanlou.com/courses/593/labs/1966/document
代码提供(Python2版):https://github.com/aosabook/500lines/tree/master/ocr
环境我用的是Python3.6。对参考文档中的代码主要修改了以下几个部分:
1、ocr.py文件中train函数
def train(self, training_data_array): for data in training_data_array: # 前向传播得到结果向量 y1 = np.dot(np.mat(self.theta1), np.mat(data['y0']).T) # 修改后 sum1 = y1 + np.mat(self.input_layer_bias) y1 = self.sigmoid(sum1) y2 = np.dot(np.array(self.theta2), y1) y2 = np.add(y2, self.hidden_layer_bias) y2 = self.sigmoid(y2) # 后向传播得到误差向量 actual_vals = [0] * 10 actual_vals[data['label']] = 1 # 修改后 output_errors = np.mat(actual_vals).T - np.mat(y2) hidden_errors = np.multiply(np.dot(np.mat(self.theta2).T, output_errors), self.sigmoid_prime(sum1)) # 更新权重矩阵与偏置向量 self.theta1 += self.LEARNING_RATE * np.dot(np.mat(hidden_errors), np.mat(data['y0'])) # 修改后 self.theta2 += self.LEARNING_RATE * np.dot(np.mat(output_errors), np.mat(y1).T) self.hidden_layer_bias += self.LEARNING_RATE * output_errors self.input_layer_bias += self.LEARNING_RATE * hidden_errors
2、运行中报TypeError: a bytes-like object is required, not 'str'。定位后在server.py中处理接收到的POST请求一块,因为json.dumps()返回是str,需要对此部分解码即可。
self.send_response(response_code) self.send_header("Content-type", "application/json") self.send_header("Access-Control-Allow-Origin", "*") self.end_headers() if response: self.wfile.write(json.dumps(response).encode()) # 修改后 return
3、Python3中不在用BaseHTTPServer,而是并入http.server中,所以此处导入http.server即可。贴出完整的server.py文件:
# -*- coding: UTF-8 -*-import http.server # 修改后import jsonfrom ocr import OCRNeuralNetworkimport numpy as npimport random# 服务器端配置HOST_NAME = 'localhost'PORT_NUMBER = 9000# 这个值是通过运行神经网络设计脚本得到的最优值HIDDEN_NODE_COUNT = 15# 加载数据集data_matrix = np.loadtxt(open('data.csv', 'rb'), delimiter=',')data_labels = np.loadtxt(open('dataLabels.csv', 'rb'))# 转换成list类型data_matrix = data_matrix.tolist()data_labels = data_labels.tolist()# 数据集一共5000个数据,train_indice存储用来训练的数据的序号train_indice = list(range(5000))# 打乱训练顺序random.shuffle(train_indice)nn = OCRNeuralNetwork(HIDDEN_NODE_COUNT, data_matrix, data_labels, train_indice);class JSONHandler(http.server.BaseHTTPRequestHandler): # 修改后 """处理接收到的POST请求""" def do_POST(self): response_code = 200 response = "" var_len = int(self.headers.get('Content-Length')) content = self.rfile.read(var_len) payload = json.loads(content) # 如果是训练请求,训练然后保存训练完的神经网络 if payload.get('train'): nn.train(payload['trainArray']) nn.save() # 如果是预测请求,返回预测值 elif payload.get('predict'): try: print(nn.predict(data_matrix[0])) response = {"type": "test", "result": str(nn.predict(payload['image']))} except: response_code = 500 else: response_code = 400 self.send_response(response_code) self.send_header("Content-type", "application/json") self.send_header("Access-Control-Allow-Origin", "*") self.end_headers() if response: self.wfile.write(json.dumps(response).encode()) # 修改后 returnif __name__ == '__main__': server_class = http.server.HTTPServer # 修改后 httpd = server_class((HOST_NAME, PORT_NUMBER), JSONHandler) try: # 启动服务器 httpd.serve_forever() except KeyboardInterrupt: pass else: print("Unexpected server exception occurred.") finally: httpd.server_close()
运行结果:
阅读全文
0 0
- 【实验楼】基于BP神经网络的手写体识别——python3版
- BP神经网络自由手写体数字识别系统
- 基于BP神经网络的字符识别研究
- 基于BP神经网络的数字识别
- 基于BP神经网络ANN的鼠标手势识别C#.NET实验程序
- 基于BP神经网络ANN的鼠标手势识别C#.NET实验程序
- 基于MFC的手写体识别
- BP神经网络的数据分类—语音特征信号识别
- 用于手写体数字识别的神经网络
- 一种基于BP神经网络的车牌字符识别方法
- 基于BP神经网络的字符识别研究(中文翻译)
- 基于BP神经网络的数字识别基础系统(一)
- 基于BP神经网络的数字识别基础系统(二)
- 深度学习笔记——TensorFlow学习笔记(三)使用TensorFlow实现的神经网络进行MNIST手写体数字识别
- Python3实现BP神经网络
- 基于KNN的手写体识别和数码管数字识别
- MATLAB的bp神经网络识别函数
- BP神经网络的简单字符识别算法
- Android各类有用的开源库项目
- 第K小数
- php核心学习-设计模式的学习-责任链模式
- Digital.Vision.Phoenix.v2015.3.020.Win64 1DVD
- Myeclipse错误:Errors occurred during the build. Errors running builder 'DeploymentBuilder' on project
- 【实验楼】基于BP神经网络的手写体识别——python3版
- Spring Boot
- 走迷宫
- 传统IT七大职业的云计算转型之路
- 哲学家吃饭问题(资源加锁和超时释放)
- leetcode(380). Insert Delete GetRandom O(1)
- java字符串
- FastDFS文件服务器安装文档
- 建造者模式