ROS机器人Diego 1#整合Tensorflow object_detection,图像识别

来源:互联网 发布:怎么把剪切的数据还原 编辑:程序博客网 时间:2024/06/10 00:18

google最近又公布了物体识别的Api,使得图像识别变得更加方便,并提供了一个预训练模型,及示例代码,官方文档请见https://github.com/tensorflow/models/blob/master/object_detection/g3doc/installation.md
从官方提供的效果图来看效果还是很不错的,这篇文章就基于官方提供的示例代码,制作一个ROS节点,订阅Image主题,然后调用Object detection api来识别,再将识别的结果,通过CompressedImage消息发送出去。

1.安装Object_detection

这里假设已经安装好了tensorflow,object_detection只要按照官方的安装说明安装即可,如下安装的脚步

sudo apt-get install protobuf-compiler python-pil python-lxmlsudo pip install jupytersudo pip install matplotlib

2.创建diego_tensorflow 包

catkin_create_pkg diego_tensorflow std_msgs rospy roscpp cv_bridge

这里写图片描述
在diego_tensorflow目录下创建scripts和launch目录
这里写图片描述

scripts目录用于存放Python的源代码
launch目录用于存放ROS launch文件

下载object_detection代码到scripts目录
这里写图片描述

3.ROS节点

代码请见ObjectDetectionDemo.py

#!/usr/bin/env pythonimport rospyfrom sensor_msgs.msg import Image as ROSImagefrom sensor_msgs.msg import CompressedImage as ROSImage_Cfrom cv_bridge import CvBridgeimport cv2import matplotlibimport numpy as npimport osimport six.moves.urllib as urllibimport sysimport tarfileimport tensorflow as tfimport zipfilefrom collections import defaultdictfrom io import StringIOfrom PIL import Image# This is needed since the notebook is stored in the object_detection folder.from object_detection.utils import label_map_utilfrom object_detection.utils import visualization_utils as vis_utilclass ObjectDetectionDemo():    def __init__(self):    rospy.init_node('object_detection_demo')    # Set the shutdown function (stop the robot)        rospy.on_shutdown(self.shutdown)        model_path = rospy.get_param("~model_path", "")        image_topic = rospy.get_param("~image_topic", "")        self._cv_bridge = CvBridge()        # What model to download.    MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'    MODEL_FILE = MODEL_NAME + '.tar.gz'    DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'    # Path to frozen detection graph. This is the actual model that is used for the object detection.    PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'    # List of the strings that is used to add correct label for each box.    PATH_TO_LABELS = os.path.join(model_path+'/data', 'mscoco_label_map.pbtxt')    #PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')    NUM_CLASSES = 90    opener = urllib.request.URLopener()    opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)    tar_file = tarfile.open(MODEL_FILE)    for file in tar_file.getmembers():        file_name = os.path.basename(file.name)        if 'frozen_inference_graph.pb' in file_name:                tar_file.extract(file, os.getcwd())    self.detection_graph = tf.Graph()    with self.detection_graph.as_default():        od_graph_def = tf.GraphDef()        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:            serialized_graph = fid.read()            od_graph_def.ParseFromString(serialized_graph)            tf.import_graph_def(od_graph_def, name='')    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)    self.category_index = label_map_util.create_category_index(categories)    self._sub = rospy.Subscriber(image_topic, ROSImage, self.callback, queue_size=1)        self._pub = rospy.Publisher('object_detection', ROSImage_C, queue_size=1)    def callback(self, image_msg):    with self.detection_graph.as_default():        with tf.Session(graph=self.detection_graph) as sess:             cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")             pil_img = Image.fromarray(cv_image)                          (im_width, im_height) = pil_img.size                         # the array based representation of the image will be used later in order to prepare the             # result image with boxes and labels on it.             image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)             # Expand dimensions since the model expects images to have shape: [1, None, None, 3]             image_np_expanded = np.expand_dims(image_np, axis=0)             image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')             # Each box represents a part of the image where a particular object was detected.             boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')             # Each score represent how level of confidence for each of the objects.             # Score is shown on the result image, together with the class label.             scores = self.detection_graph.get_tensor_by_name('detection_scores:0')             classes = self.detection_graph.get_tensor_by_name('detection_classes:0')             num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')             # Actual detection.             (boxes, scores, classes, num_detections) = sess.run(                [boxes, scores, classes, num_detections],                feed_dict={image_tensor: image_np_expanded})             # Visualization of the results of a detection.             vis_util.visualize_boxes_and_labels_on_image_array(                image_np,                np.squeeze(boxes),                np.squeeze(classes).astype(np.int32),                np.squeeze(scores),                self.category_index,                use_normalized_coordinates=True,                line_thickness=8)             ros_compressed_image=self._cv_bridge.cv2_to_compressed_imgmsg(image_np)             self._pub.publish(ros_compressed_image)    def shutdown(self):        rospy.loginfo("Stopping the tensorflow object detection...")        rospy.sleep(1) if __name__ == '__main__':    try:        ObjectDetectionDemo()        rospy.spin()    except rospy.ROSInterruptException:        rospy.loginfo("RosTensorFlow_ObjectDetectionDemo has started.")

这段代码主要是将官方提供的示例代码,封装为ROS节点,订阅image主题,结果作为CompressedImage发布,具体的识别原理可以查看官方的说明,这里只介绍ROS封装部分的代码

    def __init__(self):    rospy.init_node('object_detection_demo')    # Set the shutdown function (stop the robot)        rospy.on_shutdown(self.shutdown)        model_path = rospy.get_param("~model_path", "")        image_topic = rospy.get_param("~image_topic", "")

以上代码是标准的ROS初始化话代码,和参数读取代码

    self._sub = rospy.Subscriber(image_topic, ROSImage, self.callback, queue_size=1)        self._pub = rospy.Publisher('object_detection', ROSImage_C, queue_size=1)

以上代码订阅了Image主题,并定义回调函数callback。定义发布的主题为object_detection,后续我们可以订阅object_detection主题来显示识别结果

    def callback(self, image_msg):    with self.detection_graph.as_default():        with tf.Session(graph=self.detection_graph) as sess:             cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")             pil_img = Image.fromarray(cv_image)                          (im_width, im_height) = pil_img.size                         # the array based representation of the image will be used later in order to prepare the             # result image with boxes and labels on it.             image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)             # Expand dimensions since the model expects images to have shape: [1, None, None, 3]             image_np_expanded = np.expand_dims(image_np, axis=0)             image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')             # Each box represents a part of the image where a particular object was detected.             boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')             # Each score represent how level of confidence for each of the objects.             # Score is shown on the result image, together with the class label.             scores = self.detection_graph.get_tensor_by_name('detection_scores:0')             classes = self.detection_graph.get_tensor_by_name('detection_classes:0')             num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')             # Actual detection.             (boxes, scores, classes, num_detections) = sess.run(                [boxes, scores, classes, num_detections],                feed_dict={image_tensor: image_np_expanded})             # Visualization of the results of a detection.             vis_util.visualize_boxes_and_labels_on_image_array(                image_np,                np.squeeze(boxes),                np.squeeze(classes).astype(np.int32),                np.squeeze(scores),                self.category_index,                use_normalized_coordinates=True,                line_thickness=8)             ros_compressed_image=self._cv_bridge.cv2_to_compressed_imgmsg(image_np)             self._pub.publish(ros_compressed_image)

callback函数是主要的处理函数,将摄像头捕捉到的图片,经过识别处理后发布为object_detection主题

4.launch文件

在launch文件目录下创建object_detection_demo.launch,文件内容如下

<launch>    <node pkg="diego_tensorflow" name="ObjectDetectionDemo" type="ObjectDetectionDemo.py" output="screen">        <param name="image_topic" value="/usb_cam/image_raw" />        <param name="model_path" value="$(find diego_tensorflow)/scripts/object_detection" />    </node> </launch>

5.启动

roscoreroslaunch usb_cam usb_cam-test.launchroslaunch diego_tensorflow object_detection_demo.launch

6.通过手机APP订阅识别处理结果

我们可以用手机APP订阅处理结果,只需要将Image的主题设置为/object_detection就可以了
这里写图片描述
下图为增加深度信息的效果图,深度信息是利用深度相机xtion计算出来,可以利用此信息来判断识别物体与机器人之间的距离。
这里写图片描述
这里写图片描述

从测试结果开看,只要是模型中包含的,识别效果还是非常好的,只是我diego1# 虽然使用了I7的处理器,但还是有明显的滞后,4核的I7cpu使用率已经非常高了,8g的内存也使用到了一半,看来还是需要专门的GPU来处理
这里写图片描述
Tensorflow object detection提供了如下所示的5种模型,这里只是使用了ssd_mobilenet_v1_coco,这是一个精简模型,适合在移动设备上使用,而更加准确的模型应该是faster_rcnn_inception_resnet_v2_atrous_coco,但运行此模型需要专门的GPU,在i7的mini pc上如果针对每个视频帧都处理的话,处理速度基本上和蜗牛差不多,但识别率,明显高很多。
object detection 模型

阅读全文
0 0
原创粉丝点击