TensorFlow学习笔记(二):TensorFlow实现线性回归模型
来源:互联网 发布:床上用品在哪买 知乎 编辑:程序博客网 时间:2024/05/16 11:00
一、线性回归模型中所涉及到API
#导入TensorFlow包import tensorflow as tf#TensorFlow程序分为两个阶段:准备阶段和执行阶段#--------------准备阶段--------------#定义变量、定义函数、定义操作步骤等,中间隐藏着把变量转化为张量的步骤#用tf.Variable来定义变量 #注意:定义矩阵的写法([[],[],[]...])a = tf.Variable([[2,3]])b = tf.Variable([[4],[2]])#矩阵相乘 math multiplyc = tf.matmul(a,b)print('c----->',c)#创建用0填充的矩阵d = tf.zeros([2,4])#平方e = tf.square([2])#平均值f = tf.reduce_mean([1,3])#均匀分布的随机数g = tf.random_uniform([1,10])#--------------执行阶段--------------#获取一个能运行TensorFlow的session图,tf.session#把准备阶段所定义的所有变量都放进session图里进行运行with tf.Session() as sess: #初始化所有的变量 init = tf.global_variables_initializer() sess.run(init) #用sess.run获取最终值 print('a:',a) print('a =',sess.run(a)) print('b =',sess.run(b)) print('c =',sess.run(c)) print('d =',sess.run(d)) print('e =',sess.run(e)) print('f =',sess.run(f)) print('g =',sess.run(g))
二、线性回归模型代码实现
train
# 一元的线性回归模型的训练# 1.通过训练数据,推测出线性回归函数(y = w * x = b)中w和b的值# 2.通过验证数据,验证得到的函数是否符合预期。# 引入Tensorflow函数import tensorflow as tf# 引入绘图表(为了清晰了解训练结果)import matplotlib.pyplot as plt# 引入测试数据模块import testData as td# 1.获取训练数据# 通过testData来模拟第三方接口# get_train_data 获取训练数据 参数:data_length(获取数据的个数) 返回值:二维数组 [0]代表x(横坐标) [1]代表y(纵坐标)# get_validate_data 获取验证数据 参数:data_length(获取数据的个数) 返回值:二维数组 [0]代表x(横坐标) [1]代表y(纵坐标)trainData = td.get_train_data(200)trainx = [v[0] for v in trainData]trainy = [v[1] for v in trainData]# 2.构造预测的线性回归函数 有= W * x + bW = tf.Variable(tf.random_uniform([1]))b = tf.Variable(tf.zeros([1]))y = W * trainx + b# 3.判断假设函数的好坏# 代价函数cost = tf.reduce_mean(tf.square(y - trainy))# 4.调整假设函数# 梯度下降算法找最优解optimizer = tf.train.GradientDescentOptimizer(0.08)train = optimizer.minimize(cost)with tf.Session() as sess: ###########初始化所有变量值########### init = tf.global_variables_initializer() sess.run(init) #初始化W和b的值 print("cost=",sess.run(cost),"W=",sess.run(W),"b=",sess.run(b)) #循环运行 for k in range(500): sess.run(train) #输出训练好的W和b print("cost=", sess.run(cost), "W=", sess.run(W), "b=", sess.run(b)) print("执行完成!") #构造图形结构 plt.plot(trainx, trainy, 'ro', label='train data') plt.plot(trainx, sess.run(y), label='tain result') plt.legend() plt.show()
……
……
……
test
import matplotlib.pyplot as plt#引入测试数据import testData as ptvalidateData = pt.get_validate_data(40)va_x = [v[0] for v in validateData]va_y = [v[1] for v in validateData]#训练结果y = []for x in va_x : y.append(x * 0.3 + 0.8)#构造图形结构plt.plot(va_x, va_y, 'ro', label = 'validate Data')plt.plot(va_x, y, label = 'train result' )plt.legend()plt.show()
引入testData
#引入Numpyimport numpy as np#构造一个线性回归函数#y = W * x + bW = 0.3b = 0.8#生成测试数据def get_train_data(data_lenght): train_arr = [] for i in range(data_lenght): tr_x = np.random.uniform(0.0, 1.0) tr_y = tr_x * W + b + np.random.uniform(-0.02,0.02) train_arr.append([tr_x, tr_y]) #这里是什么意思 = = return train_arr#生成校验数据def get_validate_data(data_lenght): validate_arr = [] for i in range(data_lenght): va_x = np.random.uniform(-0.0, 1.0) va_y = va_x * W + b + np.random.uniform(-0.02, 0.02) validate_arr.append([va_x,va_y]) return validate_arr
重点内容
疑问:
1.train中循环不是很理解代码如何执行
2.append用法不了解
代码运行中碰到的问题:
1.虽然Anaconda默认环境下有自带的matplotlib包,但是TensorFlow环境下没有,需要在Anaonda prompt下安装matplotlib。
conda install ......
conda list
2.testData需要自己写,感觉很想c里面的自己写的函数,后缀是py。
三、课后扩展
1.随机梯度下降
方式:刚在代码500次循环中,每次拿一个或部分训练点(不是全部)进行计算
优点:在大数据下训练更快
缺点:有可能得到不是全局的最优解,只是局部最优解
2.批量梯度下降
方式:每次拿所有训练点进行计算
优点:可以得到全局最优解
缺点:数据量大的时候,训练会很慢
阅读全文
0 0
- TensorFlow学习笔记(二):TensorFlow实现线性回归模型
- TensorFlow学习笔记(三):TensorFlow实现逻辑回归模型
- TensorFlow学习笔记(2)--构造线性回归模型
- 学习TensorFlow,线性回归模型
- TensorFlow学习笔记--线性回归
- Tensorflow学习(一)--使用TensorFlow实现线性回归
- Tensorflow学习笔记(一):初识TensorFlow——实现线性回归
- TensorFlow 实现一元线性回归模型
- 机器学习与TensorFlow编程(1)线性回归模型
- tensorflow 实现线性回归
- Tensorflow实现线性回归
- Tensorflow实现线性回归
- TensorFlow实现线性回归
- TensorFlow学习笔记3:线性回归
- tensorflow tutorials(一):用tensorflow建立线性回归模型
- TensorFlow学习笔记(二十一) tensorflow机器学习模型
- [TensorFlow]入门学习笔记(4)-BasicModel 线性回归,逻辑回归和最近邻模型
- tensorflow tutorials(二):用tensorflow建立岭回归模型
- Maven详解
- Java使用465端口发送邮件(绕过25端口限制)
- Unity3D-CDK兑换模拟
- 什么是Java的永久代(PermGen)内存泄漏
- 最新maven视频教程附全套软件文档源码 18课
- TensorFlow学习笔记(二):TensorFlow实现线性回归模型
- 剑指offer(算法和数据操作篇)
- codeforces 120F Spiders
- 算法基础(转载)
- Sublime Text 3技巧:支持GB2312和GBK编码
- JAVA线程同步
- ICCV2017论文分类
- jmap -heap 命令
- 多线程 批处理 数据导入工具 Java