机器学习之利用AdaBoost元算法提高分类性能
来源:互联网 发布:淘宝网400电话 编辑:程序博客网 时间:2024/06/05 22:31
本文主要记录本人在学习机器学习过程中的相关代码实现,参考《机器学习实战》
from numpy import *def loadSimpData(): datMat=matrix([[1.,2.1], [2.,1.1], [1.3,1.], [1.,1.], [2.,1.]]) classLabels=[1.0,1.0,-1.0,-1.0,1.0] return datMat,classLabelsdef stumpClassify(dataMatrix,dimen,threshVal,threshIneq): retArray=ones((shape(dataMatrix)[0],1)) if threshIneq=='lt': retArray[dataMatrix[:,dimen]<=threshVal]=-1.0 else: retArray[dataMatrix[:,dimen]>threshVal]=-1.0 return retArraydef buildStump(dataArr,classLabels,D): dataMatrix=mat(dataArr);labelMat=mat(classLabels).T m,n=shape(dataMatrix) numSteps=10.0;bestStump={};bestClasEst=mat(zeros((m,1))) minError=inf for i in range(n): rangeMin=dataMatrix[:,i].min() rangeMax=dataMatrix[:,i].max() stepSize=(rangeMax-rangeMin)/numSteps for j in range(-1,int(numSteps)+1): for inequal in ['lt','gt']: threshVal=(rangeMin+float(j)*stepSize) predictedVals=stumpClassify(dataMatrix,i,threshVal,inequal) errArr=mat(ones((m,1))) errArr[predictedVals==labelMat]=0 weightedError=D.T*errArr #~ print("split:dim %d,thresh %.2f,thresh inequal: %s,the \ #~ weighted error is %.3f" % (i,threshVal,inequal,weightedError)) if weightedError<minError: minError=weightedError bestClasEst=predictedVals.copy() bestStump['dim']=i bestStump['thresh']=threshVal bestStump['ineq']=inequal return bestStump,minError,bestClasEst#~ datMat,classLabels=loadSimpData()#~ D=mat(ones((5,1))/5)#~ print(buildStump(datMat,classLabels,D))def adaBoostTrainDS(dataArr,classLabels,numIt=40): weakClassArr=[] m=shape(dataArr)[0] D=mat(ones((m,1))/m) aggClassEst=mat(zeros((m,1))) for i in range(numIt): bestStump,error,classEst=buildStump(dataArr,classLabels,D) print('D:',D.T) alpha=float(0.5*log((1.0-error)/max(error,1e-16))) bestStump['alpha']=alpha weakClassArr.append(bestStump) print('classEst:',classEst.T) expon=multiply(-1*alpha*mat(classLabels).T,classEst) D=multiply(D,exp(expon)) D=D/D.sum() aggClassEst+=alpha*classEst print('aggClassEst:',aggClassEst.T) aggErrors=multiply(sign(aggClassEst)!=mat(classLabels).T,ones((m,1))) errorRate=aggErrors.sum()/m print('total error:',errorRate,'\n') if errorRate==0.0: break return weakClassArr#~ datMat,classLabels=loadSimpData()#~ classifierArray=adaBoostTrainDS(datMat,classLabels,9)#~ print(classifierArray)def adaClassify(datToClass,classifierArr): dataMatrix=mat(datToClass) m=shape(dataMatrix)[0] aggClassEst=mat(zeros((m,1))) for i in range(len(classifierArr)): classEst=stumpClassify(dataMatrix,classifierArr[i]['dim'], classifierArr[i]['thresh'],classifierArr[i]['ineq']) aggClassEst+=classifierArr[i]['alpha']*classEst print(aggClassEst) return sign(aggClassEst)#~ datMat,classLabels=loadSimpData()#~ classifierArray=adaBoostTrainDS(datMat,classLabels,30)#~ print(adaClassify([0,0],classifierArray))def loadDataSet(filename): with open(filename) as fr: numFeat=len(fr.readline().split('\t')) dataMat=[];labelMat=[] for line in fr.readlines(): lineArr=[] curLine=line.strip().split('\t') for i in range(numFeat-1): lineArr.append(float(curLine[i])) dataMat.append(lineArr) labelMat.append(float(curLine[-1])) return dataMat,labelMatdatArr,labelArr=loadDataSet('horseColicTest2.txt')classifierArray=adaBoostTrainDS(datArr,labelArr,10)print(classifierArray)
阅读全文
0 0
- 机器学习之利用AdaBoost元算法提高分类性能
- 机器学习实战-利用AdaBoost元算法提高分类性能
- 《机器学习实战》笔记之七——利用AdaBoost元算法提高分类性能
- 《机器学习实战》笔记之七——利用AdaBoost元算法提高分类性能
- 《机器学习实战》笔记之七——利用AdaBoost元算法提高分类性能
- 《机器学习实战》笔记之七——利用AdaBoost元算法提高分类性能
- 《机器学习实战》学习笔记:利用Adaboost元算法提高分类性能
- 《机器学习实战》学习笔记:利用Adaboost元算法提高分类性能
- 【机器学习实战-python3】Adaboost元算法提高分类性能
- 机器学习-python编写Adaboost元算法提高分类性能
- 机器学习实战——利用AdaBoost元算法提高分类性能
- [完]机器学习实战 第七章 利用AdaBoost元算法提高分类性能
- 代码注释:机器学习实战第7章 利用AdaBoost元算法提高分类性能
- python机器学习实战6:利用adaBoost元算法提高分类性能
- 读书笔记:机器学习实战【第7章:利用Adaboost元算法提高分类性能】
- 机器学习实战代码详解(七)利用AdaBoost元算法提高分类性能
- 机器学习实战读书笔记----利用Adaboost元算法提高分类性能
- 机器学习实战笔记-利用AdaBoost元算法提高分类性能
- 利用推广的方法证明NP-完全性
- 求一个数的二进制中1的个数(补码形式下)
- HDU 1728
- [Android6.0] RILC 系统结构及 LibRIL 运行机制
- extend 的js实现
- 机器学习之利用AdaBoost元算法提高分类性能
- 精通比特币
- Ural1017
- Quick Sort
- 欢迎使用CSDN-markdown编辑器
- java的反射
- 数列的极限
- 华为ETS8能否替代ETF8来做EFS0单板的接口板
- STL中string中c_str(),data(),copy()