tensorflow tf.python_io模块源码阅读

来源:互联网 发布:淘宝一件代发发货 编辑:程序博客网 时间:2024/05/21 10:04

1.序言
该模块是tensorflow用来处理tfrecords文件的接口,定义在tensorflow/python/lib/io/python_io.py,主要包含了四个部分:
class TFRecordCompressionType:记录的压缩类型。
class TFRecordOptions:用于操作TFRecord文件的选项。
class TFRecordWriter:将记录写入TFRecords文件的类。
tf_record_iterator(…):从TFRecords文件中读取记录的迭代器

2.源码解析

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom tensorflow.python import pywrap_tensorflowfrom tensorflow.python.framework import errorsfrom tensorflow.python.util import compat#该类定义了tfrecords文件压缩类型:无,ZLIB,GZIP三种class TFRecordCompressionType(object):  """The type of compression for the record."""  NONE = 0  ZLIB = 1  GZIP = 2# 这个类会转换为proto格式,以便与C++接口对接class TFRecordOptions(object):  """Options used for manipulating TFRecord files."""  compression_type_map = {      TFRecordCompressionType.ZLIB: "ZLIB",      TFRecordCompressionType.GZIP: "GZIP",      TFRecordCompressionType.NONE: ""  }  def __init__(self, compression_type):    self.compression_type = compression_type  @classmethod  def get_compression_type_string(cls, options):    if not options:      return ""    return cls.compression_type_map[options.compression_type]def tf_record_iterator(path, options=None):  """从tfrecords文件读取数据的迭代器.  参数:    path: TFRecords文件路径.    options: 读取选项,主要是压缩类型,TFRecordOptions对象.  yields:    Strings.  异常:    IOError: 路径不正确是引发.  """  compression_type = TFRecordOptions.get_compression_type_string(options)  with errors.raise_exception_on_not_ok_status() as status:    reader = pywrap_tensorflow.PyRecordReader_New(        compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)#读取器,pywarp_tensorflow包装所以的符号,这里定义了一个文件读取器对象  if reader is None:    raise IOError("Could not open %s." % path)  while True:    try:      with errors.raise_exception_on_not_ok_status() as status:        reader.GetNext(status)    except errors.OutOfRangeError:      break    yield reader.record() #逐步读取文件  reader.Close()class TFRecordWriter(object):  """tfrecords文件写操作类,由于实施了__enter__和__exit__接口,根据Python的上下文管理机制,可以用with语句  """  # TODO(josh11b): Support appending?  def __init__(self, path, options=None):    """打开文件,并初始化写对象    参数:      path: 文件路径      options: 选项,TFRecordOptions对象    Raises:      IOError: If `path` cannot be opened for writing.    """    compression_type = TFRecordOptions.get_compression_type_string(options)#获取压缩类型    with errors.raise_exception_on_not_ok_status() as status:      self._writer = pywrap_tensorflow.PyRecordWriter_New(          compat.as_bytes(path), compat.as_bytes(compression_type), status)#定义writer  def __enter__(self):    """进入with语句块"""    return self  def __exit__(self, unused_type, unused_value, unused_traceback):    """退出with语句块,并关闭文件"""    self.close()  def write(self, record):    """想文件中写入一条记录.    Args:      record: str    """    self._writer.WriteRecord(record)#实际是由writer实现  def flush(self):    """刷新缓冲区内容到磁盘文件"""    with errors.raise_exception_on_not_ok_status() as status:      self._writer.Flush(status)  def close(self):    """关闭文件"""    with errors.raise_exception_on_not_ok_status() as status:      self._writer.Close(status)
原创粉丝点击