使用Tensorflow物体识别API抠出视频中的猪

来源:互联网 发布:如何更改电脑mac地址 编辑:程序博客网 时间:2024/05/18 03:39

Tensorflow Object Detection API

猪检测代码以及后续进行猪分类的程序都开源在github了。

主要在官方的demo code上做了如下修改:

  1. 扩展det出的box,以更好地包裹目标,crop时限定不超出图像边界[expand_ratio]
  2. 如检测出pig, animal可能都是对的,可以依据运行结果调整接受规则,抑制检测到的概率比较大的无关类别,提高鲁棒性[class_keep]
  3. 使用mini batch的方式,以充分利用GPU提高程序运行效率。

下面重点看一下与obj det API有关的核心代码:

# Load a (frozen) Tensorflow model into memory'''tf.GraphDef():The GraphDef class is an object created by the ProtoBuf. 详见https://www.tensorflow.org/extend/tool developers/graph_def: A GraphDef proto containing operations to be imported into the default graph'''detection_graph = tf.Graph()with detection_graph.as_default():  od_graph_def = tf.GraphDef()  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:    serialized_graph = fid.read()    od_graph_def.ParseFromString(serialized_graph)    tf.import_graph_def(od_graph_def, name='')
'''这里用了几个util函数。'''label_map = label_map_util.load_labelmap(PATH_TO_LABELS)categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)category_index = label_map_util.create_category_index(categories)
'''重点看定义计算图。在这个脚本中图片是通过feed_dict={image_tensor: image_np_expanded})传递给计算图的。之前的博文介绍过如何使用自己生成的tfrecord,另外还可以使用tf1.4新出的dataset API。关于get_tensor_by_name,就是通过名字来获得张量,具体见下面一段小测试代码。但是还是看不出来为什么这个计算图能work,看起来就是获取了几个张量,应该就是检测框等张量依赖于image_tensor,我们去源码里确认一下。发现在object_detection/inference/detection_inference.py文件中build_inference_graph函数里,这个函数主要作用是Loads the inference graph and connects it to the input image.具体如下:  tf.import_graph_def(      graph_def, name='', input_map={'image_tensor': image_tensor})官方文档:input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.再来看看build_inference_graph函数是在哪被调用的。然后发现确实在inference文件夹下被调用了,但是我们这里通过feed的方式并不是调用这个函数。猜想一定是导出网络时定义了image_tensor这个变量名,如在object_detection/exporter.py可以看到image_tensor是placeholder,意料之中。至于计算图具体的连接关系就是模型定义本身了,下次分析训练的代码再看。'''with detection_graph.as_default():  with tf.Session(graph=detection_graph) as sess:    # Definite input and output Tensors for detection_graph    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')    # Each box represents a part of the image where a particular object was detected.    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')    # Each score represent how level of confidence for each of the objects.    # Score is shown on the result image, together with the class label.    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')    num_detections = detection_graph.get_tensor_by_name('num_detections:0')    '''注意此处省略了一些代码'''      (boxes, scores, classes, num) = sess.run(          [detection_boxes, detection_scores, detection_classes, num_detections],          feed_dict={image_tensor: image_np_expanded})
import tensorflow as tfc = tf.constant([[1.0, 2.0], [3.0, 4.0]])d = tf.constant([[1.0, 1.0], [0.0, 1.0]])e = tf.matmul(c, d, name='example')with tf.Session() as sess:    test =  sess.run(e)    print (e.name) #example:0    print(test)    test = tf.get_default_graph().get_tensor_by_name("example:0")    print (test) #Tensor("example:0", shape=(2, 2), dtype=float32)    print (test.eval())'''输出是:example_2:0[[ 1.  3.] [ 3.  7.]]Tensor("example:0", shape=(2, 2), dtype=float32)[[ 1.  3.] [ 3.  7.]]'''
阅读全文
1 0
原创粉丝点击