sofmax 算法 ---多分类回归

来源:互联网 发布:中美网络大战华为 编辑:程序博客网 时间:2024/05/22 15:53

上一篇博文中已经用逻辑回归解决了二分类的问题,那多分类的问题呢,能不能解决呢,很显然通过逻辑回归不能解决,但是,softmax 算法,也被称为是逻辑回归的扩展,很好的解决了多分类的问题。

在推倒出softmax 模型时 我们用到了 指数分布族 和 广义线性模型 两个概念。对于逻辑回归,我们用的是二项分布,但是对于softmax我们必须用多项式分布,但是二项分布可以理解为多项式分布的一个特例,所以从二项分布去理解多项式分布,或者从逻辑回归去理解softmax算法是很有帮助的,我可能会在下一篇博文中 贴出 怎么从逻辑回归去理解softmax算法。

softmax算法的推导


以上可以算是 softmax 的推导 以及写出了 梯度的公式。 所以后面可以用梯度下降的方法,去优化参数,得到一个分类的结果。

直接上代码,(python):

#_*_coding:utf-8 _*_import numpy as npimport matplotlib.pyplot as pltclass SoftMax(object):def __init__(self,data,label,alpha,k,thetaDim):self.x = np.vstack((np.ones((1,data.shape[1])),data));self.label = labelself.alpha = alphaself.thetaDim = thetaDimself.theta = np.random.random((k-1,thetaDim))  #   np.zeros((1,3))self.k= kdef hypothesis(self,x,i): #x is single example   i row theta P_y_i=np.exp(self.theta[i,:].dot(x.T) )/(1+np.sum(np.exp(self.theta.dot(x.T).T))  )  # 计算 p(y=i|x)return P_y_idef learn(self):theta_tmp = np.ones((self.k-1,self.thetaDim)) # 创建 用于更新theta 的差值P_y_j=np.ones( (self.x.shape[1],1) )count =0;while count <1000:count+=1;for j in range(0,self.k-1):for i in range(0,self.x.shape[1]):P_y_i= self.hypothesis(self.x[:,i],j) #计算  p(y=i|x)P_y_j[i,:]=P_y_i #对于每一个x都计算一下p(y=i|x)theta_tmp[j,:]=np.sum( self.x.T*( (j==self.label.T) -P_y_j ),axis = 0  )#然后更新 theta (j) 的差值self.theta=self.theta +self.alpha*theta_tmpif __name__ =="__main__":trainData =np.array([   [1,1,1,2,2,2,3,3,3,1,1,1,2,2,2,3,3,3,6,6,6,7,7,7,8,8,8,6,6,6,7,7,7,8,8,8],   [1,2,3,1,2,3,1,2,3,6,7,8,6,7,8,6,7,8,1,2,3,1,2,3,1,2,3,6,7,8,6,7,8,6,7,8]])   trainLabel = np.array([[0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3]])  #traindata and  trainlabeldata_class1,data_class2,data_class3,data_class4 = np.hsplit(trainData,4)  #split traindata to four classsoftmax = SoftMax(trainData,trainLabel,0.1,4,3) # softmax algorithmsoftmax.learn()theta = softmax.theta #theta after trainx = np.array([1,8,3.5])  #testdata after construct   actually  x= [3.5,3.5]  result =[]for i in range(0,3):re= np.exp(theta[i,:].dot(x.T) )/(1+np.sum(np.exp(theta.dot(x.T).T))  )result.append(re)result.append(1-sum(result) )testLabel=result.index(max(result))print "the class of testdata = ", testLabel  #print the class of testdata#draw testdataiconLabel = {0:'+',1:'^',2:'o',3:'s'}plt.plot(x[1],x[2],iconLabel[testLabel])#draw traindataplt.plot(trainData[0],trainData[1],'+')plt.plot(data_class1[0],data_class1[1],'+')plt.plot(data_class2[0],data_class2[1],'^')plt.plot(data_class3[0],data_class3[1],'o')plt.plot(data_class4[0],data_class4[1],'s')plt.axis([0,10,0,10])plt.show()


代码不多,但是这一个写起来还是比前几个稍微复杂的,至少得理解了公式以后才能尝试自己写出代码。

贴出结果图:


0 0
原创粉丝点击