Handwritten digit recognition with ANNs
来源:互联网 发布:九泰基金 知乎 编辑:程序博客网 时间:2024/06/07 07:14
digits_ann .py :
import cv2import cPickleimport numpy as npimport gzipdef load_data(): mnist = gzip.open('mnist.pkl.gz', 'rb') training_data, classification_data, test_data = cPickle.load(mnist) mnist.close() return training_data, classification_data, test_datadef wrap_data(): tr_d, va_d, te_d = load_data() training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] training_results = [vectorized_result(y) for y in tr_d[1]] training_data = zip(training_inputs, training_results) validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] validation_data = zip(validation_inputs, va_d[1]) test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]] test_data = zip(test_inputs, te_d[1]) return training_data, validation_data, test_datadef vectorized_result(j): e = np.zeros((10, 1)) e[j] = 1.0 return edef create_ANN(hidden=20): ann = cv2.ml.ANN_MLP_create() ann.setLayerSizes(np.array([784, hidden, 10])) ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP) ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM) ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1)) return anndef train(ann, samples=10000, epochs=1): tr, val, test = wrap_data() for x in xrange(epochs): counter = 0 for img in tr: if(counter > samples): break if(counter % 1000 == 0): print "Epoch %d: Trained %d/%d" % (x, counter, samples) counter += 1 data, digit = img ann.train(np.array([data.ravel()], dtype=np.float32), cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32)) print "Epoch %d complete" % x return ann, testdef test(ann, test_data): sample = np.array(test_data[0][0].ravel(), dtype=np.float32).reshape(28, 28) cv2.imshow("sample", sample) cv2.waitKey() print ann.predict(np.array([test_data[0][0].ravel()], dtype=np.float32))def predict(ann, sample): resized = sample.copy() rows, cols = resized.shape if (rows != 28 or cols != 28) and rows * cols > 0: resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC) return ann.predict(np.array([resized.ravel()], dtype=np.float32))
main.py :
import cv2import numpy as npimport digits_ann as ANNdef inside(r1, r2): x1,y1,w1,h1 = r1 x2,y2,w2,h2 = r2 if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2): return True else: return Falsedef wrap_digit(rect): x, y, w, h = rect padding = 5 hcenter = x + w/2 vcenter = y + h/2 if (h > w): w = h x = hcenter - (w/2) else: h = w y = vcenter - (h/2) return (x-padding, y-padding, w+padding, h+padding)ann, test_data = ANN.train(ANN.create_ANN(56), 20000)font = cv2.FONT_HERSHEY_SIMPLEXpath = "./c.png"img = cv2.imread(path, cv2.IMREAD_UNCHANGED)bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)bw = cv2.GaussianBlur(bw, (7, 7), 0)ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV)thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations=2)image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)rectangles = []for c in cntrs: r = x,y,w,h = cv2.boundingRect(c) a = cv2.contourArea(c) b = (img.shape[0]-3) * (img.shape[1] - 3) is_inside = False for q in rectangles: if inside(r, q): is_inside = True break if not is_inside: if not a == b: rectangles.append(r)for r in rectangles: x,y,w,h = wrap_digit(r) cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2) roi = thbw[y:y+h, x:x+w] try: digit_class = int(ANN.predict(ann, roi.copy())[0]) except: continue cv2.putText(img, "%d" % digit_class, (x, y-1), font, 1, (0, 255, 0))cv2.imshow("thbw", thbw)cv2.imshow("contours", img)cv2.imwrite("results.jpg", img)cv2.waitKey()
results:
compare pictures' results:
Conclusion :
This comparison demonstrated previously reveals its insufficient .
Reference to 《learning opencv3 computer vision with Python》
" mnist.pkl.gz " downloaded from http://www.cnblogs.com/xueliangliu/archive/2013/04/03/2997437.html
0 0
- Handwritten digit recognition with ANNs
- mxnet-Handwritten Digit Recognition程序理解
- Pattern Recognition For HandWritten with Semeion Data
- Atrainable feature extractor for handwritten digit recognition(经典文章阅读)
- MXNet官方文档中文版教程(7):手写数字识别(Handwritten Digit Recognition)
- Digit Recognition via CNN
- Simple Digit Recognition OCR in OpenCV-Python
- Caffe-Based Digit Recognition in Kaggle
- kaggle-Digit Recognition(手写数字识别)
- Speech recognition with Kaldi lectures
- Number plate recognition with Tensorflow
- Classification: Instant Recognition with Caffe
- KNN算法Hadoop实现及kaggle digit recognition数据测试
- 利用Caffe+Python实现Kaggle上Digit Recognition练手项目
- kaggle | Digit recognizer with caffe
- Scalable Recognition with a Vocabulary Tree
- Speech Recognition with Hidden Markov Model
- Action Recognition with DTF + Fisher Vectors
- c++基于对象的编程风格2
- 系统分析与$.ajax()对象
- 关于解决Windows系统许可证即将过期的问题
- hdu 1018 Big Number(公式求阶乘位数)
- 将Java程序打包成可执行文件jar包,然后执行jar包,不引用外部包的情况
- Handwritten digit recognition with ANNs
- JS_中Iterale
- STM32程序死在BEAB BKPT 0xAB解决办法
- BootStrapTable 隐藏列
- 扣减库存策略采用订单是否锁定库存方案
- java double比较大小
- hightchat图表展示
- 阶乘的和
- gdufe acm 1361 校庆抽奖