TensorFlow里面mnist导入手写数据代码分析
来源:互联网 发布:sql comment 编辑:程序博客网 时间:2024/06/15 03:46
TensorFlow里面mnist导入手写数据代码分析
本文主要介绍了Tensorflow(TF)手写识别,导入数据源码分析
在 tensorflow/tensorflow/examples/tutorials/mnist 目录下,文件树如下:
[xzy@localhost mnist]$ tree
.
├── BUILD
├── fully_connected_feed.py
├── __init__.py
├── input_data.py
├── mnist_deep.py
├── mnist.py
├── mnist_softmax.py
├── mnist_softmax_xla.py
└── mnist_with_summaries.py
0 directories, 9 files
在 fully_connected_feed.py 里面有一句代码如下:
from tensorflow.examples.tutorials.mnist import input_datadata_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data) <---------调用语句,默认input_data为<---------/tmp/tensorflow/mnist/input_data
打开 input_data.py 文件
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets <--------------------------注意这句
进入 tensorflow/contrib/learn/python/learn/datasets ,打开 mnist.py文件,里面有个def 定义的函数
# CVDF mirror of http://yann.lecun.com/exdb/mnist/SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'... ...def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=dtypes.float32, reshape=True, validation_size=5000, seed=None): if fake_data: def fake(): return DataSet( [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed) train = fake() validation = fake() test = fake() return base.Datasets(train=train, validation=validation, test=test) TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' local_file = base.maybe_download(TRAIN_IMAGES, train_dir, SOURCE_URL + TRAIN_IMAGES) with open(local_file, 'rb') as f: train_images = extract_images(f) local_file = base.maybe_download(TRAIN_LABELS, train_dir, SOURCE_URL + TRAIN_LABELS) with open(local_file, 'rb') as f: train_labels = extract_labels(f, one_hot=one_hot) local_file = base.maybe_download(TEST_IMAGES, train_dir, SOURCE_URL + TEST_IMAGES) with open(local_file, 'rb') as f: test_images = extract_images(f) local_file = base.maybe_download(TEST_LABELS, train_dir, SOURCE_URL + TEST_LABELS) with open(local_file, 'rb') as f: test_labels = extract_labels(f, one_hot=one_hot) if not 0 <= validation_size <= len(train_images): raise ValueError( 'Validation size should be between 0 and {}. Received: {}.' .format(len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] options = dict(dtype=dtype, reshape=reshape, seed=seed) train = DataSet(train_images, train_labels, **options) validation = DataSet(validation_images, validation_labels, **options) test = DataSet(test_images, test_labels, **options) return base.Datasets(train=train, validation=validation, test=test)
着代码里面调用了maybe_download函数下载数据,打开 tensorflow/contrib/learn/python/learn/datasets/base.py 文件
def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. Args: filename: string, name of the file in the directory. work_directory: string, path to working directory. source_url: url to download from if file doesn't exist. Returns: Path to resulting file. """ if not gfile.Exists(work_directory):#判断工作目录不存在就创建 gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) if not gfile.Exists(filepath):#判断输入数据目录不存在就创建 temp_file_name, _ = urlretrieve_with_retry(source_url)#直接将远程数据下载到本地目录,这是python内置的函数 gfile.Copy(temp_file_name, filepath)#定义在tensorflow/python/lib/io/file_io.py,将数据从旧目录复制到新目录 with gfile.GFile(filepath) as f: size = f.size() print('Successfully downloaded', filename, size, 'bytes.') return filepath
阅读全文
0 0
- TensorFlow里面mnist导入手写数据代码分析
- Tensorflow测试Mnist手写数据集
- 【TensorFlow】TensorFlow实现 AlexNet Mnist手写数据集
- Tensorflow MNIST 手写识别
- TensorFlow代码实现(一)[MNIST手写数字识别]
- TensorFlow手写数字识别mnist example源码分析
- tensorflow tutorials(八):手写数字数据集MNIST介绍
- tensorflow入门之mnist手写数据集识别
- tensorflow-mnist手写数字识别
- 【TensorFlow】神经网络MNIST手写识别
- Tensorflow实例:mnist手写数字
- mnist手写数据集
- Tensorflow MNIST 数据集测试代码入门
- 使用tensorflow导入已经下载好的mnist数据集
- 基于tensorflow的MNIST手写数字识别
- Tensorflow #1 祖传例子 MNIST 手写识别
- 基于tensorflow的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- 数值的整数次方(java版)
- CodeForces
- canny边缘检测
- Centos开启SSH
- MUI上拉加载下拉刷新
- TensorFlow里面mnist导入手写数据代码分析
- 魅族大数据之用户洞察平台介绍
- 秒杀高并发
- 在vs2017 上安装 arcgis 10.1 ArcObject sdk
- 批处理文件打开matlab gui文件
- Mac安装VisualBox显示问题
- 如何调整Linux内核启动中的驱动初始化顺序
- 求数组交集不同解法小结
- PMS管理APP安装到data和禁止卸载列表