Tensorflow之nn 简单神经网络学习
来源:互联网 发布:mplayerx for mac下载 编辑:程序博客网 时间:2024/05/16 09:53
最近在深入研究深度学习,关于机器学习的基本知识就略过不说了,在深度学习里面,一些概念性的东西还是很好理解的,重点是如何利用已有的知识去构建一个合适解决实际问题的模型,然后用各种小trick去把参数调优。
试水阶段,为了训一个简单的二分类问题,搭了一个两层的神经网络
这里是不加bias的简单模型
def init_weight(shape): return tf.Variable(tf.random_normal(shape, stddev=0.01))def model(X, w_h, w_o): h = tf.nn.sigmoid(tf.matmul(X, w_h)) return tf.matmul(h, w_o)
注意tensorflow的结构是先构图(模型)再训练,所以需要先将图中的变量和结点声明清楚:
X = tf.placeholder("float", [None, 36])Y = tf.placeholder("float", [None, 2])w_h = init_weight([36, 16])w_o = init_weight([16, 2])py_x = model(X, w_h, w_o)predict_op = tf.argmax(py_x,1)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost)
X为训练样本,Y为训练标签,w_h, w_o为模型中的参数。cost是交叉熵代价,可以设置不同的代价函数。train_op是训练过程中的下降梯度(参考概念back-propagation),当然有不同的下降方法可以使用,0.05是学习速率,这也是个很重要的参数。
模型和变量构建完之后,就可以开始训练了,需要启动session:
with tf.Session() as sess: tf.initialize_all_variables().run() for i in range(1000): for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)): sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]}) accu = np.mean(np.argmax(teY, axis=1) == sess.run(predict_op, feed_dict={X:teX}))
可以手动设置训练1000此停止,也可以根据validation accuracy的变化情况决定何时停止。trX,trY,teX,teY分别为载入数据后的训练集(样本,标签)和测试集(样本,标签)。可以通过观察测试集的准确率调整模型训练参数,也可以观察训练集准确率、代价函数(cost)等。详细的调试经验研究之后再更新。
0 0
- Tensorflow之nn 简单神经网络学习
- 机器学习-神经网络NN
- 【机器学习】--神经网络(NN)
- 小白学Tensorflow之简单神经网络
- Tensorflow深度学习之二:简单卷积神经网络CNN
- 简单的Tensorflow实现NN
- 机器学习之径向基神经网络(RBF NN)
- [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d
- TensorFlow学习笔记之tf.nn.softmax()与tf.nn.softmax_cross_entropy_with_logits的用法
- Tensorflow学习---tf.nn.embedding_lookup
- tensorflow入门之训练简单的神经网络
- tensorflow之安装及简单神经网络搭建
- tensorflow实战之 简单卷积神经网络
- 神经网络NN简单理解以及算法
- 机器学习算法练习之(二):Python和Tensorflow分别实现简单的神经网络
- tensorflow构建简单神经网络
- Tensorflow简单的神经网络
- TensorFlow学习---tf.nn.softmax_cross_entropy_with_logits的用法
- 椭圆机能减肥吗
- Android如何绘制View
- IntelliJ IDEA 13&14 插件推荐及快速上手建议
- mysql的root用户没有grant权限
- unity 安卓热更新代码的最新方法: 通过Mono加载新的重新编译的dll
- Tensorflow之nn 简单神经网络学习
- 硬币求和
- Google Dapper-大规模分布式系统的基础跟踪设施
- java collection库
- 15 个 Android 通用流行框架大全
- Unity Editor Scripting 2
- JavaEE Web服务端必备的核心基础(图)
- java concurrent库
- 将数据存储到文件中