线性回归

来源:互联网 发布:web of science数据库 编辑:程序博客网 时间:2024/06/07 21:57
import tensorflow as tf#初始化变量和模型参数, 定义训练闭环中运算W = tf.Variable(tf.zeros([2, 1]), name = "weights")b = tf.Variable(0.0, name = "bias")#计算推断模型在数据X上的输出, 并将结果返回def inference(X):    return tf.matmul(X, W) + b#依据数据X及其期望输出Y的计算损失def loss(X, Y):    Y_predicted = inference(X)    return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))'''a = tf.to_float([1,2,3,4,5])b = tf.to_float([1,1,1,1,1])with tf.Session() as sess:    print sess.run(tf.squared_difference(a, b))    print sess.run(tf.square(a - b))输出结果:[  0.   1.   4.   9.  16.][  0.   1.   4.   9.  16.]'''#读取或生成训练数据X及其期望输出Ydef inputs():    weights_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25],[63, 28], [72,36],                               [79, 57], [75, 44], [27, 44], [89,31], [65, 52],[57,23], [59, 60],[69, 48], [60, 34],                                [79, 51], [75, 50], [82, 34], [59,46], [67,23], [85, 37],[55, 40], [63, 30]                  ]    blood_fat_content = [ 354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346,                                       254, 395, 434, 220, 374, 308, 220, 311, 181, 274, 303, 244                        ]    return tf.to_float(weights_age), tf.to_float(blood_fat_content)#对训练得到的模型进行评估def evaluate(sess, X, Y):    print sess.run(inference([[80., 25.]]))    print sess.run(inference([[65., 25.]]))#依据计算的总损失训练或调整模型参数def train(tol_loss):    learning_rate = 0.0000001    return tf.train.GradientDescentOptimizer(learning_rate).minimize(tol_loss)#在一个会话中对象启动数据流图, 搭建流程with tf.Session() as sess:    tf.global_variables_initializer().run()    X, Y = inputs()      tol_loss = loss(X, Y)    train_op = train(tol_loss)    #搭建多线程    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners()    #实际的训练次数    training_steps = 1000    for step in range(training_steps):        sess.run([train_op])        #处于调试和学习的目的, 查看损失在训练过程中递减情况        if step % 10 == 0:            print "loss: ", sess.run([tol_loss])    evaluate(sess, X, Y)    coord.request_stop()    coord.join(threads)