利用TensorFlow实现VGG16
来源:互联网 发布:网狐棋牌源码 编辑:程序博客网 时间:2024/05/19 17:48
上一篇文章,实现了网络的输入,这次继续完成网络的训练,网络采用VGG16的结构。其中为了方便,keep_prob无论是训练还是测试,都这成了1,大家应该根据需要feed进不同的值。网络的输入TFRecord.createBatch(),为上一篇文章中产生数据的方法。
1.定义网络参数import tensorflow as tfimport numpy as npimport TFRecord#定义网络参数learning_rate = 0.001display_step = 5epochs = 10keep_prob = 0.5
2.定义各种类型的层
#定义卷积操作def conv_op(input_op, name, kh, kw, n_out, dh, dw): input_op = tf.convert_to_tensor(input_op) n_in = input_op.get_shape()[-1].value with tf.name_scope(name) as scope: kernel = tf.get_variable(scope+"w", shape = [kh, kw, n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer_conv2d()) conv = tf.nn.conv2d(input_op, kernel, (1, dh, dw, 1), padding = 'SAME') bias_init_val = tf.constant(0.0, shape = [n_out], dtype = tf.float32) biases = tf.Variable(bias_init_val, trainable = True, name = 'b') z = tf.nn.bias_add(conv, biases) activation = tf.nn.relu(z, name = scope) return activation#定义全连接操作def fc_op(input_op, name, n_out): n_in = input_op.get_shape()[-1].value with tf.name_scope(name) as scope: kernel = tf.get_variable(scope+'w', shape = [n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer()) biases = tf.Variable(tf.constant(0.1, shape = [n_out], dtype = tf.float32), name = 'b') # tf.nn.relu_layer对输入变量input_op与kernel做矩阵乘法加上bias,再做RELU非线性变换得到activation activation = tf.nn.relu_layer(input_op, kernel, biases, name = scope) return activation #定义池化层def mpool_op(input_op, name, kh, kw, dh, dw): return tf.nn.max_pool(input_op, ksize = [1, kh, kw, 1], strides = [1, dh, dw, 1], padding = 'SAME', name = name)
3. 定义网络结构
def inference_op(input_op, keep_prob): # block 1 -- outputs 112x112x64 conv1_1 = conv_op(input_op, name="conv1_1", kh=3, kw=3, n_out=64, dh=1, dw=1) conv1_2 = conv_op(conv1_1, name="conv1_2", kh=3, kw=3, n_out=64, dh=1, dw=1) pool1 = mpool_op(conv1_2, name="pool1", kh=2, kw=2, dw=2, dh=2) # block 2 -- outputs 56x56x128 conv2_1 = conv_op(pool1, name="conv2_1", kh=3, kw=3, n_out=128, dh=1, dw=1) conv2_2 = conv_op(conv2_1, name="conv2_2", kh=3, kw=3, n_out=128, dh=1, dw=1) pool2 = mpool_op(conv2_2, name="pool2", kh=2, kw=2, dh=2, dw=2) # # block 3 -- outputs 28x28x256 conv3_1 = conv_op(pool2, name="conv3_1", kh=3, kw=3, n_out=256, dh=1, dw=1) conv3_2 = conv_op(conv3_1, name="conv3_2", kh=3, kw=3, n_out=256, dh=1, dw=1) conv3_3 = conv_op(conv3_2, name="conv3_3", kh=3, kw=3, n_out=256, dh=1, dw=1) pool3 = mpool_op(conv3_3, name="pool3", kh=2, kw=2, dh=2, dw=2) # block 4 -- outputs 14x14x512 conv4_1 = conv_op(pool3, name="conv4_1", kh=3, kw=3, n_out=512, dh=1, dw=1) conv4_2 = conv_op(conv4_1, name="conv4_2", kh=3, kw=3, n_out=512, dh=1, dw=1) conv4_3 = conv_op(conv4_2, name="conv4_3", kh=3, kw=3, n_out=512, dh=1, dw=1) pool4 = mpool_op(conv4_3, name="pool4", kh=2, kw=2, dh=2, dw=2) # block 5 -- outputs 7x7x512 conv5_1 = conv_op(pool4, name="conv5_1", kh=3, kw=3, n_out=512, dh=1, dw=1) conv5_2 = conv_op(conv5_1, name="conv5_2", kh=3, kw=3, n_out=512, dh=1, dw=1) conv5_3 = conv_op(conv5_2, name="conv5_3", kh=3, kw=3, n_out=512, dh=1, dw=1) pool5 = mpool_op(conv5_3, name="pool5", kh=2, kw=2, dw=2, dh=2) # flatten shp = pool5.get_shape() flattened_shape = shp[1].value * shp[2].value * shp[3].value resh1 = tf.reshape(pool5, [-1, flattened_shape], name="resh1") # fully connected fc6 = fc_op(resh1, name="fc6", n_out=4096) fc6_drop = tf.nn.dropout(fc6, keep_prob, name="fc6_drop") fc7 = fc_op(fc6_drop, name="fc7", n_out=4096) fc7_drop = tf.nn.dropout(fc7, keep_prob, name="fc7_drop") logits = fc_op(fc7_drop, name="fc8", n_out=2) return logits
4.开始训练
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)) optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost) correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) return optimizer, cost, accuracy
5. 主函数
if __name__ == "__main__": train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords" test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords" image_batch, label_batch = TFRecord.createBatch(filename = train_filename, batchsize=2) test_image, test_label = TFRecord.createBatch(filename = test_filename, batchsize=20) pred = inference_op(input_op = image_batch, keep_prob = keep_prob) test_pred = inference_op(input_op = test_image, keep_prob = keep_prob) optimizer, cost, accuracy = train(logits = pred, labels = label_batch) test_optimizer, test_cost, test_acc = train(logits = test_pred, labels = test_label) initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) with tf.Session() as sess: sess.run(initop) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) step = 0 while step < epochs: step += 1 print step _, loss, acc = sess.run([optimizer,cost,accuracy]) if step % display_step ==0: print loss,acc print "training finish!" _, testLoss, testAcc = sess.run([test_optimizer,test_cost,test_acc]) print "Test acc = "+ str(testAcc) print "Test Finish!"
其中的一些参数,都是随便设置的,我们的目的仅仅是将网络结构搭好,把数据feed进网络,将网络训练起来,其中epochs参数,是通过 总样本数 20 / batch_size 2 计算得来的,因为上篇文章中,读取tfrecord数据的方法中,创建了一个文件队列,
filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
这里的num_epochs = 1,使得数据只能被读取一个周期,如果不设置,则可以重复读取。
阅读全文
2 0
- 利用TensorFlow实现VGG16
- TensorFlow实现中文字体分类(三):模型-VGG16
- Docker-tensorflow跑VGG16
- vgg16 on keras for tensorflow
- Tensorflow使用slim工具(vgg16模型)实现图像分类与分割
- Tensorflow深度学习之十五:VGG16模型的简单自主实现
- Tensorflow使用slim工具(vgg16模型)实现图像分类与分割
- 利用TensorFlow实现CNN
- VGG16
- vgg16测试模型的实现
- 基于tensorflow + Vgg16进行图像分类识别的实验
- 基于tensorflow + Vgg16进行图像分类识别的实验
- 基于tensorflow + Vgg16进行图像分类识别的实验
- keras实现VGG16 CIFAR10数据集
- 利用 TensorFlow 实现上下文的 Chat-bots
- 利用 TensorFlow 实现“看图说话”
- 利用 TensorFlow 实现排序和搜索算法
- 利用 TensorFlow 实现排序和搜索算法
- 扬州游记_2017
- Win10 去掉桌面快捷方式小箭头
- Codeforces Round #432 C. Five Dimensional Points
- hdu 4763 kmp的简单应用
- Java设计模式之抽象工厂模式
- 利用TensorFlow实现VGG16
- 常见的HTTP相应状态码
- 什么是协方差矩阵?
- JAVA在win10环境下配置环境变量
- Java经典面试题(其一)——Java异常和克隆
- js页面加载触发事件
- Leetcode #3. Longest Substring Without Repeating Characters
- 【opencv学习之九】opencv3.2配置opencv_contrib方法
- python中的正则表达式(re模块)