用 keras 建立超简单的汉字识别模型
来源:互联网 发布:淘宝海外物流怎么发货 编辑:程序博客网 时间:2024/06/06 19:25
之前看过很多 mnist 的识别模型,都是识别数字的,为啥不做一个汉字识别模型呢?因为汉字手写的库找不到啊。当时我还想自己从字库生成汉字用作识别(已经做出来了,导出字体图片再识别之)。
后来看了这篇文章和这篇文章 : CASIA-HWDB 这个神奇的东西。原文是用 tensorflow 实现的,比较复杂,现在改成用 keras 去完成。
数据集下载
$ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip # zip 解压没得说, 之后还要解压 alz 压缩文件 $ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
正好用新学的 keras 来尝试建模识别。
首先要将下载来的 gnt 文件解压。这部分我完全不懂,图像处理部分直接使用他们的代码了。
其中 3500.txt 是常用的 3500 个汉字,这个我用来跟另外一个根据字体生成汉字的脚本配合使用。
import os import numpy as np import struct from PIL import Image data_dir = './hanwriting' train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt') test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt') # f = open('3500.txt', 'r', encoding="utf8") f = open('3500.txt', 'r') total_words = f.readlines()[0].decode("utf-8") print(total_words) def read_from_gnt_dir(gnt_dir=train_data_dir): def one_file(f): header_size = 10 while True: header = np.fromfile(f, dtype='uint8', count=header_size) if not header.size: break sample_size = header[0] (header[1] << 8) (header[2] << 16) (header[3] << 24) tagcode = header[5] (header[4] << 8) width = header[6] (header[7] << 8) height = header[8] (header[9] << 8) if header_size width*height != sample_size: break image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width)) yield image, tagcode for file_name in os.listdir(gnt_dir): if file_name.endswith('.gnt'): file_path = os.path.join(gnt_dir, file_name) with open(file_path, 'rb') as f: for image, tagcode in one_file(f): yield image, tagcode char_set = set() for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') char_set.add(tagcode_unicode) char_list = list(char_set) char_dict = dict(zip(sorted(char_list), range(len(char_list)))) print len(char_dict) import pickle f = open('char_dict', 'wb') pickle.dump(char_dict, f) f.close() train_counter = 0 test_counter = 0 for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') if tagcode_unicode in total_words: im = Image.fromarray(image) dir_name = './data/train/' '%0.5d'%char_dict[tagcode_unicode] if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert('RGB').save(dir_name '/' str(train_counter) '.png') train_counter = 1 for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') if tagcode_unicode in total_words: im = Image.fromarray(image) dir_name = './data/test/' '%0.5d'%char_dict[tagcode_unicode] if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert('RGB').save(dir_name '/' str(test_counter) '.png') test_counter = 1
解压完会生成一个 train 和一个 test 的文件夹,里面分别用数字为文件夹名,里面都是一些别人手写的汉字的图片。
如果用 tensorflow 写的话,大概需要 300 行,需要处理图像(当然 tf 也会帮你处理大部分繁琐的操作),需要写批量加载,还有各种东西。
到了 keras,十分简单。总共的代码就 70 多行,连图像加载和偏移处理都是智能的。图片转换都给你包办了,简直贴心。
from __future__ import print_function import os from keras.preprocessing.image import ImageDataGenerator from keras.layers import Input, Dense, Dropout, Convolution2D, MaxPooling2D, Flatten from keras.models import Model, load_model data_dir = './data' train_data_dir = os.path.join(data_dir, 'train') test_data_dir = os.path.join(data_dir, 'test') # dimensions of our images. img_width, img_height = 64, 64 charset_size = 3751 nb_validation_samples = 800 nb_samples_per_epoch = 2000 nb_nb_epoch = 20000; def train(model): train_datagen = ImageDataGenerator( rescale=1. / 255, rotation_range=0, width_shift_range=0.1, height_shift_range=0.1 ) test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_data_dir, target_size=(img_width, img_height), batch_size=1024, color_mode="grayscale", class_mode='categorical') validation_generator = test_datagen.flow_from_directory( test_data_dir, target_size=(img_width, img_height), batch_size=1024, color_mode="grayscale", class_mode='categorical') model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) model.fit_generator(train_generator, samples_per_epoch=nb_samples_per_epoch, nb_epoch=nb_nb_epoch, validation_data=validation_generator, nb_val_samples=nb_validation_samples) def build_model(include_top=True, input_shape=(64, 64, 1), classes=charset_size): img_input = Input(shape=input_shape) x = Convolution2D(32, 3, 3, activation='relu', border_mode='same', name='block1_conv1')(img_input) x = Convolution2D(32, 3, 3, activation='relu', border_mode='same', name='block1_conv2')(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) x = Convolution2D(64, 3, 3, activation='relu', border_mode='same', name='block2_conv1')(x) x = Convolution2D(64, 3, 3, activation='relu', border_mode='same', name='block2_conv2')(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) if include_top: x = Flatten(name='flatten')(x) x = Dropout(0.05)(x) x = Dense(1024, activation='relu', name='fc2')(x) x = Dense(classes, activation='softmax', name='predictions')(x) model = Model(img_input, x, name='model') return model model = build_model() # model = load_model("./model.h5") train(model) # model.save("./model.h5")
可以看到生成模型的代码就 12 行,十分简洁。开头两套双卷积池化层,后面接一个 dropout 防过拟合,再接两个全链接层,最后一个 softmax 输出结果。
于是开我的 GTX1080 机器开跑,大约花了半天时间。
Epoch 20000/20000
1024/2000 [==============>...............] - ETA: 0s - loss: 0.2178 - acc: 0.9482
2048/2000 [==============================] - 2s - loss: 0.2118 - acc: 0.9478 - val_loss: 0.4246 - val_acc: 0.9102
在 20000 次 Epoch 后,准确率在 95%,验证的准确率在 91%左右,基本可以识别大部分库里的汉字了。
实际看来汉字识别是图像识别的一种,不过汉字数量比较多,很多手写的连人类都无法识别,估计难以达到 mnist 数据集的准确率。
最后可以看到,keras 是非常适合新手阶段去尝试的,代码也十分简洁。不过由于底层隐藏的比较深,如果深入研究的话容易会遇到瓶颈,而且包装太多,想对他做出修改优化也不是太容易。后期研究还是建议使用 tensorflow 和 pytorch。(个人在看 pytorch,比 tensorflow 要简洁不少,而且大部分 paper 都移植过去了,github 最近热门全是他,潜力无限)
- 用 keras 建立超简单的汉字识别模型
- 基于keras建立简单的CNN
- iOS简单的手写汉字识别
- keras的模型可视化
- 语音识别 一个超简单的语音听写识别编程
- 超简单:用Google Spreadsheets建立留言板
- Keras学习---RNN模型建立篇
- 一个Delphi超简单的取汉字首拼函数
- deep learning keras: 关于动物识别的vgg_16模型与调优
- 解读 Keras 在 ImageNet 中的应用:详解 5 种主要的图像识别模型
- 看似简单的行为识别模型
- 使用PowerPivot建立简单的分析模型
- 保存Keras训练的模型
- Keras上实现简单线性回归模型
- 一个超简单的语音识别编程,听写程序
- 识别汉字的方法
- Node.js建立一个超简单的HTTP服务器
- [5]深度学习和Keras----一个图像识别的简单Demo
- oracle创建表空间和授权
- stylus入门使用方法
- 欢乐西游通用缓存系统设计—应用Redis
- Problem J: 新奇的加法运算
- Java访问修饰符
- 用 keras 建立超简单的汉字识别模型
- Linux命令行——stat命令详解
- 【转】C++中读取一行数据:get和getline
- NoSQL(Not Only SQL)不同分类
- 数据库创建内存表
- 【java-web开发】spring复习
- win32实现画图小程序
- 第一次启动Android studio创建文档时出现错误解决方法
- 算法导论 练习题 15.3-3