tensorflow中slim高级库的应用
来源:互联网 发布:淘宝网成功的原因 编辑:程序博客网 时间:2024/06/08 09:02
tensorflow中slim库学习
在阅读用tensorflow实现的深度学习网络结构的源码时,经常会看到作者使用TF中封装的slim高级库,看起来(实际上也是)比直接调用TF的API简洁好多。为了弄懂网络源码和学习slim库应用,特地查阅了一些资料,在这里做一下学习时的记录。
tensorflow中关于slim库的介绍
某位博主关于上面slim英文介绍的一些翻译
下面直接贴出我实现的一个应用slim库的代码:
import tensorflow as tfimport tensorflow.contrib.slim as slimfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('./data/MNIST',one_hot=True)def cal_loss(y_pre,y_label): return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y_pre))def cal_accuracy(y_pre,y_label): return tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y_pre, dimension=1),tf.arg_max(y_label, dimension=1)),tf.float32))def network(inputs,y_label): with slim.arg_scope([slim.conv2d],######可以在列表里添加其他要简化的操作,比如再添加全连接。函数中下面的参数是默认执行的操作 activation_fn=tf.nn.relu,########可以应用自己编写的激活函数 weights_initializer=slim.xavier_initializer(),####默认xavier_initializer初始化权值 biases_initializer=tf.zeros_initializer(), weights_regularizer=slim.l2_regularizer(0.0005), padding='SAME'): print inputs.get_shape() net = slim.conv2d(inputs,num_outputs=32,kernel_size=[3,3],stride=1,scope='conv1') print net.get_shape() net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool1') print net.get_shape() net = slim.conv2d(net, num_outputs=64, kernel_size=[3,3], stride=1, scope='conv2') print net.get_shape() net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool2') print net.get_shape() net = slim.conv2d(net,num_outputs=64,kernel_size=[3,3],scope='conv3') print net.get_shape() fc_flat = slim.flatten(net) print fc_flat.get_shape() fc1 = slim.fully_connected(fc_flat, num_outputs=512, scope='fc1') print fc1.get_shape() y_out = slim.fully_connected(fc1, num_outputs=10, scope='y_out') print y_out.get_shape() accuracy = cal_accuracy(y_out, y_label) l2_loss = tf.add_n(slim.losses.get_regularization_losses()) return cal_loss(y_out,y_label) + l2_loss , accuracy def main(): x_data = tf.placeholder(dtype=tf.float32, shape=[None,784], name='x_data') y_label = tf.placeholder(dtype=tf.float32, shape=[None,10], name='y_label') x_input = tf.reshape(x_data, shape=[-1,28,28,1], name='x_input') loss,accuracy = network(x_input,y_label) train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(30001): xs,ys = mnist.train.next_batch(64) if i % 1000 == 0: loss_op,ac = sess.run([loss,accuracy],feed_dict={x_data:xs,y_label:ys}) print 'the %dth iteration loss: %f'%(i,loss_op) print 'the %dth iteration accuracy: %f'%(i,ac) sess.run(train_op,feed_dict={x_data:xs,y_label:ys}) total_acc = sess.run(accuracy,feed_dict={x_data:mnist.validation.images,y_label:mnist.validation.labels}) print 'the total accuracy: %f'%(total_acc) if __name__ == '__main__': main()
代码运行结果:
在上面添加的链接中有关于slim的详细介绍,大家仔细看看即可,有不懂的或者我写错了的地方,可以交流~
阅读全文
0 0
- tensorflow中slim高级库的应用
- 【Tensorflow slim】 slim.arg_scope的用法
- 使用TF-Slim:在TensorFlow中定义复杂模型的高层库
- TensorFlow-Slim image classification library:TensorFlow-Slim 图像分类库
- TensorFlow-Slim图像分类库
- Tensorflow slim库使用小记
- 使用Tensorflow的slim库进行迁移学习
- tensorflow slim【TF-Slim】
- tensorflow中slim模块api介绍
- tensorflow中slim模块api介绍
- tensorflow的slim换地址了
- 【Tensorflow slim】slim.data包
- 【Tensorflow slim】slim evaluation 函数
- 【Tensorflow slim】slim layers包
- 【Tensorflow slim】slim learning包
- 【Tensorflow slim】slim losses包
- 【Tensorflow slim】slim nets包
- 【Tensorflow slim】slim variables包
- mybatis的一些特殊SQL用法
- JVM复习
- 线性表之连续存储(数组)
- JQuery基础
- xilinx--IOB(1)
- tensorflow中slim高级库的应用
- 清理vmware碎片
- 【Data Struct】冒泡排序算法
- iptables 规则备份和恢复,firewalld防火墙机制
- ThinkPHP5.1 action变量路由&controller变量路由
- Linux基础(一):文件处理命令
- 泊松分布
- class文件结构
- 精选11道Java技术面试题并有答案