【深度学习】VGGNet的tensorflow实现code
来源:互联网 发布:网络飞行游戏 编辑:程序博客网 时间:2024/05/16 13:04
在上一篇博客中对VGGNet论文做了学习笔记以及网络结构记录,链接见此http://blog.csdn.net/qq_29340857/article/details/71440674,本篇博客的代码都是基于上一篇的理解写的,话不多说,code如下。
输入数据集定义(因为使用随机梯度下降,所以使用placeholder定义训练数据):
tf_train_data=tf.placeholder(tf.float32,shape=(batch_size,img_size,img_size,3)) tf_train_label=tf.placeholder(tf.float32,shape=(batch_size,img_num)) tf_valid_data=tf.constant(valid_data) tf_test_data=tf.constant(test_data)
参数初始化(不算pool层,fc层+conv层一共是16层神经网络):
w={ 'w1':tf.get_variable('w1',[3,3,3,64],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w2':tf.get_variable('w2',[3,3,64,64],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w3':tf.get_variable('w3',[3,3,64,128],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w4':tf.get_variable('w4',[3,3,128,128],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w5':tf.get_variable('w5',[3,3,128,256],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w6':tf.get_variable('w6',[3,3,256,256],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w7':tf.get_variable('w7',[3,3,256,256],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w8':tf.get_variable('w8',[3,3,256,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w9':tf.get_variable('w9',[3,3,512,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w10':tf.get_variable('w10',[3,3,512,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w11':tf.get_variable('w11',[3,3,512,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w12':tf.get_variable('w12',[3,3,512,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w13':tf.get_variable('w13',[3,3,512,512],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'w14':tf.Variable(tf.random_normal([img_size/32*img_size/32*512,4096])), 'w15':tf.Variable(tf.random_normal([4096,4096])), 'w16':tf.Variable(tf.random_normal([4096,1000])), } b={ 'b1':tf.Variable(tf.zeros([64])), 'b2':tf.Variable(tf.zeros([64])), 'b3':tf.Variable(tf.zeros([128])), 'b4':tf.Variable(tf.zeros([128])), 'b5':tf.Variable(tf.zeros([256])), 'b6':tf.Variable(tf.zeros([256])), 'b7':tf.Variable(tf.zeros([256])), 'b8':tf.Variable(tf.zeros([512])), 'b9':tf.Variable(tf.zeros([512])), 'b10':tf.Variable(tf.zeros([512])), 'b11':tf.Variable(tf.zeros([512])), 'b12':tf.Variable(tf.zeros([512])), 'b13':tf.Variable(tf.zeros([512])), 'b14':tf.Variable(tf.zeros([4096])), 'b15':tf.Variable(tf.zeros([4096])), 'b16':tf.Variable(tf.zeros([1000])), }
论文原文中提到两种参数初始化方法:
1. pre-training:训练A结构网络,直接把其前四层conv的参数和fc的前两层参数直接赋给D结构网络,就是本文代码实现的
2. no pre-training:利用Xaiver Initialization初始化参数
在此采用Xavier Initialization。
网络结构搭建:
#13层卷积层+5层max_pool+3层FC层 def model(input_data): conv1=tf.nn.conv2d(input_data,w['w1'],[1,1,1,3],padding="SAME")#img:224*224*64 h1=tf.nn.relu(conv1+b['b1']) conv2=tf.nn.conv2d(h1,w['w2'],[1,1,1,64],padding="SAME") h2=tf.nn.relu(conv2+b['b2']) max1=tf.nn.max_pool(h2,[2,2,64,64],[1,2,2,64],padding="VALID") conv3=tf.nn.conv2d(max1,w['w3'],[1,1,1,64],padding="SAME")#img:112*112*128 h3=tf.nn.relu(conv3+b['b3']) conv4=tf.nn.conv2d(h3,w['w4'],[1,1,1,128],padding="SAME") h4=tf.nn.relu(conv4+b['b4']) max2=tf.nn.max_pool(h4,[2,2,128,128],[1,2,2,128],padding="VALID") conv5=tf.nn.conv2d(max2,w['w5'],[1,1,1,128],padding="SAME")#img:56*56*256 h5=tf.nn.relu(conv5+b['b5']) conv6=tf.nn.conv2d(h5,w['w6'],[1,1,1,256],padding="SAME") h6=tf.nn.relu(conv6+b['b6']) conv7=tf.nn.conv2d(h6,w['w7'],[1,1,1,256],padding="SAME") h7=tf.nn.relu(conv7)+b['b7'] max3=tf.nn.max_pool(h7,[2,2,256,256],[1,2,2,256],padding="VALID") conv8=tf.nn.conv2d(max3,w['w8'],[1,1,1,256],padding="SAME")#img:28*28*512 h8=tf.nn.relu(conv8+b['b8']) conv9=tf.nn.conv2d(h8,w['w9'],[1,1,1,512],padding="SAME") h9=tf.nn.relu(conv9+b['b9']) conv10=tf.nn.conv2d(h9,w['w10'],[1,1,1,512],padding="SAME") h10=tf.nn.relu(conv10+b['b10']) max4=tf.nn.max_pool(h10,[2,2,512,512],[1,2,2,512],padding="VALID") conv11=tf.nn.conv2d(max4,w['w11'],[1,1,1,512],padding="SAME")#img:14*14*512 h11=tf.nn.relu(conv11+b['b11']) conv12=tf.nn.conv2d(h11,w['w12'],[1,1,1,512],padding="SAME") h12=tf.nn.relu(conv12+b['b12']) conv13=tf.nn.conv2d(h12,w['w13'],[1,1,1,512],padding="SAME") h13=tf.nn.relu(conv13+b['b13']) max5=tf.nn.max_pool(h13,[2,2,512,512],[1,2,2,512],padding="VALID") shapes=max5.get_shape().as_list()#img:7*7*512 reshape=tf.reshape(max5,[shapes[0],shapes[1]*shapes[2]*shapes[3]]) fc1=tf.matmul(reshape,w['w14'])+b['b14'] h14=tf.nn.dropout(tf.nn.relu(fc1),drop_param) fc2=tf.matmul(h14,w['w15'])+b['b15'] h15=tf.nn.dropout(tf.nn.relu(fc2),drop_param) fc3=tf.matmul(h15,w['w16'])+b['b16'] return fc3
L2正则化:
l2_loss=None for i in np.arange(1,17): k="w"+str(i) l2_loss+=l2_param*tf.nn.l2_loss(w[k]) loss=tf.reduce_mean(tf.softmax_cross_entropy_with_logits(logits,tf_train_label))+l2_loss
Momentum Optimizer更新参数:
optimizer=tf.train.MomentumOptimizer(learning_rate,momentum_param)
对Momentum Optimizer不了解的小伙伴和想了解更多参数更新方法的小伙伴,可以参见我之前写的博客http://blog.csdn.net/qq_29340857/article/details/71353751
train,validation,test数据集预测:
train_predictions=tf.nn.softmax(logits) valid_predictions=tf.nn.softmax(model(tf_valid_data)) test_predictions=tf.nn.softmax(model(tf_test_data))
至此,VGGNet的搭建工作差不多完成,在下一篇博客中会记录实际训练时遇到的一些问题以及解决方案,欢迎继续关注!
阅读全文
0 0
- 【深度学习】VGGNet的tensorflow实现code
- TensorFlow实现经典深度学习网络(2):TensorFlow实现VGGNet
- tensorflow34《TensorFlow实战》笔记-06-02 TensorFlow实现VGGNet code
- Tensorflow实战学习(三十一)【实现VGGNet】
- Tensorflow实现VGGNet
- Tensorflow实现VGGNet
- 深度学习的TensorFlow实现
- TensorFlow学习--VGGNet实现&图像识别
- VGGNet原理及Tensorflow实现
- 【深度学习】VGGNet学习笔记
- 神经网络之VGGNet模型的实现(Python+TensorFlow)
- 深度学习之解读VGGNet
- tensorflow40《TensorFlow实战》笔记-08-01 TensorFlow实现深度强化学习-策略网络 code
- tensorflow41《TensorFlow实战》笔记-08-02 TensorFlow实现深度强化学习-估值网络 code
- 学习笔记TF031:实现VGGNet
- 深度学习算法之AlexNet和VGGNet
- 深度学习经典卷积神经网络之VGGNet
- 【深度学习】TensorFlow实现LeNet5
- 在Linux系统下,重启Tomcat使用命令操作的
- mycat后台连接调用关系
- python中关于类的理解
- 易飞去除批处理删除历史单据权限
- [UE4]后期处理(Post Processing)相关的官方文档
- 【深度学习】VGGNet的tensorflow实现code
- 培养创造力的10个注意点
- 微信开发后台处理消息时使用反射,去掉繁琐的if判断
- Scrollow嵌套RecyclverView出现滑动卡顿
- DotNet 资源大全中文版(Awesome最新版)
- 链接:NFC:高级NFC
- PHP最佳实践(译)
- JAVA大神之路
- 解决Nutz连接Oracle扫描建表,错误信息:无效字符