Resnet Cifar-10调试
来源:互联网 发布:自动化编程是什么 编辑:程序博客网 时间:2024/06/03 05:27
一、下载和运行
https://github.com/tensorflow/models 页面即可下载具体项目是 models/tutorials/image/cifar10_estimator/$ curl -o cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz$ tar xzf cifar-10-python.tar.gz$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/cifar-10-batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10 \ --model_dir=/tmp/cifar10 \ --is_cpu_ps=True \ --force_gpu_compatible=True \ --num_gpus=1 \ --train_steps=10000$ tensorboard --logdir=/tmp/cifar10
二、代码分析
1.cifar10_main.py
1.1 命令行参数处理
FLAGS = tf.flags.FLAGS
tf.flags.FLAGS定义在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform文件夹下flags.py中。
FLAGS = _FlagValues()class _FlagValues(object): def _parse_flags(self, args=None): result, unparsed = _global_parser.parse_known_args(args=args) def __getattr__(self, name): if not parsed: self._parse_flags()
python中如下代码的作用。
if __name__ = "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码 tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)在tf.app.run()中 flags_passthrough = f._parse_flags(args=args)
1.2 训练和评估
1)训练和评估输入
train_input_fn = functools.partial(input_fn, subset='train', num_shards=FLAGS.num_gpus)eval_input_fn = functools.partial(input_fn, subset='eval', num_shards=FLAGS.num_gpus)
functools.partial的作用就是表明train_input_fn函数就是带了train和FLAGS.num_gpus参数的input_fn函数。
2)Session配置
sess_config = tf.ConfigProto()sess_config.allow_soft_placement = Truesess_config.log_device_placement = FLAGS.log_device_placementsess_config.intra_op_parallelism_threads = FLAGS.num_intra_threadssess_config.inter_op_parallelism_threads = FLAGS.num_inter_threadssess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
3)Estimator配置
config = tf.estimator.RunConfig()config = config.replace(session_config=sess_config)classifier = tf.estimator.Estimator( model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
4)训练和评估
classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=hooks)eval_results = classifier.evaluate( input_fn=eval_input_fn, steps=eval_steps)
阅读全文
0 0
- Resnet Cifar-10调试
- CIFAR-10
- Caffe学习-CIFAR-10
- cifar 10 最高正确率
- caffe CIFAR-10
- caffe学习:CIFAR-10
- 深度学习 :CIFAR-10
- caffe CIFAR 10 database
- CIFAR-10训练模型
- cifar
- Resnet
- ResNet
- ResNet
- ResNet
- ResNet
- ResNet
- 用python读取cifar-10与cifar-100图像数据
- 用python读取cifar-10与cifar-100图像数据
- 几种简单的Dialog对话框
- 深入浅出JMS(一)--JMS基本概念
- Javaweb 自动登录 详细讲解
- 滑屏效果实现
- 各种音视频编解码学习详解
- Resnet Cifar-10调试
- word-break
- Notes on Tensorflow
- hdu6106Classes(交集计算集合)
- 基于maven发送邮件系列(2)---用spring的timer实现定时发送邮件
- Java 验证表单工具类,史上最全
- 第9章 多元函数微分法及其应用
- 用mybatis在java后台insert数据,能运行但数据库没有添加成功
- Eclipse:An internal error occurred during: "Build Project". GC overhead limit exceeded