sofmax 算法 ---多分类回归
来源:互联网 发布:中美网络大战华为 编辑:程序博客网 时间:2024/05/22 15:53
上一篇博文中已经用逻辑回归解决了二分类的问题,那多分类的问题呢,能不能解决呢,很显然通过逻辑回归不能解决,但是,softmax 算法,也被称为是逻辑回归的扩展,很好的解决了多分类的问题。
在推倒出softmax 模型时 我们用到了 指数分布族 和 广义线性模型 两个概念。对于逻辑回归,我们用的是二项分布,但是对于softmax我们必须用多项式分布,但是二项分布可以理解为多项式分布的一个特例,所以从二项分布去理解多项式分布,或者从逻辑回归去理解softmax算法是很有帮助的,我可能会在下一篇博文中 贴出 怎么从逻辑回归去理解softmax算法。
softmax算法的推导
:
以上可以算是 softmax 的推导 以及写出了 梯度的公式。 所以后面可以用梯度下降的方法,去优化参数,得到一个分类的结果。
直接上代码,(python):
#_*_coding:utf-8 _*_import numpy as npimport matplotlib.pyplot as pltclass SoftMax(object):def __init__(self,data,label,alpha,k,thetaDim):self.x = np.vstack((np.ones((1,data.shape[1])),data));self.label = labelself.alpha = alphaself.thetaDim = thetaDimself.theta = np.random.random((k-1,thetaDim)) # np.zeros((1,3))self.k= kdef hypothesis(self,x,i): #x is single example i row theta P_y_i=np.exp(self.theta[i,:].dot(x.T) )/(1+np.sum(np.exp(self.theta.dot(x.T).T)) ) # 计算 p(y=i|x)return P_y_idef learn(self):theta_tmp = np.ones((self.k-1,self.thetaDim)) # 创建 用于更新theta 的差值P_y_j=np.ones( (self.x.shape[1],1) )count =0;while count <1000:count+=1;for j in range(0,self.k-1):for i in range(0,self.x.shape[1]):P_y_i= self.hypothesis(self.x[:,i],j) #计算 p(y=i|x)P_y_j[i,:]=P_y_i #对于每一个x都计算一下p(y=i|x)theta_tmp[j,:]=np.sum( self.x.T*( (j==self.label.T) -P_y_j ),axis = 0 )#然后更新 theta (j) 的差值self.theta=self.theta +self.alpha*theta_tmpif __name__ =="__main__":trainData =np.array([ [1,1,1,2,2,2,3,3,3,1,1,1,2,2,2,3,3,3,6,6,6,7,7,7,8,8,8,6,6,6,7,7,7,8,8,8], [1,2,3,1,2,3,1,2,3,6,7,8,6,7,8,6,7,8,1,2,3,1,2,3,1,2,3,6,7,8,6,7,8,6,7,8]]) trainLabel = np.array([[0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3]]) #traindata and trainlabeldata_class1,data_class2,data_class3,data_class4 = np.hsplit(trainData,4) #split traindata to four classsoftmax = SoftMax(trainData,trainLabel,0.1,4,3) # softmax algorithmsoftmax.learn()theta = softmax.theta #theta after trainx = np.array([1,8,3.5]) #testdata after construct actually x= [3.5,3.5] result =[]for i in range(0,3):re= np.exp(theta[i,:].dot(x.T) )/(1+np.sum(np.exp(theta.dot(x.T).T)) )result.append(re)result.append(1-sum(result) )testLabel=result.index(max(result))print "the class of testdata = ", testLabel #print the class of testdata#draw testdataiconLabel = {0:'+',1:'^',2:'o',3:'s'}plt.plot(x[1],x[2],iconLabel[testLabel])#draw traindataplt.plot(trainData[0],trainData[1],'+')plt.plot(data_class1[0],data_class1[1],'+')plt.plot(data_class2[0],data_class2[1],'^')plt.plot(data_class3[0],data_class3[1],'o')plt.plot(data_class4[0],data_class4[1],'s')plt.axis([0,10,0,10])plt.show()
代码不多,但是这一个写起来还是比前几个稍微复杂的,至少得理解了公式以后才能尝试自己写出代码。
贴出结果图:
0 0
- sofmax 算法 ---多分类回归
- 分类算法:Logistic回归
- CART分类回归树算法
- 分类算法之逻辑回归
- 分类&回归算法-随机森林
- 多类分类回归
- 分类算法--并行逻辑回归算法
- R分类算法-Logistic回归算法
- 分类算法之logistic 回归模型
- 利用CART算法建立分类回归树
- 机器学习算法-分类回归树CART
- 数据挖掘笔记-分类-回归算法-最小二乘法
- 机器学习-分类算法-逻辑回归
- R语言使用逻辑回归分类算法
- 6章 分类问题、逻辑回归算法
- 拉格朗日多项式逻辑回归分类算法
- 分类算法之逻辑回归详解
- Multinomial 回归多分类推导
- 生产者消费者模型
- Android开发者应该深入学习的10个开源应用项目
- 三,Java集合类(1)
- 关于sqlserver in 走不走索引
- JAVA学习代码——日志文件
- sofmax 算法 ---多分类回归
- angularjs2 五
- 解决live CD方式启动Ubuntu系统不能启动openssh-server服务
- 通过异步过程调用(APC)注入DLL
- 常用的机器学习算法优缺点
- Windows下python和pip的环境配置 ---转载
- jsp自定义标签用法实例详解
- iOS-Runtime
- android程序中Zxing二维码扫描图片变形 问题解决方法