Tensorflow代码阅读
来源:互联网 发布:淘宝怎么升级心 编辑:程序博客网 时间:2024/06/02 00:26
一、fully_connected_feed.py
1.主函数
parser = argparse.ArgumentParser()parser.add_argument( '--learning_rate', type=float, default=0.01, help='Initial learning rate.')...FLAGS, unparsed = parser.parse_known_args()...run_training()
主函数主要是通过argparse来设定learning_rate, max_steps, input_data_dir, hidden1, hidden2, batch_size等参数保存在FLAGS中,然后调用run_trainning进行训练。
2.run_training()函数
# Build a Graph that computes predictions from the inference model.logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)# Add to the Graph the Ops for loss calculation.loss = mnist.loss(logits, labels_placeholder)# Add to the Graph the Ops that calculate and apply gradients.train_op = mnist.training(loss, FLAGS.learning_rate)# Add the Op to compare the logits to the labels during evaluation.eval_correct = mnist.evaluation(logits, labels_placeholder)...# Start the training loop.for step in xrange(FLAGS.max_steps): start_time = time.time() # Fill a feed dictionary with the actual set of images and labels # for this particular training step. feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) # Run one step of the model. The return values are the activations # from the `train_op` (which is discarded) and the `loss` Op. To # inspect the values of your Ops or variables, you may include them # in the list passed to sess.run() and the value tensors will be # returned in the tuple from the call. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time... # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) # Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) # Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
run_training函数首先将logits, loss, train_op, eval_correct添加到计算图中,调用sess.run([train_op, loss], feed_dict=feed_dict)进行训练,调用do_eval函数进行train, validation, test评估。
3.do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set)函数
for step in xrange(steps_per_epoch): feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder) true_count += sess.run(eval_correct, feed_dict=feed_dict)precision = float(true_count) / num_examples
do_eval函数计算feed_dict,调用sess.run(eval_correct, feed_dict=feed_dict)来计算预测精度。
4.mnist.py
def inference(images, hidden1_units, hidden2_units): return logitsdef loss(logits, labels): return lossdef training(loss, learning_rate): return train_opdef evaluation(logits, labels): return tf.reduce_sum(tf.cast(correct, tf.int32))
该程序中inference函数构建模型,计算logits;loss函数通过logits和labels来计算loss;training函数定义了模型的学习优化算法,返回训练操作算子train_op;evaluation函数计算预测正确样本的个数。
二、resnet.py
1.main函数
关键代码如下:
mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')classifer = tf.estimator.Estimator(model_fn=res_net_model)train_input_fn = tf.estimator.inputs.numpy_input_fn(x={X_FEATURE: mnist.train.images},...)classifer.train(input_fn=train_input_fn,steps = 100)sorces = classifer.evaluate(input_fn=test_input_fn)
其中mnist为数据,classifer为估计器(可以train,也可以evaluate),classifer初始化时需要输入res_net_model。x={X_FEATURE: mnist.train.images}起到占位符的作用。
2.res_net_model(feature,labels,mode)函数
关键代码如下:
1)resnet各个blockneck层的配置,用namedtuple来配置各层参数
BottleneckGroup = namedtuple('BottleneckGroup', ['num_blocks', 'num_filters', 'bottleneck_size']) groups = [ BottleneckGroup(3, 128, 32), BottleneckGroup(3, 256, 64), BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256) ]
共4层,每层3个blockneck。
2)blockneck层的构成
with tf.variable_scope(name + '/conv_in'): conv = tf.layers.conv2d( net, filters=group.num_filters, kernel_size=1, padding='valid', activation=tf.nn.relu) conv = tf.layers.batch_normalization(conv)...net = net+conv
tf.variable_scope用于给该层conv的参数范围标识(估计用于模型计算图的构建)。
tf.layers.conv2d用于构建2D的cnn。
net = conv+net这代码是够厉害的,模型竟然能够直接相加!!!!
3)模型的输出结果
logits = tf.layers.dense(net, N_DIGITS, activation=None)predicted_classes = tf.argmax(logits, 1)#预测if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class': predicted_classes, 'prob': tf.nn.softmax(logits) } return tf.estimator.EstimatorSpec(mode, predictions=predictions)#训练onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)loss = tf.losses.softmax_cross_entropy( onehot_labels=onehot_labels, logits=logits)if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdagradOptimizer(learning_rate=0.01) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)#评估eval_metric_ops = { 'accuracy': tf.metrics.accuracy( labels=labels, predictions=predicted_classes)return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=eval_metric_ops)
可以看出根据不同的mode,模型会返回不同的EstimatorSpec。
EstimatorSpec可接受的参数包括:predictions, train_op, eval_metric_op。
感觉由于采用了tf.estimator和tf.layers.conv2d,resnet.py模型部分代码比fully_connected_feed.py封装程度更高!!!
- Tensorflow代码阅读
- tensorflow cifar_10 代码阅读与理解
- tensorflow seq2seq模型 代码阅读分析
- 如何阅读TensorFlow源码
- tensorflow rnn阅读笔记
- TensorFlow学习之CNN-Cifar10代码阅读与详解(一):cifar10数据批量读取
- 阅读代码
- 阅读代码
- 代码阅读
- 代码阅读
- 代码阅读
- 代码阅读
- 利用sourceInsight阅读TensorFlow源码
- tensorflow之seq2seq阅读笔记
- 《TensorFlow实战》TensorFlow上手代码
- TensorFlow的代码框架
- tensorflow跑示例代码
- TensorFlow MNIST案例代码
- 求一个数的因子个数
- WPF-Binding篇(四)
- 前端模版artTemplate的介绍及使用
- 线段树模板
- CQBZOJ 【重庆市NOIP模拟赛】避难向导
- Tensorflow代码阅读
- C++(3)资源管理
- Linux基本概念及操作
- 【Spring】Spring依赖注入与控制反转理解
- 洛谷P1983 拓扑排序 解题报告
- bzoj [Noi2002]Savage 扩展欧几里得
- 共享雨伞,又一个昙花一现的共享经济?
- Android 自定义View-旋转小按钮
- 安卓WebView与H5互调的简单实现讲解(一)