SVM处理mnist字体库

来源:互联网 发布:网络歌手小右个人 编辑:程序博客网 时间:2024/05/18 20:50
  2017年robomaster比赛中,大神符环节使用的是
# decoding:utf-8import osimport cv2import numpy as npimport codecsfrom cv2.ml import VAR_ORDEREDimport codecsfrom cv2.ml import VAR_ORDEREDfrom canny import *from find_contours import *import numpy as npimport cPickleimport gzipdef vectorized_result(j):   e = np.zeros((10, 1))   e[j] = 1.0   return edef load_data():   mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')   training_data, classification_data, test_data = cPickle.load(mnist)   mnist.close()   return training_data, classification_data, test_datadef wrap_data():   tr_d, va_d, te_d = load_data()   # print type(tr_d), type(va_d), type(te_d)   training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]   training_results = [vectorized_result(y) for y in tr_d[1]]   training_data = zip(training_inputs, training_results)   validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]   validation_data = zip(validation_inputs, va_d[1])   test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]   test_data = zip(test_input, te_d[1])   return training_data, validation_data, test_datadef train_svm(train_file='train_data.txt', test_file= 'train_result.txt'):   svm = cv2.ml.SVM_create()   svm.setType(cv2.ml.SVM_C_SVC)   #自己设置一下SVM参数   svm.setKernel(cv2.ml.SVM_POLY)   t_d = np.loadtxt(train_file, np.float32)   m_d = np.loadtxt(test_file, np.int32)   train_data = cv2.ml.TrainData_create(t_d, cv2.ml.ROW_SAMPLE, m_d)   svm.train(train_data)   return svmdef svm_test(svm, test_data):   le = len(test_data)   sum_tem = 0   for i in range(le):      sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)      a, b =svm.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))      if b[0][0] == test_data[i][1] or test_data[i][1] == 0:         sum_tem += 1   print '正确率 ', float(sum_tem * 1.0/ le)def svm_predict(svm, sample):   resized = sample.copy()   rows, cols = resized.shape   if (rows != 28 or cols != 28) and rows * cols > 0:      resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)   return svm.predict(np.array([resized.ravel()], dtype=np.float32))if __name__ == '__main__':   tr, val, test = wrap_data()   save_path = os.path.join('data', '自己想个文件名')   if os.path.exists(save_path):      print 'find it'      svm = cv2.ml.SVM_load(save_path)   else:      svm = train_svm()      svm.save(save_path)   svm_test(svm, test)
1 0
原创粉丝点击