Resnet Cifar-10调试

来源:互联网 发布:自动化编程是什么 编辑:程序博客网 时间:2024/06/03 05:27

一、下载和运行

https://github.com/tensorflow/models 页面即可下载具体项目是 models/tutorials/image/cifar10_estimator/$ curl -o cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz$ tar xzf cifar-10-python.tar.gz$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/cifar-10-batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10 \                         --model_dir=/tmp/cifar10 \                         --is_cpu_ps=True \                         --force_gpu_compatible=True \                         --num_gpus=1 \                         --train_steps=10000$ tensorboard --logdir=/tmp/cifar10

二、代码分析

1.cifar10_main.py

1.1 命令行参数处理
FLAGS = tf.flags.FLAGS

tf.flags.FLAGS定义在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform文件夹下flags.py中。

FLAGS = _FlagValues()class _FlagValues(object):    def _parse_flags(self, args=None):        result, unparsed = _global_parser.parse_known_args(args=args)    def __getattr__(self, name):        if not parsed:              self._parse_flags()

python中如下代码的作用。

if __name__ = "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码    tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)在tf.app.run()中     flags_passthrough = f._parse_flags(args=args)
1.2 训练和评估

1)训练和评估输入

train_input_fn = functools.partial(input_fn, subset='train',                                  num_shards=FLAGS.num_gpus)eval_input_fn = functools.partial(input_fn, subset='eval',                                 num_shards=FLAGS.num_gpus)

functools.partial的作用就是表明train_input_fn函数就是带了train和FLAGS.num_gpus参数的input_fn函数。
2)Session配置

sess_config = tf.ConfigProto()sess_config.allow_soft_placement = Truesess_config.log_device_placement = FLAGS.log_device_placementsess_config.intra_op_parallelism_threads = FLAGS.num_intra_threadssess_config.inter_op_parallelism_threads = FLAGS.num_inter_threadssess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible

3)Estimator配置

config = tf.estimator.RunConfig()config = config.replace(session_config=sess_config)classifier = tf.estimator.Estimator(    model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)

4)训练和评估

classifier.train(input_fn=train_input_fn,                 steps=train_steps,                 hooks=hooks)eval_results = classifier.evaluate(    input_fn=eval_input_fn,    steps=eval_steps)
原创粉丝点击