线性回归
来源:互联网 发布:web of science数据库 编辑:程序博客网 时间:2024/06/07 21:57
import tensorflow as tf#初始化变量和模型参数, 定义训练闭环中运算W = tf.Variable(tf.zeros([2, 1]), name = "weights")b = tf.Variable(0.0, name = "bias")#计算推断模型在数据X上的输出, 并将结果返回def inference(X): return tf.matmul(X, W) + b#依据数据X及其期望输出Y的计算损失def loss(X, Y): Y_predicted = inference(X) return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))'''a = tf.to_float([1,2,3,4,5])b = tf.to_float([1,1,1,1,1])with tf.Session() as sess: print sess.run(tf.squared_difference(a, b)) print sess.run(tf.square(a - b))输出结果:[ 0. 1. 4. 9. 16.][ 0. 1. 4. 9. 16.]'''#读取或生成训练数据X及其期望输出Ydef inputs(): weights_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25],[63, 28], [72,36], [79, 57], [75, 44], [27, 44], [89,31], [65, 52],[57,23], [59, 60],[69, 48], [60, 34], [79, 51], [75, 50], [82, 34], [59,46], [67,23], [85, 37],[55, 40], [63, 30] ] blood_fat_content = [ 354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346, 254, 395, 434, 220, 374, 308, 220, 311, 181, 274, 303, 244 ] return tf.to_float(weights_age), tf.to_float(blood_fat_content)#对训练得到的模型进行评估def evaluate(sess, X, Y): print sess.run(inference([[80., 25.]])) print sess.run(inference([[65., 25.]]))#依据计算的总损失训练或调整模型参数def train(tol_loss): learning_rate = 0.0000001 return tf.train.GradientDescentOptimizer(learning_rate).minimize(tol_loss)#在一个会话中对象启动数据流图, 搭建流程with tf.Session() as sess: tf.global_variables_initializer().run() X, Y = inputs() tol_loss = loss(X, Y) train_op = train(tol_loss) #搭建多线程 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners() #实际的训练次数 training_steps = 1000 for step in range(training_steps): sess.run([train_op]) #处于调试和学习的目的, 查看损失在训练过程中递减情况 if step % 10 == 0: print "loss: ", sess.run([tol_loss]) evaluate(sess, X, Y) coord.request_stop() coord.join(threads)
阅读全文
0 0
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- HDU
- loadrunner Lr_类函数之lr_decrypt()
- iterm2自动登陆,解决分栏后vi混乱
- 欧拉角和旋转矩阵相互转换
- 【SSH】Hibernate学习(一)
- 线性回归
- java 向DB2插入数据
- 昂贵的聘礼(dijkstra)
- [LeetCode-Algorithms-10] "Regular Expression Matching" (2017.10.12-WEEK6)
- string转换为LPCWSTR
- 中点bresenham算法画线
- Lesson 3 上机练习题——继承
- [HNOI2004]打鼹鼠
- opencv归一化函数normalize详解