tensorflow cifar_10 代码阅读与理解
来源:互联网 发布:linux网络工程师培训 编辑:程序博客网 时间:2024/05/18 03:41
前言
Tensorflow 提供cifar_10 benchmark问题的示例代码,并且在中文翻译的官方文档中有专门的一章介绍该卷积神经网络(CNN),作为刚刚接触深度学习与Tensorflow框架的菜鸟,对tf提供的大量库函数与深度学习的trick并不十分熟悉,因此花了两天的时间通读懂了代码,下面具体剖析一下整个程序的过程,作为学习记录。
准备工作
从Github https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10下载到的cifar10程序共包括1. cifar10.py
2. cifar10_eval.py
3. cifar10_input.py
4. cifar10_multi_gpu_train.py
5. cifar10_train.py
本文只讨论CPU版本,因此可以自动忽略cifar10_multi_gpu_train.py文件。另外,为加快程序调试,避免在程序运行时再去自动下载资源,可以提前去 http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz下好图片压缩文件,放在默认位置/tmp/cifar10_data 中即可
代码阅读
总体流程
程序入口为cifar10_train中的main方法,其代码其注释如下:def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): # use the default graph in the process in the context global_step = tf.contrib.framework.get_or_create_global_step() # Returns and create (if necessary) the global step variable. However the method is depressed in V0.8.0 #global_step = tf.Variable(0, name='global_step', trainable=False) # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore. with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op)
其中,专门定义_LoggerHook类,在mon_sess这个对话中注册,代码中最后一句,表示在停止条件达到之前,循环运行train_op,更新网络系数
读取文件
调用cifar10.py中的 distorted_inputs方法,其主要语句是filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) # Create a queue that produces the filenames to read. filename_queue = tf.train.string_input_producer(filenames) # Pass the list of filenames to the tf.train.string_input_producer function. string_input_producer creates a FIFO queue for holding the filenames until the reader needs them. # Read examples from files in the filename queue. read_input = read_cifar10(filename_queue) reshaped_image = tf.cast(read_input.uint8image, tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE # Image processing for training the network. Note the many random # distortions applied to the image. # Randomly crop a [height, width] section of the image. distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) # Randomly flip the image horizontally. distorted_image = tf.image.random_flip_left_right(distorted_image) # Because these operations are not commutative, consider randomizing # the order their operation. distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(distorted_image) #Linearly scales image to have zero mean and unit norm. # Set the shapes of tensors. float_image.set_shape([height, width, 3]) read_input.label.set_shape([1]) # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) #tf.train.shuffle_batch equals 20000. print ('Filling queue with %d CIFAR images before starting to train. ' 'This will take a few minutes.' % min_queue_examples) # Generate a batch of images and labels by building up a queue of examples. return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=True)
# 采用多线程并行读入样本,构成一个训练batch,大小为128,需要注意的是,这里返回的是一个batch,images张量形式为(128,32,32,3)
return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=True)
其中,read_cifar10方法内容如下:
class CIFAR10Record(object): pass result = CIFAR10Record() # Dimensions of the images in the CIFAR-10 dataset. # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the # input format. label_bytes = 1 # 2 for CIFAR-100 result.height = 32 result.width = 32 result.depth = 3 image_bytes = result.height * result.width * result.depth # Every record consists of a label followed by the image, with a # fixed number of bytes for each. record_bytes = label_bytes + image_bytes # Read a record, getting filenames from the filename_queue. No # header or footer in the CIFAR-10 format, so we leave header_bytes # and footer_bytes at their default of 0. # 下面的FixedLengthRecordReader与reader.read,tf.decode_raw配合起来,是以固定长度读取文件名队列中数据的一个常用方法 reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key, value = reader.read(filename_queue) # Convert from a string to a vector of uint8 that is record_bytes long. record_bytes = tf.decode_raw(value, tf.uint8) # To read binary files in which each record is a fixed number of bytes, use tf.FixedLengthRecordReader with the tf.decode_raw operation. The decode_raw op converts from a string to a uint8 tensor. # For example, the CIFAR-10 dataset uses a file format where each record is represented using a fixed number of bytes: 1 byte for the label followed by 3072 bytes of image data. Once you have a uint8 tensor, standard operations can slice out each piece and reformat as needed. For CIFAR-10, you can see how to do the reading and decoding in # The first bytes represent the label, which we convert from uint8->int32. # 下面采用tf.strided_slice方法在record_bytes中提取第一个bytes作为标签 result.label = tf.cast( tf.strided_slice(record_bytes, [0], [label_bytes], [1]), tf.int32) #unfortunately, the method "tf.strided_slice" is deprecated in this version, What can be subsititued? # strided_slice Extracts a strided slice from a tensor. # The remaining bytes after the label represent the image, which we reshape # from [depth * height * width] to [depth, height, width]. # 下面采用tf.strided_slice方法在record_bytes中的图片数据信息 depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes], [1]), [result.depth, result.height, result.width]) # Convert from [depth, height, width] to [height, width, depth]. result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result如代码所述,read_cifar10其实返回了一个训练样本,包括result.label 和result.uint8image两个数据成员。其中,_generate_image_and_label_batch方法内容如下:
num_preprocess_threads = 16 if shuffle: # 随机产生一个batch,有16个线程,而读入的缓存大小为20000,capacity为20000+3*128 images, label_batch = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples) else: images, label_batch = tf.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size) # Display the training images in the visualizer. tf.summary.image('images', images) return images, tf.reshape(label_batch, [batch_size])
我认为载入的难以理解点就是明明是一个一个样本读取,最后却能返回一个完成的batch,并且还有一个载入的缓存,大小为20000,极客学院网站上解释的图
首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队。
我的理解是FixedLengthRecordReader,reader.read读取的是一个文件名列表中任意一个文件中的一个样本信息,对应上图中dequeue,在这之后可以针对这一个样本进行处理,而最终的tf.train.shuffle_batch,则将16个不同的reader读到的样本组成batch并返回。这些方法必须配套使用,即虽然没有显式的将多线程及batch构成过程编程实现,但tensorflow帮我们实现了上述的机制。
模型定义
反而模型定义部分无需多讲,只要注意以下几点:1. 虽然与MNIST相比,这里的深度学习网络也是采用了两个卷积层和两个池化层,但中间加入局部响应正则化层和两个全连接层
2. 定义两个全连接层时,通过_variable_with_weight_decay方法将权重的二范数值引入最终的loss计算,其相关代码如下所示:
weights = _variable_with_weight_decay('weights', shape=[384, 192], stddev=0.04, wd=0.004)
def _variable_with_weight_decay(name, shape, stddev, wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. Args: name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian wd: add L2Loss weight decay multiplied by this float. If None, weight decay is not added for this Variable. Returns: Variable Tensor """ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = _variable_on_cpu( name, shape, tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') #Computes half the L2 norm of a tensor without the sqrt #output = sum(t ** 2) / 2 tf.add_to_collection('losses', weight_decay) return var
其实,tf.nn.l2_loss就是将所有weights平方和除以2之后,然后weight_decay计算是乘以0.004系数加到losses计算中
训练目标
loss计算方法定义如下:def loss(logits, labels): """Add L2Loss to all the trainable variables. Add summary for "Loss" and "Loss/avg". Args: logits: Logits from inference(). labels: Labels from distorted_inputs or inputs(). 1-D tensor of shape [batch_size] Returns: Loss tensor of type float. """ # Calculate the average cross entropy loss across the batch. labels = tf.cast(labels, tf.int64) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits, labels, name='cross_entropy_per_example') cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') tf.add_to_collection('losses', cross_entropy_mean) # Wrapper for Graph.add_to_collection() using the default graph. # The total loss is defined as the cross entropy loss plus all of the weight # decay terms (L2 loss). return tf.add_n(tf.get_collection('losses'), name='total_loss') #It seems that the cross_entropy_mean is add_to_collection and added
通过collection中'losses'字段,最后的tf.add_n将通常的熵值与上面所说的weights的二范数值相加作为loss
迭代过程
def train(total_loss, global_step): """Train CIFAR-10 model. Create an optimizer and apply to all trainable variables. Add moving average for all trainable variables. Args: total_loss: Total loss from loss(). global_step: Integer Variable counting the number of training steps processed. Returns: train_op: op for training. """ # Variables that affect learning rate. num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size #50000/128 decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) #When training a model, it is often recommended to lower the learning rate as the training progresses. This function applies an exponential decay function to a provided initial learning rate. It requires a global_step value to compute the decayed learning rate. You can just pass a TensorFlow variable that you increment at each training step. tf.summary.scalar('learning_rate', lr) # Generate moving averages of all losses and associated summaries. loss_averages_op = _add_loss_summaries(total_loss) # Compute gradients. with tf.control_dependencies([loss_averages_op]): # Returns a context manager that specifies control dependencies. # Use with the with keyword to specify that all operations constructed within the context should have control dependencies on control_inputs. For example: opt = tf.train.GradientDescentOptimizer(lr) grads = opt.compute_gradients(total_loss) # Apply gradients. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Add histograms for gradients. for grad, var in grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad) # Track the moving averages of all trainable variables. variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies([apply_gradient_op, variables_averages_op]): train_op = tf.no_op(name='train') #Does nothing. Only useful as a placeholder for control edges. return train_op
训练过程中主要的不同之处是
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True)就是说在每一次迭代过程中,都需要重新计算一次learning rate,而这里初始的INITIAL_LEARNING_RATE为0.1,global_step为当前的迭代次数,decay_steps就是每多少代,learning_rate衰减到到LEARNING_RATE_DECAY_FACTOR×INITIAL_LEARNING_RATE值,比如本程序中LEARNING_RATE_DECAY_FACTOR = 0.1 ,而decay_steps = num_batches_per_epoch * NUM_EPOCHS_PER_DECAY = 50000/128×350 ,也就说每十多万次迭代,lr衰减为原来0.1,然后根据每代的lr,用梯度法计算
opt = tf.train.GradientDescentOptimizer(lr) grads = opt.compute_gradients(total_loss)而后面的
variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step)产生一个滑动平均计算对象,MOVING_AVERAGE_DECAY = 0.999,则每一代中的decay值更新如下
min(decay, (1 + num_updates) / (10 + num_updates))
采用这个计算得到的decay值对上面梯度法更新得到的所有参数进行平滑处理如下:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
PS
在实际运行过程中,本人采用0.12r版本,出现版本不兼容导致的错误和警告,主要有以下两个问题:1.cifar10_input中的tf.strided_slice方法原程序中只提供3个参数,而在0.12r中的版本需要提供4个参数,相差一个步长stride的值,这里补充为[1]即可
2.写日志文件大量采用了tf.contrib.deprecated库中的方法,已全部失效, 可直接采用tf.summary.scalar等方法cifar10_train.pycifar10_train
2 0
- tensorflow cifar_10 代码阅读与理解
- Tensorflow代码阅读
- 理解ResNet结构与TensorFlow代码分析
- 理解ResNet结构与TensorFlow代码分析
- 理解ResNet结构与TensorFlow代码分析
- tensorflow 学习之 cifar_10 模型定义
- tensorflow学习之cifar_10模型评估
- Tensorflow做阅读理解与完形填空
- tensorflow 学习之 cifar_10 模型定义(补)
- tensorflow seq2seq模型 代码阅读分析
- TensorFlow学习之CNN-Cifar10代码阅读与详解(一):cifar10数据批量读取
- Mahout MinHash代码阅读理解
- ResNet-TensorFlow Model Zoo代码理解
- 从Tensorflow代码中理解LSTM网络
- HBase源代码阅读与理解
- HBase源代码阅读与理解
- HashMap源码阅读与理解
- ConcurrentHashMap源码阅读与理解
- BZOJ 2194: 快速傅立叶之二
- alertdialog button位置潜谈
- 今天弯个腰随便捡点小钱哦,提前跟大家说一声
- Duplicating managed version 版本冲突
- linux 命令总结
- tensorflow cifar_10 代码阅读与理解
- Python:用迭代器和生成器降低程序内存占用率
- 绝对有用!上传视频到视频网站的教学【以及各大视频网站的特点介绍】
- 浅谈CSRF攻击方式
- zTree的菜单筛选
- 如何利用DigitalOcean设置主机名称
- 理解OAuth 2.0
- Jquery选择器总结
- 慕课网C++学习笔记 20170111