tensorflow之非线性方程
来源:互联网 发布:cnc编程快捷键 编辑:程序博客网 时间:2024/05/21 14:53
[python]import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltdef add_layer(inputs, in_size, out_size, activation_function=None): # activation_function=None线性函数 Weights = tf.Variable(tf.random_normal([in_size, out_size])) # Weight中都是随机变量 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # biases推荐初始值不为0 Wx_plus_b = tf.matmul(inputs, Weights) + biases # inputs*Weight+biases if activation_function is None: # 线性回归 outputs = Wx_plus_b else: # 逻辑回归 outputs = activation_function(Wx_plus_b) return outputsdef get_next_batch(batch_size=64): x = np.random.choice(np.linspace(start=0, stop=1, num=200), batch_size) # 从0-1之间选取batch_size个数,免去对数据的归一化 x = np.reshape(x, newshape=[-1, 1]) # noise=np.random.normal(0, 0.05, x.shape)*50 # 噪声 y = np.power(x, 2) return x, yx_input = tf.placeholder(tf.float32, [None, 1])y_input = tf.placeholder(tf.float32, [None, 1])def model(): layer_1 = add_layer(x_input, 1, 4, activation_function=tf.nn.relu) # 隐藏层 layer_2 = add_layer(layer_1, 4, 4, activation_function=tf.nn.relu) # 隐藏层 prediction = add_layer(layer_2, 4, 1, activation_function=None) # 输出层 return predictionmodel_dir = 'model_non_linear' # 模型存储目录def train(): prediction = model() # 求标准差 loss = tf.reduce_mean( tf.reduce_sum(tf.square(y_input - prediction), reduction_indices=[1])) # square()平方,sum()求和,mean()平均值 optimizer = tf.train.GradientDescentOptimizer(0.005).minimize(loss) # 0.05学习效率,minimize(loss)减小loss误差 sess = tf.Session() saver = tf.train.Saver() checkpoint = tf.train.latest_checkpoint(model_dir) if checkpoint: saver.restore(sess, tf.train.latest_checkpoint(model_dir)) # 从模型中读取参数 else: sess.run(tf.global_variables_initializer()) # 变量初始化 # 训练20万次 for step in range(200000): x_train, y_train = get_next_batch(batch_size=256) loss_, optimizer_ = sess.run([loss, optimizer], feed_dict={x_input: x_train, y_input: y_train}) if step % 50 == 0: print(step, loss_) # 模型保存 saver.save(sess, '{}/non_linear_equation'.format(model_dir), global_step=step)def test(): predition = model() sess = tf.Session() saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(model_dir)) x_test, y_test = get_next_batch(200) y_pred = sess.run(predition, feed_dict={x_input: x_test}) # 数据可视化 fig = plt.figure() ax = fig.add_subplot(2, 1, 1) ax.scatter(x_test, y_test) bx = fig.add_subplot(2, 1, 2) bx.scatter(x_test, y_pred, c='r') plt.xlabel('X') plt.ylabel('Y') plt.show()if __name__ == '__main__': train() # test()
训练20万次后得到下图
阅读全文
0 0
- tensorflow之非线性方程
- MATLAB实例之对线性,非线性,超越方程的求解
- Matlab非线性方程求根
- 二分法解非线性方程
- 非线性方程求解
- 非线性方程求根
- 试位法求解非线性方程
- Python求解非线性方程
- Python解非线性方程
- Python求解非线性方程
- 非线性方程求根迭代法
- Matlab非线性方程求解
- 非线性方程的解法
- Newton_Raphson法求解非线性方程
- 割线法求解非线性方程
- fsolve函数求解非线性方程
- solve it--非线性方程求根
- 牛顿法求解非线性方程
- Ubuntu命令行常用的指令整理
- QT之qss教程- QPushButton
- 获取指定名称DLL
- DLL导出函数名称改编的解决方法
- 使用freemarker生成word,步骤详解并奉上源代码
- tensorflow之非线性方程
- GameEntity(九)—— InviteOtherPlayer
- Leetcode||49. Group Anagrams
- 基础判断网络
- 数据库里程(2):数据库的隔离机制
- 水经注地图发布服务中间件的适用范围
- Core Animation实战六(专用图层)
- java http请求数据 未完待续
- error LNK2001: 无法解析的外部符号 _GUID_DEVCLASS_ADB