tensorflow练习9:图像分类器

来源:互联网 发布:安华卫浴知乎 编辑:程序博客网 时间:2024/06/13 11:59

这一节继续使用谷歌的image_retain作为模型进行训练。下载文件:
https://github.com/tensorflow/tensorflow。
使用examples中的image_retraining进行训练:
运行命令:

python tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model --output_graph output_graph.pb --output_labels output_labels.txt --image_dir girl_types/

训练一段时间后,得到输出文件:output_graph.pb与output_labels.txt。使用训练好的文件:

import tensorflow as tfimport sys# 命令行参数,传入要判断的图片路径image_file = sys.argv[1]#print(image_file)# 读取图像image = tf.gfile.FastGFile(image_file, 'rb').read()# 加载图像分类标签labels = []for label in tf.gfile.GFile("output_labels.txt"):    labels.append(label.rstrip())# 加载Graphwith tf.gfile.FastGFile("output_graph.pb", 'rb') as f:    graph_def = tf.GraphDef()    graph_def.ParseFromString(f.read())    tf.import_graph_def(graph_def, name='')with tf.Session() as sess:    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')    predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})    # 根据分类概率进行排序    top = predict[0].argsort()[-len(predict[0]):][::-1]    for index in top:        human_string = labels[index]        score = predict[0][index]        print(human_string, score)
原创粉丝点击