tensorflow 学习之 cifar_10 模型定义(补)
来源:互联网 发布:2016双色球预测软件 编辑:程序博客网 时间:2024/05/22 06:25
# -*- coding: utf-8 -*-import osimport tensorflow as tfimport new_cifar10_inputimport sysimport tarfileimport urllibFLAGS=tf.app.flags.FLAGS #解析命令行传递的参数#设置模型参数tf.app.flags.DEFINE_integer('batch_size',128,"""Number of images to process in a batch.""")tf.app.flags.DEFINE_string('data_dir','/tmp/cifar10_data',"""Path to the CIFAR-10 data directory.""")tf.app.flags.DEFINE_boolean('use_fp16',False,"""Train the model using fp16.""")#数据集的全局常量IMAGE_SIZE =new_cifar10_input.IMAGE_SISENUM_CLASSES =new_cifar10_input.NUM_CLASSESNUM_EXAMOLES_PER_EPOCH_FOR_TRAIN =new_cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAINNUM_EXAMOLES_PER_EPOCH_FOR_EVAL = new_cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL#训练的常量MOVING_AVERAGE_DEVAY=0.999 #移动平均衰减率NUM_EPOCHS_PER_DECAY=350.0 #衰减呈阶梯函数,控制衰减周期(阶梯宽度) 每350epoch衰减一次LEARNING_RATE_DECAY_FACTOR=0.1 #学习率衰减因子INITIAL_LEARNING_RATE=0.1 #初始化学习率TOWER_NAME='tower'DATA_URL='http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'#创建直方图,以及衡量稀疏度的量,在tensorboard展现出来def _activation_summary(x): tensor_name=re.sub('%s_[0-9]*/'%TOWER_NAME,'',x.op.name) tf.summary.histogram(tensor_name+'/activations',x) tf.summary.scalar(tensor_name+'/sparity',tf.nn.zero_fraction(x))def _variable_on_cpu(name,shape,initializer): with tf.float16('/cup:0'): # #一个 context manager,用于为新的op指定要使用的硬件 dtype=tf.float16 if FLAGS.use_fp16 else tf.float32 var=tf.get_variable(name,shape,initializer=initializer,dtype=dtype) return vardef _variable_with_weight_decay(name,shape,stddev,wd): dtype=tf.float16 if FLAGS.use_fp16 else tf.float32 var=_variable_on_cpu(name,shape,tf.truncated_normal_initializer(stddev=stddev,dtype=dtype)) if wd is not None: weight_decay=tf.multiply(tf.nn.l2_loss(var),wd,name='weight_loss') tf.add_to_collection('losses',weight_decay) return vardef distorted_inputs(): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir =os.path.join(FLAGS.data_dir,'cifar-10-batches-bin') images,lables=new_cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images=tf.cast(images,tf.float16) lables=tf.cast(lables,tf.float16) return images,lablesdef inputs(eval_data): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir =os.path.join(FLAGS.data_dir,'cifar-10-batches-bin') images,labels=new_cifar10_input.inputs(eval_data=eval_data,data_dir=data_dir,batch_size=batch_size) if FLAGS.use_fp16: images=tf.cast(images,tf.float16) labels=tf.cast(labels,tf.float16) return images,labelsdef inference(images): #卷积和池化第一层 with tf.variable_scope('conv1') as scope: kernel=_variable_with_weight_decay('weights',shape=[5,5,3,64],stddev=5e-2,wd=0.0) conv=tf.nn.conv2d(images,kernel,[1,1,1,1],padding='SAME') biases=_variable_on_cpu('biases',[64],tf.constant_initializer(0.0)) pre_activation=tf.nn.bias_add(conv,biases) conv1=tf.nn.relu(pre_activation,name=scope.name) _activation_summary(conv1) pool1=tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1') norm1=tf.nn.lrn(pool1,4,bias=1.0,alpha=0.001/9.0,beta=0.75,name='norm1') #卷积和池化第二层 with tf.variable_scope('conv2') as scope: kernel=_variable_with_weight_decay('weights',shape=[5,5,64,64],stddev=5e-2,wd=0.0) conv=tf.nn.conv2d(norm1,kernel,[1,1,1,1],padding='SAME') biases=_variable_on_cpu('biases',[64],tf.constant_initializer(0.1)) pre_activation=tf.nn.bias_add(conv,biases) conv2=tf.nn.relu(pre_activation,name=scope.name) _activation_summary(conv2) norm2=tf.nn.lrn(conv2,4,bias=1.0,alpha=0.001/9.0,beta=0.75,name='norm2') pool2=tf.nn.max_pool(norm2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool2') #全连接层 with tf.variable_scope('fc1') as scope: reshape=tf.reshape(pool2,[FLAGS.batch_size,-1]) dim=reshape.get_shape()[1].value weights=_variable_with_weight_decay('weights',shape=[dim,384],stddev=0.04,wd=0.004) biases=_variable_on_cpu('biases',[384],tf.constant_initializer(0.1)) fc1=tf.nn.relu(tf.matmul(reshape,weights)+biases,name=scope.name) _activation_summary(fc1) with tf.variable_scope('fc2') as scope: weights=_variable_with_weight_decay('weights',shape=[384,192],stddev=0.04,wd=0.004) biases=_variable_on_cpu('biased',[192],tf.constant_initializer(0.1)) fc2=tf.nn.relu(tf.matmul(fc1,weights)+biases,name=scope.name) _activation_summary(fc2) #进行线性变换输出logistics模型 with tf.variable_scope('sotfmax_linear') as scope: weights=_variable_with_weight_decay('weights',[192,NUM_CLASSES],stddev=1/192.0,wd=0.0) biases=_variable_on_cpu('biases',[NUM_CLASSES],tf.constant_initializer(0.0)) softmax_linear=tf.add(tf.matmul(fc2,weights),biases,name=scope.name) _activation_summary(softmax_linear) return softmax_lineardef loss(logits,labels): # labels,其值是稀疏表示的 logits,其表示隐藏层线性变换后非归一化后的结果 labels=tf.cast(labels,tf.int64) cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits, #根据稀疏表示的label和输出层数据计算损失 name='cross_entropy_per_example') cross_entropy_mean=tf.reduce_mean(cross_entropy,name='cross_entropy') tf.add_to_collection('losses',cross_entropy_mean) return tf.add_n(tf.get_collection('losses'),name='total_loss')def _add_loss_summaries(total_loss): # MovingAverage为滑动平均 # 计算方法:对于一个给定的数列,首先设定一个固定的值k,然后分别计算第1项到第k项,第2项到第k+1项,第3项到第k+2项的平均值,依次类推 loss_averages=tf.train.ExponentialMovingAverage(0.9,name='avg') losses=tf.get_collection('losses') #从字典集合中返回关键字'losses'对应的所有变量,包括交叉熵损失和正则项损失 loss_averages_op=loss_averages.apply(losses+[total_loss]) for l in losses+[total_loss]: tf.summary.scalar(l.op.name +'(raw)',l) tf.summary.scalar(l.op.name,loss_averages.average(l)) return loss_averages_opdef train(total_loss,global_step): #影响学习速率的变量 num_batched_per_epoch=NUM_EXAMOLES_PER_EPOCH_FOR_TRAIN/FLAGS.batch_size decay_steps=int(num_batched_per_epoch*NUM_EPOCHS_PER_DECAY) ##根据步数以指数方式衰减学习率。 lr=tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,decay_steps, LEARNING_RATE_DECAY_FACTOR,staircase=True) tf.summary.scalar('learning_rate',lr) #生成所有损失的平均值 loss_averages_op=_add_loss_summaries(total_loss) #计算梯度 with tf.control_dependencies(loss_averages_op): opt=tf.train.GradientDescentOptimizer(lr) grads=opt.compute_gradients(total_loss) apply_gradient_op=opt.apply_gradients(grads,global_step=global_step) #应用梯度 for var in tf.trainable_variables(): tf.summary.histogram(var.op.name,var) #训练变量直方图 for grad,var in grads: if grad is not None: tf.summary.histogram(var.op.name+'/gradients',grad) #梯度直方图 #跟踪所有的训练变量的移动平均值 variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DEVAY,global_step) variable_averages_op=variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies([apply_gradient_op,variable_averages_op]): train_op=tf.no_op(name='train') return train_opdef maybe_download_and_extract(): dest_directory=FLAGS.data_dir if not os.path.exists(dest_directory): os.makedirs(dest_directory) filename=DATA_URL.split('/')[-1] filepath=os.path.join(dest_directory,filename) if not os.path.exists(filepath): def _progress(count,block_size,total_size): sys.stdout.write('\r >>Downloading %.1f%%'%(filename, float(count*block_size)/float(total_size)*100.0)) sys.stdout.flush() filepath,_=urllib.request.urlretrieve(DATA_URL,filepath,_progress) print() statinfo=os.stat(filepath) print('Successfully download',filename,statinfo.st_size,'bytes.') extracted_dir_path=os.path.join(dest_directory,'cifar-10-batches-bin') if not os.path.exists(extracted_dir_path): tarfile.open(filepath,'r:gz').extractall(dest_directory)
很多不理解的地方,得去学习API
阅读全文
0 0
- tensorflow 学习之 cifar_10 模型定义(补)
- tensorflow 学习之 cifar_10 模型定义
- tensorflow学习之cifar_10模型评估
- TensorFlow学习笔记(一补):使用Anaconda安装TensorFlow
- tensorflow cifar_10 代码阅读与理解
- 深度学习模型之tensorflow应用
- TensorFlow学习笔记(二十一) tensorflow机器学习模型
- tensorflow学习:定义变量
- TensorFlow学习笔记(二):TensorFlow实现线性回归模型
- TensorFlow学习笔记(三):TensorFlow实现逻辑回归模型
- thinkphp学习之模型数据表名定义
- TensorFlow个人学习(回归模型)
- TensorFlow学习笔记(二十六)CNN的9大模型之LeNet5的原理讲解
- TensorFlow学习笔记(二十七)CNN的9大模型之Dan CiresanNet
- TensorFlow学习笔记(二十八)CNN的9大模型之AlexNet
- tensorflow之inception_v3模型的部分加载及权重的部分恢复(23)---《深度学习》
- Tensorflow深度学习之九:滑动平均模型
- TensorFlow优化模型之学习率的设置
- Lua流程控制
- Java中equals和==的那些事
- define声明一个常数问题
- 源码 Music音乐播放器代码结构
- Android电话响铃、接听、挂断状态
- tensorflow 学习之 cifar_10 模型定义(补)
- linux下安装jdk和tomcat
- Android加密已有的sqlite数据库---sqlcipher
- Redis-Java客户端Jedis
- 最大序列和问题
- Lua 函数
- Android:如何快速对系统重启问题进行归类
- 一般线性模型、混合线性模型、广义线性模型
- 解决ueditor百度富文本编辑器图片可以上传但是在线管理图片无法显示