解决Tensorflow中文社区MNIST机器学习入门里使用示例代码无法连接服务器的问题

来源:互联网 发布:java 单例模式优点 编辑:程序博客网 时间:2024/05/09 03:45

今天上课开始讲神经网络,看着我的python3.6,caffe不装了,换tensorflow,gpu版折腾几个小时也装不上——放弃,cpu版先用着。

跟着tensorflow的中文文档先学下,但是很不给力啊,示例就一堆事:

http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

首先这个input_data.py就根本下载不下来,好在百度一下就发现有很多小伙伴备份过了,我也贴一下:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange  # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
我马上试了试:

等了一会儿....根本连不上服务器...被墙啦?我架着梯子呢!可能人品不好,再试一次,还是连不上服务器,这不是坑爹吗!

能不能不从服务器下载啊,毕竟这四个训练集文件是可以手动下载的。

我纠结了一下,自己改代码吧。
先找到主代码的位置:

tensorflow.contrib.learn.python.learn.datasets.mnist.py 里面的:read_data_sets这个函数
果然呐,里面有各种下载的信息,那就把它们删了,直接读文件。

以下是我自己改的,大家有需求就直接复制吧:

四个数据集文件后缀改成gz,直接放在工程根目录里就行。

from tensorflow.contrib.learn.python.learn.datasets.mnist import *def read_data_setss(fake_data=False,                    one_hot=False,                    dtype=dtypes.float32,                    reshape=True,                    validation_size=5000,                    seed=None, ):    if fake_data:        def fake():            return DataSet(                [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)        train = fake()        validation = fake()        test = fake()        return base.Datasets(train=train, validation=validation, test=test)    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'    with gfile.Open(TRAIN_IMAGES, 'rb') as f:        train_images = extract_images(f)    with gfile.Open(TRAIN_LABELS, 'rb') as f:        train_labels = extract_labels(f, one_hot=one_hot)    with gfile.Open(TEST_IMAGES, 'rb') as f:        test_images = extract_images(f)    with gfile.Open(TEST_LABELS, 'rb') as f:        test_labels = extract_labels(f, one_hot=one_hot)    if not 0 <= validation_size <= len(train_images):        raise ValueError('Validation size should be between 0 and {}. Received: {}.'                         .format(len(train_images), validation_size))    validation_images = train_images[:validation_size]    validation_labels = train_labels[:validation_size]    train_images = train_images[validation_size:]    train_labels = train_labels[validation_size:]    options = dict(dtype=dtype, reshape=reshape, seed=seed)    train = DataSet(train_images, train_labels, **options)    validation = DataSet(validation_images, validation_labels, **options)    test = DataSet(test_images, test_labels, **options)    return base.Datasets(train=train, validation=validation, test=test)mnists = read_data_setss(one_hot=True)



阅读全文
0 0
原创粉丝点击