tensorflow基础使用4
来源:互联网 发布:sql 多个case when 编辑:程序博客网 时间:2024/06/15 14:08
非线性回归
# coding: utf-8import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt#使用numpy生成200个随机点x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]noise = np.random.normal(0,0.02,x_data.shape)y_data = np.square(x_data) + noise#定义两个placeholderx = tf.placeholder(tf.float32,[None,1])y = tf.placeholder(tf.float32,[None,1])#定义神经网络中间层Weights_L1 = tf.Variable(tf.random_normal([1,10]))biases_L1 = tf.Variable(tf.zeros([1,10]))Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1L1 = tf.nn.tanh(Wx_plus_b_L1)#定义神经网络输出层Weights_L2 = tf.Variable(tf.random_normal([10,1]))biases_L2 = tf.Variable(tf.zeros([1,1]))Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases_L2prediction = tf.nn.tanh(Wx_plus_b_L2)#二次代价函数loss = tf.reduce_mean(tf.square(y-prediction))#使用梯度下降法训练train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)with tf.Session() as sess: #变量初始化 sess.run(tf.global_variables_initializer()) for _ in range(2000): sess.run(train_step,feed_dict={x:x_data,y:y_data}) #获得预测值 prediction_value = sess.run(prediction,feed_dict={x:x_data}) #画图 plt.figure() plt.scatter(x_data,y_data) plt.plot(x_data,prediction_value,'r-',lw=5) plt.show()
阅读全文
0 0
- tensorflow基础使用4
- tensorflow基础使用1
- tensorflow基础使用2
- tensorflow基础使用3
- tensorflow基础使用5
- TensorFlow使用基础(Basic Usage)
- TensorFlow基础
- Tensorflow基础
- Tensorflow基础
- Tensorflow 基础
- Tensorflow基础
- tensorflow基础
- Tensorflow基础
- TensorFlow基础
- tensorflow基础
- Tensorflow基础
- TensorFlow基础知识点(二)交互式使用/Interactive Usage
- Tensorflow基础:使用验证数据集判断模型效果
- SpringBoot获得application.properties中数据的几种方式
- Django运行访问项目出现的问题:DisallowedHost at / Invalid HTTP_HOST header
- Qt Creator快捷键
- UEFI 双启动情况下禁用 GRUB 的启动菜单
- [LeetCode] 4. Longest Palindrome Substring 分析+代码
- tensorflow基础使用4
- 超级读入挂
- Shiro异常:java.lang.IllegalArgumentException: Line argument must contain a key and a value. Only one
- ubuntu Qt5环境变量设置
- 足球赛确定淘汰赛名单-map<string, int>问题
- 在网页上动态显示当前时间
- HDU 1029 Ignatius and the Princess IV 水题
- Java导入Excel数据方法
- 删除列表中重复值1