Tensorflow代码阅读

来源:互联网 发布:淘宝怎么升级心 编辑:程序博客网 时间:2024/06/02 00:26

一、fully_connected_feed.py

1.主函数

parser = argparse.ArgumentParser()parser.add_argument(    '--learning_rate',    type=float,    default=0.01,    help='Initial learning rate.')...FLAGS, unparsed = parser.parse_known_args()...run_training()

主函数主要是通过argparse来设定learning_rate, max_steps, input_data_dir, hidden1, hidden2, batch_size等参数保存在FLAGS中,然后调用run_trainning进行训练。
2.run_training()函数

# Build a Graph that computes predictions from the inference model.logits = mnist.inference(images_placeholder,                         FLAGS.hidden1,                         FLAGS.hidden2)# Add to the Graph the Ops for loss calculation.loss = mnist.loss(logits, labels_placeholder)# Add to the Graph the Ops that calculate and apply gradients.train_op = mnist.training(loss, FLAGS.learning_rate)# Add the Op to compare the logits to the labels during evaluation.eval_correct = mnist.evaluation(logits, labels_placeholder)...# Start the training loop.for step in xrange(FLAGS.max_steps):  start_time = time.time()  # Fill a feed dictionary with the actual set of images and labels  # for this particular training step.  feed_dict = fill_feed_dict(data_sets.train,                             images_placeholder,                             labels_placeholder)  # Run one step of the model.  The return values are the activations  # from the `train_op` (which is discarded) and the `loss` Op.  To  # inspect the values of your Ops or variables, you may include them  # in the list passed to sess.run() and the value tensors will be  # returned in the tuple from the call.  _, loss_value = sess.run([train_op, loss],                           feed_dict=feed_dict)  duration = time.time() - start_time... # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:   checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')   saver.save(sess, checkpoint_file, global_step=step)   # Evaluate against the training set.   print('Training Data Eval:')   do_eval(sess,           eval_correct,           images_placeholder,           labels_placeholder,           data_sets.train)   # Evaluate against the validation set.   print('Validation Data Eval:')   do_eval(sess,           eval_correct,           images_placeholder,           labels_placeholder,           data_sets.validation)   # Evaluate against the test set.   print('Test Data Eval:')   do_eval(sess,           eval_correct,           images_placeholder,           labels_placeholder,           data_sets.test)

run_training函数首先将logits, loss, train_op, eval_correct添加到计算图中,调用sess.run([train_op, loss], feed_dict=feed_dict)进行训练,调用do_eval函数进行train, validation, test评估。
3.do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set)函数

for step in xrange(steps_per_epoch):  feed_dict = fill_feed_dict(data_set,                             images_placeholder,                             labels_placeholder)  true_count += sess.run(eval_correct, feed_dict=feed_dict)precision = float(true_count) / num_examples

do_eval函数计算feed_dict,调用sess.run(eval_correct, feed_dict=feed_dict)来计算预测精度。
4.mnist.py

def inference(images, hidden1_units, hidden2_units):    return logitsdef loss(logits, labels):    return lossdef training(loss, learning_rate):    return train_opdef evaluation(logits, labels):    return tf.reduce_sum(tf.cast(correct, tf.int32))

该程序中inference函数构建模型,计算logits;loss函数通过logits和labels来计算loss;training函数定义了模型的学习优化算法,返回训练操作算子train_op;evaluation函数计算预测正确样本的个数。

二、resnet.py

1.main函数
关键代码如下:

mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')classifer = tf.estimator.Estimator(model_fn=res_net_model)train_input_fn = tf.estimator.inputs.numpy_input_fn(x={X_FEATURE: mnist.train.images},...)classifer.train(input_fn=train_input_fn,steps = 100)sorces = classifer.evaluate(input_fn=test_input_fn)

其中mnist为数据,classifer为估计器(可以train,也可以evaluate),classifer初始化时需要输入res_net_model。x={X_FEATURE: mnist.train.images}起到占位符的作用。

2.res_net_model(feature,labels,mode)函数
关键代码如下:
1)resnet各个blockneck层的配置,用namedtuple来配置各层参数

  BottleneckGroup = namedtuple('BottleneckGroup',                               ['num_blocks', 'num_filters', 'bottleneck_size'])  groups = [      BottleneckGroup(3, 128, 32), BottleneckGroup(3, 256, 64),      BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256)  ]

共4层,每层3个blockneck。
2)blockneck层的构成

with tf.variable_scope(name + '/conv_in'):  conv = tf.layers.conv2d(      net,      filters=group.num_filters,      kernel_size=1,      padding='valid',      activation=tf.nn.relu)  conv = tf.layers.batch_normalization(conv)...net = net+conv

tf.variable_scope用于给该层conv的参数范围标识(估计用于模型计算图的构建)。
tf.layers.conv2d用于构建2D的cnn。
net = conv+net这代码是够厉害的,模型竟然能够直接相加!!!!
3)模型的输出结果

logits = tf.layers.dense(net, N_DIGITS, activation=None)predicted_classes = tf.argmax(logits, 1)#预测if mode == tf.estimator.ModeKeys.PREDICT:  predictions = {      'class': predicted_classes,      'prob': tf.nn.softmax(logits)  }  return tf.estimator.EstimatorSpec(mode, predictions=predictions)#训练onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)loss = tf.losses.softmax_cross_entropy(    onehot_labels=onehot_labels, logits=logits)if mode == tf.estimator.ModeKeys.TRAIN:  optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())  return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)#评估eval_metric_ops = {    'accuracy': tf.metrics.accuracy(        labels=labels, predictions=predicted_classes)return tf.estimator.EstimatorSpec(    mode, loss=loss, eval_metric_ops=eval_metric_ops)

可以看出根据不同的mode,模型会返回不同的EstimatorSpec。
EstimatorSpec可接受的参数包括:predictions, train_op, eval_metric_op。
感觉由于采用了tf.estimator和tf.layers.conv2d,resnet.py模型部分代码比fully_connected_feed.py封装程度更高!!!

原创粉丝点击