VOC200成TFRecord

来源:互联网 发布:推广引流软件 编辑:程序博客网 时间:2024/06/06 19:45
#定义TFRecordWrite,关键参数是所要写入文件的目录和名字    while i < len(filenames):        # Open new TFRecord file.        tf_filename = _get_output_filename(output_dir, name, fidx)        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:            j = 0            while i < len(filenames) and j < SAMPLES_PER_FILES:                sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))                sys.stdout.flush()                filename = filenames[i]                img_name = filename[:-4]                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)                i += 1                j += 1            fidx += 1def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):    """Loads data from image and annotations files and add them to a TFRecord.    Args:      dataset_dir: Dataset directory;      name: Image name to add to the TFRecord;      tfrecord_writer: The TFRecord writer to use for writing.    """    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \        _process_image(dataset_dir, name)    example = _convert_to_example(image_data, labels, labels_text,                                  bboxes, shape, difficult, truncated)    tfrecord_writer.write(example.SerializeToString())def _convert_to_example(image_data, labels, labels_text, bboxes, shape,                        difficult, truncated):    """Build an Example proto for an image example.    Args:      image_data: string, JPEG encoding of RGB image;      labels: list of integers, identifier for the ground truth;      labels_text: list of strings, human-readable labels;      bboxes: list of bounding boxes; each box is a list of integers;          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong          to the same label as the image label.      shape: 3 integers, image shapes in pixels.    Returns:      Example proto    """    xmin = []    ymin = []    xmax = []    ymax = []    for b in bboxes:        assert len(b) == 4        # pylint: disable=expression-not-assigned        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]        # pylint: enable=expression-not-assigned    print xmin,ymin,ymax,xmax    image_format = b'JPEG'    example = tf.train.Example(features=tf.train.Features(feature={            'image/height': int64_feature(shape[0]),            'image/width': int64_feature(shape[1]),            'image/channels': int64_feature(shape[2]),            'image/shape': int64_feature(shape),            'image/object/bbox/xmin': float_feature(xmin),            'image/object/bbox/xmax': float_feature(xmax),            'image/object/bbox/ymin': float_feature(ymin),            'image/object/bbox/ymax': float_feature(ymax),            'image/object/bbox/label': int64_feature(labels),            'image/object/bbox/label_text': bytes_feature(labels_text),            'image/object/bbox/difficult': int64_feature(difficult),            'image/object/bbox/truncated': int64_feature(truncated),            'image/format': bytes_feature(image_format),            'image/encoded': bytes_feature(image_data)}))    return example
原创粉丝点击