tensorflow 使用正则化
来源:互联网 发布:最短路径的算法 编辑:程序博客网 时间:2024/06/10 17:45
Tensorflow 使用正则化T
import tensorflow.contrib.layers as layers
def easier_network(x, reg): """ A network based on tf.contrib.learn, with input `x`. """ with tf.variable_scope('EasyNet'): out = layers.flatten(x) out = layers.fully_connected(out, num_outputs=200, weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = tf.nn.tanh) out = layers.fully_connected(out, num_outputs=200, weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = tf.nn.tanh) out = layers.fully_connected(out, num_outputs=10, # Because there are ten digits! weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = None) return out
def main(_): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) # Make a network with regularization y_conv = easier_network(x, FLAGS.regu) weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet') print("") for w in weights: shp = w.get_shape().as_list() print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp))) print("") reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet') for w in reg_ws: shp = w.get_shape().as_list() print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp))) print("") # Make the loss function `loss_fn` with regularization. cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) loss_fn = cross_entropy + tf.reduce_sum(reg_ws) train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)
tf.GraphKeys.REGULARIZATION_LOSSES得到在图中正则化的损失
regularizer=tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(weight_decay),gen_vars+d_vars)这样也可以
阅读全文
0 0
- tensorflow 使用正则化
- TensorFlow正则化
- 【TensorFlow】MNIST(使用全连接神经网络+滑动平均+正则化+指数衰减法+激活函数)
- tensorflow:3.1)add_to_collection和L2正则化
- L2正则化—tensorflow实现
- tensorflow regularizer(正则化)防止过拟合
- tensorflow 实现神经网络带正则化
- TensorFlow优化模型之正则化
- 【TensorFlow】正则化(过拟合问题)
- 【TensorFlow】MNIST(使用LeNet5+滑动平均+正则化+指数衰减法+激活函数+模型持久化)
- 【tensorflow 学习】给LSTM加上L2正则化
- tensorflow中正则化防止过拟合以及Batch Normalization
- [Tensorflow]L2正则化和collection【tf.GraphKeys】
- tensorflow使用
- tensorflow 使用
- Machine Learning with Scikit-Learn and Tensorflow 6.7 正则化超参数
- tensorflow06 《TensorFlow实战Google深度学习框架》笔记-04-04正则化
- TensorFlow之损失函数、学习率、正则
- 【云和恩墨】性能优化:Linux环境下合理配置大内存页(HugePage)
- spring in action 面向切面
- 数组遍历排序
- NDK的基础教程 四 动态注册
- TypeError: Expected int32, got list containing Tensors of type '_Message' instead.解决方法
- tensorflow 使用正则化
- Hyper-V 虚拟机虚拟网卡慢问题解决方案
- 多态相关(虚函数,覆盖,纯虚函数,抽象类)
- HDU 1907 John (Nim博弈 模板)
- JQuery学习笔记(One)
- 一种排序
- 关于Unity3D中使用SQLite数据库发生的几种常见错误(适用新手)
- MYSQL常用命令
- c++后台开发需要掌握哪些知识