SVM处理mnist字体库
来源:互联网 发布:网络歌手小右个人 编辑:程序博客网 时间:2024/05/18 20:50
2017年robomaster比赛中,大神符环节使用的是
# decoding:utf-8import osimport cv2import numpy as npimport codecsfrom cv2.ml import VAR_ORDEREDimport codecsfrom cv2.ml import VAR_ORDEREDfrom canny import *from find_contours import *import numpy as npimport cPickleimport gzipdef vectorized_result(j): e = np.zeros((10, 1)) e[j] = 1.0 return edef load_data(): mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb') training_data, classification_data, test_data = cPickle.load(mnist) mnist.close() return training_data, classification_data, test_datadef wrap_data(): tr_d, va_d, te_d = load_data() # print type(tr_d), type(va_d), type(te_d) training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] training_results = [vectorized_result(y) for y in tr_d[1]] training_data = zip(training_inputs, training_results) validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] validation_data = zip(validation_inputs, va_d[1]) test_input = [np.reshape(x, (784, 1)) for x in te_d[0]] test_data = zip(test_input, te_d[1]) return training_data, validation_data, test_datadef train_svm(train_file='train_data.txt', test_file= 'train_result.txt'): svm = cv2.ml.SVM_create() svm.setType(cv2.ml.SVM_C_SVC) #自己设置一下SVM参数 svm.setKernel(cv2.ml.SVM_POLY) t_d = np.loadtxt(train_file, np.float32) m_d = np.loadtxt(test_file, np.int32) train_data = cv2.ml.TrainData_create(t_d, cv2.ml.ROW_SAMPLE, m_d) svm.train(train_data) return svmdef svm_test(svm, test_data): le = len(test_data) sum_tem = 0 for i in range(le): sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28) a, b =svm.predict(np.array([test_data[i][0].ravel()], dtype=np.float32)) if b[0][0] == test_data[i][1] or test_data[i][1] == 0: sum_tem += 1 print '正确率 ', float(sum_tem * 1.0/ le)def svm_predict(svm, sample): resized = sample.copy() rows, cols = resized.shape if (rows != 28 or cols != 28) and rows * cols > 0: resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC) return svm.predict(np.array([resized.ravel()], dtype=np.float32))if __name__ == '__main__': tr, val, test = wrap_data() save_path = os.path.join('data', '自己想个文件名') if os.path.exists(save_path): print 'find it' svm = cv2.ml.SVM_load(save_path) else: svm = train_svm() svm.save(save_path) svm_test(svm, test)
1 0
- SVM处理mnist字体库
- ANN处理mnist字体库
- mnist svm
- MNIST数据集处理
- Windows Caffe 学习笔记(四)搭建自己的网络,训练和测试MNIST手写字体库
- svm处理流程
- 处理成svm的
- mnist
- mnist
- mnist
- MNIST
- MNIST数据库处理--matlab生成mnist_uint8.mat
- 初涉LeNet5处理mnist (CNN卷积神经网络)
- SVM处理多分类情况
- iOS字体库
- icon字体库
- 字体库引用
- iOS字体库
- Windows下使用Hadoop2.6.0-eclipse-plugin插件
- 求字符串的子集
- Log4J入门教程(一) 入门例程
- react demo2 (JSX入门)
- 锁的种类
- SVM处理mnist字体库
- RSA 数据加密解密
- Android日常错误-----app按home键,再次点击图标直接进入APP,以及APP保活问题
- Linux命令中Ctrl+z、Ctrl+c和Ctrl+d的区别和使用
- 猜神童年龄
- 二叉树的链式存储,先序建树,以及4种遍历方式
- Kaldi学习笔记(二)
- Hdu 3401 题解 单调队列优化DP
- oracle分析函数系列之rank,dense_rank,row_number:实现排名策略