机器学习笔记(三) 随便实现的logistic回归

来源:互联网 发布:淘宝店铺招牌怎么上传 编辑:程序博客网 时间:2024/05/21 03:24

用tensorflow随便写了一个logistic回归的程序,结果还行
写这个程序的时候发现自己还不能尽可能地利用所学到的知识,还需多加练习,多看API才行

import tensorflow as tfimport numpy as npalpha = 0.01;fp1 = open("ex2data1.txt","r")x = []y = []size = 0for line in fp1.readlines():    pre= line.strip("\n").split(",")    size+=1    pre_x = [[pre[0],pre[1]]]    pre_y = pre[2]    #print(pre_x,pre_y)    x.extend(pre_x)    y.extend(pre_y)#print(x,y)train_x = np.array(x).astype(np.float32)train_y = np.array(y).astype(np.float32)print(len(train_x),len(train_y))x = tf.placeholder(tf.float32)y = tf.placeholder(tf.float32)w = tf.Variable(tf.zeros([2,1]))b = tf.Variable(-.3)h = tf.nn.sigmoid(tf.matmul(x,w)+b)loss = tf.reduce_mean((y*tf.log(h) + (1-y)*tf.log(1-h))/-size)optimizer = tf.train.GradientDescentOptimizer(alpha)train = optimizer.minimize(loss);sess = tf.Session()sess.run(tf.initialize_all_variables())for step in range(100):    sess.run(train,{x:train_x,y:train_y})    if step%10 == 0:       print(step,sess.run(w).flatten(),sess.run(b).flatten())

结果1

100 1000 [ 0.00114512  0.00115519] [-0.29998258]10 [ 0.00478334  0.00485309] [-0.29992235]20 [ 0.00507055  0.0051795 ] [-0.29991126]30 [ 0.00508022  0.00522374] [-0.29990426]40 [ 0.00506677  0.0052418 ] [-0.2998977]50 [ 0.00505249  0.00525635] [-0.29989114]60 [ 0.00503923  0.00526948] [-0.29988459]70 [ 0.00502709  0.00528147] [-0.29987803]80 [ 0.00501597  0.00529244] [-0.29987147]90 [ 0.00500578  0.00530248] [-0.29986492]

从数据看应该是收敛了
结果2:

118 1180 [  3.06241617e-07   1.02363003e-06] [-0.29999441]10 [  3.36825246e-06   1.12584212e-05] [-0.29993838]20 [  6.42952500e-06   2.14904649e-05] [-0.29988235]30 [  9.49005971e-06   3.17197591e-05] [-0.29982659]40 [  1.25498582e-05   4.19463177e-05] [-0.29977086]50 [  1.56089245e-05   5.21701477e-05] [-0.29971513]60 [  1.86672514e-05   6.23912347e-05] [-0.2996594]70 [  2.17248453e-05   7.26095896e-05] [-0.29960367]80 [  2.47817025e-05   8.28252014e-05] [-0.29954794]90 [  2.78378229e-05   9.30380847e-05] [-0.29949221]

调大了学习率,果然没有稳定下来,和学到的知识一摸一样……
这次小实验暴露出我没有掌握tensorflow的许多函数,应当多读API

原创粉丝点击