第七期 使用 Keras 演示神经网络 《显卡就是开发板》

来源:互联网 发布:开淘宝店没有营业执照 编辑:程序博客网 时间:2024/05/22 14:04

  这一期我们来演示一种更加简洁的深度神经网络构建方法–Keras,下面这张图片展示了Keras在网络栈中的位置。

这里写图片描述

  可见Keras是一种比较高级的API,也就是说用它来构建网络使用的代码量会更少,下面用一段代码来演示一下,我们使用通过ImageNet预先训练好的VGG16结构的网络来分类一张图片。

%matplotlib inlinefrom keras.applications.vgg16 import VGG16from keras.preprocessing import imagefrom keras.applications.vgg16 import preprocess_input, decode_predictionsimport numpy as npimport cv2from matplotlib import pyplot as pltmodel = VGG16(weights='imagenet')img_path = 'demo.jpg'img = image.load_img(img_path, target_size=(224, 224))x = image.img_to_array(img)x = np.expand_dims(x, axis=0)x = preprocess_input(x)features = model.predict(x)preds = decode_predictions(features, top=1)[0][0]print(preds)# Label and show the imageimg = cv2.imread('demo.jpg',cv2.IMREAD_COLOR)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) cv2.putText(img, "Label: {}, {:.2f}%".format(preds[1], preds[2] * 100),             (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0,255,0), 2, cv2.LINE_AA)plt.figure(dpi=150)plt.imshow(img, cmap = 'gray', interpolation = 'bicubic')plt.xticks([]), plt.yticks([])  # to hide tick values on X and Y axisplt.show()

这里写图片描述

  通过上面的代码就可以实现将一张图片打上ImageNet中的标签。可以看到,通过jupyter notebook 可以非常直观的展示出已训练模型的计算结果,并且可以发现,使用Keras的API相对与Tensorflow可以非常简洁的产生计算结果,所以Tensorflow的1.4版本可是引入Keras的API了,现在Tensorflow官方文档已经将自己的接口分成High-level API 和 Low-level API,以后我们演示党有福啦,几行代码就可以演示结果了。
  上面的代码运行时会自动下载VGG16的weight文件,如果下载失败可以到我的网盘里下载 https://pan.baidu.com/s/1dE5PrHJ ,将 vgg16_weights_tf_dim_ordering_tf_kernels.h5 和 imagenet_class_index.json
文件保存到 ${HOME}/.keras/models/ 目录下即可。

对应源码地址: https://github.com/aggresss/GPUDemo/blob/master/keras_demo.ipynb
参考文档: https://keras.io/

阅读全文
0 0
原创粉丝点击