TensorFlow里面mnist导入手写数据代码分析

来源:互联网 发布:sql comment 编辑:程序博客网 时间:2024/06/15 03:46

TensorFlow里面mnist导入手写数据代码分析

本文主要介绍了Tensorflow(TF)手写识别,导入数据源码分析


  tensorflow/tensorflow/examples/tutorials/mnist   目录下,文件树如下:

[xzy@localhost mnist]$ tree
.
├── BUILD
├── fully_connected_feed.py
├── __init__.py
├── input_data.py
├── mnist_deep.py
├── mnist.py
├── mnist_softmax.py
├── mnist_softmax_xla.py
└── mnist_with_summaries.py

0 directories, 9 files

  fully_connected_feed.py  里面有一句代码如下:

from tensorflow.examples.tutorials.mnist import input_data
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)    <---------调用语句,默认input_data为
                                                                                                                                                                   <---------/tmp/tensorflow/mnist/input_data

打开  input_data.py   文件

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange  # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets        <--------------------------注意这句

进入  tensorflow/contrib/learn/python/learn/datasets  ,打开  mnist.py文件,里面有个def 定义的函数

# CVDF mirror of http://yann.lecun.com/exdb/mnist/SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'... ...def read_data_sets(train_dir,                   fake_data=False,                   one_hot=False,                   dtype=dtypes.float32,                   reshape=True,                   validation_size=5000,                   seed=None):  if fake_data:    def fake():      return DataSet(          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)    train = fake()    validation = fake()    test = fake()    return base.Datasets(train=train, validation=validation, test=test)  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'  local_file = base.maybe_download(TRAIN_IMAGES, train_dir, SOURCE_URL + TRAIN_IMAGES)  with open(local_file, 'rb') as f:    train_images = extract_images(f)  local_file = base.maybe_download(TRAIN_LABELS, train_dir, SOURCE_URL + TRAIN_LABELS)  with open(local_file, 'rb') as f:    train_labels = extract_labels(f, one_hot=one_hot)  local_file = base.maybe_download(TEST_IMAGES, train_dir, SOURCE_URL + TEST_IMAGES)  with open(local_file, 'rb') as f:    test_images = extract_images(f)  local_file = base.maybe_download(TEST_LABELS, train_dir, SOURCE_URL + TEST_LABELS)   with open(local_file, 'rb') as f:    test_labels = extract_labels(f, one_hot=one_hot)  if not 0 <= validation_size <= len(train_images):    raise ValueError(        'Validation size should be between 0 and {}. Received: {}.'        .format(len(train_images), validation_size))  validation_images = train_images[:validation_size]  validation_labels = train_labels[:validation_size]  train_images = train_images[validation_size:]  train_labels = train_labels[validation_size:]  options = dict(dtype=dtype, reshape=reshape, seed=seed)  train = DataSet(train_images, train_labels, **options)  validation = DataSet(validation_images, validation_labels, **options)  test = DataSet(test_images, test_labels, **options)  return base.Datasets(train=train, validation=validation, test=test)

着代码里面调用了maybe_download函数下载数据,打开 tensorflow/contrib/learn/python/learn/datasets/base.py 文件

def maybe_download(filename, work_directory, source_url):  """Download the data from source url, unless it's already here.  Args:      filename: string, name of the file in the directory.      work_directory: string, path to working directory.      source_url: url to download from if file doesn't exist.  Returns:      Path to resulting file.  """  if not gfile.Exists(work_directory):#判断工作目录不存在就创建    gfile.MakeDirs(work_directory)  filepath = os.path.join(work_directory, filename)  if not gfile.Exists(filepath):#判断输入数据目录不存在就创建    temp_file_name, _ = urlretrieve_with_retry(source_url)#直接将远程数据下载到本地目录,这是python内置的函数    gfile.Copy(temp_file_name, filepath)#定义在tensorflow/python/lib/io/file_io.py,将数据从旧目录复制到新目录    with gfile.GFile(filepath) as f:      size = f.size()    print('Successfully downloaded', filename, size, 'bytes.')  return filepath