caffe下用PyQt5预测手写数字输入小程序

来源:互联网 发布:c语言 音乐函数 编辑:程序博客网 时间:2024/05/16 17:22
  • PyQt5的界面
    主界面
    在qtDesigner中设计,并在pyClarm中转换得如下python代码
from PyQt5 import QtCore, QtGui, QtWidgetsclass Ui_MainWindow(object):    def setupUi(self, MainWindow):        MainWindow.setObjectName("MainWindow")        MainWindow.resize(319, 197)        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)        sizePolicy.setHorizontalStretch(0)        sizePolicy.setVerticalStretch(0)        sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())        MainWindow.setSizePolicy(sizePolicy)        self.centralWidget = QtWidgets.QWidget(MainWindow)        self.centralWidget.setObjectName("centralWidget")        self.widget = DigitalMnistNum(self.centralWidget)        self.widget.setGeometry(QtCore.QRect(30, 20, 140, 140)) #画布用140*140        self.widget.setObjectName("widget")        self.verticalLayoutWidget = QtWidgets.QWidget(self.centralWidget)        self.verticalLayoutWidget.setGeometry(QtCore.QRect(190, 20, 106, 141))        self.verticalLayoutWidget.setObjectName("verticalLayoutWidget")        self.verticalLayout = QtWidgets.QVBoxLayout(self.verticalLayoutWidget)        self.verticalLayout.setContentsMargins(11, 11, 11, 11)        self.verticalLayout.setSpacing(6)        self.verticalLayout.setObjectName("verticalLayout")        self.clearBtn = QtWidgets.QPushButton(self.verticalLayoutWidget)        self.clearBtn.setObjectName("clearBtn")        self.verticalLayout.addWidget(self.clearBtn)        self.saveBtn = QtWidgets.QPushButton(self.verticalLayoutWidget)        self.saveBtn.setObjectName("saveBtn")        self.verticalLayout.addWidget(self.saveBtn)        self.predictBtn = QtWidgets.QPushButton(self.verticalLayoutWidget)        self.predictBtn.setObjectName("predictBtn")        self.verticalLayout.addWidget(self.predictBtn)        self.result = QtWidgets.QLabel(self.verticalLayoutWidget)        font = QtGui.QFont()        font.setFamily("Arial")        font.setPointSize(26)        font.setBold(True)        font.setWeight(75)        self.result.setFont(font)        self.result.setObjectName("result")        self.verticalLayout.addWidget(self.result)        self.verticalLayout.setStretch(0, 1)        self.verticalLayout.setStretch(1, 1)        self.verticalLayout.setStretch(2, 1)        self.verticalLayout.setStretch(3, 2)        MainWindow.setCentralWidget(self.centralWidget)        self.retranslateUi(MainWindow)        self.clearBtn.clicked.connect(MainWindow.on_clearBtn)        self.saveBtn.clicked.connect(MainWindow.on_saveBtn)        self.predictBtn.clicked.connect(MainWindow.on_predictBtn)        QtCore.QMetaObject.connectSlotsByName(MainWindow)    def retranslateUi(self, MainWindow):        _translate = QtCore.QCoreApplication.translate        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))        self.clearBtn.setText(_translate("MainWindow", "clear"))        self.saveBtn.setText(_translate("MainWindow", "save"))        self.predictBtn.setText(_translate("MainWindow", "predict"))        self.result.setText(_translate("MainWindow", "result"))# DigitalMnistNum为数字画板的子类from DigitalMnistNum import DigitalMnistNum
  • 主窗口的具体实现
from PyQt5 import QtCore, QtWidgetsfrom Ui_MainWindow import Ui_MainWindowimport caffeclass MainWindow(QtWidgets.QMainWindow):    def __init__(self, parent=None):        super(MainWindow,self).__init__(parent)        self.ui = Ui_MainWindow()        self.ui.setupUi(self)    def setNet(self, net):        self._net = net    def on_clearBtn(self):        QtCore.qDebug(str("on_predictBtn"))        self.ui.widget.clearBitmap()    def on_saveBtn(self):        QtCore.qDebug(str("on_predictBtn"))        self.ui.widget.saveBitmap()    # 预测过程,先将文件保存为aaa.bmp,再预测最可能值,并显示    def on_predictBtn(self):        self.ui.widget.saveBitmap()        input_image = []        IMAGE_FILE = r"aaa.bmp"        input_image.append(caffe.io.load_image(IMAGE_FILE, color=False))        prediction = self._net.predict(input_image, oversample = False)        self.ui.result.setText(str(prediction[0].argmax()))    def setLabelText(self, text):        self.ui.result.setText(text)    def setBitmapSize(self, size):        self.ui.widget.setBitmapSize(size)
  • 实现及画图板功能
from PyQt5 import QtCore, QtGui, QtWidgetsclass DigitalMnistNum(QtWidgets.QWidget):    def __init__(self, parent=None):        super(DigitalMnistNum, self).__init__(parent)        self.pen = QtGui.QPen()        self.pen.setStyle(QtCore.Qt.SolidLine)        self.pen.setWidth(12)              #笔的粗细        self.pen.setColor(QtCore.Qt.white) #白字        self.bitmapSize = QtCore.QSize(28, 28)        self.resetBitmap()    def resetBitmap(self):        self.pix = QtGui.QBitmap(self.size())        self.pix.fill(QtCore.Qt.black)     #黑底    def clearBitmap(self):        self.resetBitmap()        self.update()    def saveBitmap(self):        fileName = str("aaa.bmp")        tmp = self.pix.scaled(self.bitmapSize, QtCore.Qt.KeepAspectRatio) #保存图片        QtCore.qDebug(str(tmp.size()))        tmp.save(fileName)    def setBitmapSize(self, size):        self.bitmapSize = QtCore.QSize(size[0], size[1])    def mousePressEvent(self, event):        if event.button() == QtCore.Qt.LeftButton:            self.startPos = event.pos()            painter = QtGui.QPainter()            painter.begin(self.pix)            painter.setPen(self.pen)            painter.drawPoint(self.startPos)            painter.end()        self.update()    def mouseMoveEvent(self, event):        painter = QtGui.QPainter()        painter.begin(self.pix)        painter.setPen(self.pen)        painter.drawLine(self.startPos, event.pos())        painter.end()        self.startPos = event.pos()        self.update()    def paintEvent(self, event):        if self.size() != self.pix.size():            QtCore.qDebug(str(self.size()) + "," + str(self.pix.size()) + "," + str(event.type()))            self.resetBitmap()        painter = QtGui.QPainter(self)        painter.drawPixmap(QtCore.QPoint(0, 0), self.pix)    def mouseReleaseEvent(self, event):        self.update()
  • 入口调用
import sysfrom PyQt5 import QtWidgets, QtGuifrom MainWindowC import MainWindowimport caffeapp = QtWidgets.QApplication(sys.argv)win = MainWindow()win.show()win.setLabelText("LOAD")caffe_root = r"D:\caffe"sys.path.insert(0, caffe_root+r'\python')MODEL_FILE = caffe_root + r"\examples\mnist\lenet.prototxt"PRETRAINED = caffe_root + r"\examples\mnist\lenet_iter_10000.caffemodel"image_dims = [140, 140]win.setBitmapSize(image_dims)net = caffe.Classifier(MODEL_FILE, PRETRAINED, image_dims=image_dims)win.setNet(net)win.setLabelText("OK")sys.exit(app.exec_())

但是预测效果很差,网络及资料上找不到原因。训练集上效果不错,但真实手写确不行,难到是中国人与外国人手写差别大?下一步可以考虑用更深、更复杂的网络(如加入dropout层)来重新训练、测试、预测。