代码注释:机器学习实战第9章 树回归
来源:互联网 发布:淘宝联盟客服电话 编辑:程序博客网 时间:2024/04/29 11:06
写在开头的话:在学习《机器学习实战》的过程中发现书中很多代码并没有注释,这对新入门的同学是一个挑战,特此贴出我对代码做出的注释,仅供参考,欢迎指正。
1、将CART算法用于回归
#coding:gbkfrom numpy import *#作用:从文件导入数据#输入:文件名#输出:数据矩阵def loadDataSet(fileName): dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line. strip().split('\t') fltLine = map(float, curLine)#将每行映射成浮点数 dataMat.append(fltLine) return dataMat# 作用:从文件导入数据# 输入:数据矩阵,待切分特征值,阈值# 输出:切分后的数据集def binSplitDataSet(dataSet, feature, value): # 书中最后有[0],练习发现只会返回1*n矩阵,因此删掉 mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]# nonzero()返回的是列表的下标值 mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :] return mat0, mat1# 作用:目标变量的均值# 输入:数据集# 输出:目标变量的均值def regLeaf(dataSet): return mean(dataSet[:, -1])# 作用:目标变量的总方差# 输入:数据集# 输出:目标变量的总方差def regErr(dataSet): return var(dataSet[:, -1]) * shape(dataSet)[0]#均方差乘以数据集中样本的个数,得总方差# 作用:找到数据的最佳二元切分方式# 输入:数据集,建立叶节点的函数,误差计算函数,包含树构建所需其他参数的元组# 输出:最佳切分特征和特征值def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)): tolS = ops[0]#容许的误差下降值 tolN = ops[1]#切分的最少样本数 #如果所有值都相等则退出 if len(set(dataSet[:, -1].T.tolist()[0])) == 1: return None, leafType(dataSet) m, n = shape(dataSet)#数据集的行数和列数 S = errType(dataSet)#数据集的总方差 bestS = inf bestIndex = 0 bestValue = 0 for featIndex in range(n-1):#最后一列为y值 #书中代码有错,需改成如下形式,转置后转换为列表 for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):#对每个特征的特征值 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 有切分的数据集太小,跳过该种切分方式 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue #切分后的总方差 newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS #如果误差不大则退出 if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #如果切分出的数据集很小则退出 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet) return bestIndex, bestValue#作用:创建树#输入:数据集,建立叶节点的函数,误差计算函数,包含树构建所需其他参数的元组#输出:树def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)): feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree#作用:测试输入变量是否是一棵树#输入:输入变量#输出:布尔类型的结果,是一棵树则返回Truedef isTree(obj): return (type(obj).__name__ == 'dict')#如果是一棵树,则类型为'dict'即字典#作用:得到树的平均值#输入:树#输出:数的平均值def getMean(tree): if isTree(tree['right']): tree['right'] = getMean(tree['right']) if isTree(tree['left']): tree['left'] = getMean(tree['left']) #如果左子树和右子树的样例个数不相等,意义是什么? return (tree['left'] + tree['right']) / 2.0#作用:剪枝#输入:待剪枝的树,剪枝所需的测试数据#输出:剪好的树def prune(tree, testData): if shape(testData)[0] == 0:#确认测试集是否为空 return getMean(tree) #如果两个分支有一个是子树,则对测试数据进行划分 if (isTree(tree['right']) or isTree(tree['left'])): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) # 对左子树进行剪枝 if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) # 对右子树进行剪枝 if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) #两个分支都不是子树 if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) #没有合并的误差 errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \ sum(power(rSet[:, -1] - tree['right'], 2)) #有点疑问,如果左右子树样例个数不相等,则treeMean意义是什么? treeMean = (tree['left'] + tree['right']) / 2.0 #合并后的误差 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) if errorMerge < errorNoMerge: print "merging" return treeMean else: return tree else: return tree
2、模型树
#作用:将数据集格式化成目标变量Y和自变量X,并求回归系数向量#输入:数据集#输出:回归系数向量,自变量,目标变量def linearSolve(dataSet): m, n = shape(dataSet) X = mat(ones((m, n))) Y = mat(ones((m, 1))) X[:, 1:n] = dataSet[:, 0:n-1]#X[:, 0]为截距,均为1 Y = dataSet[:, -1] xTx = X.T * X if linalg.det(xTx) == 0.0: raise NameError('This matrix is singular, cannot do inverse,\n\ try increasing the second value of ops') ws = xTx.I * (X.T * Y) return ws, X, Y#作用:求数据集的回归系数向量#输入:数据集#输出:回归系数向量def modelLeaf(dataSet): ws, X, Y = linearSolve(dataSet) return ws#作用:求数据集的方差#输入:数据集#输出:数据集的方差def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y - yHat, 2))#作用:返回模型的浮点数#输入:模型,?#输出:模型的浮点数def regTreeEval(model, inDat): return float(model)#作用:貌似没有出现过#输入:#输出:def modelTreeEval(model, inDat): n = shape(inDat)[1] X = mat(ones((1, n + 1))) X[:, 1:n+1] = inDat return float(X * model)#作用:预测测试数据在树中的值#输入:数,测试数据,树形式#输出:测试数据在树中的值def treeForeCast(tree, inData, modelEval = regTreeEval): if not isTree(tree): return modelEval(tree, inData) if inData[tree['spInd']] > tree['spVal']:#inData[]值比根节点大,进入左子树 if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval) else: return modelEval(tree['left'], inData) else:#inData[]值比根节点小,进入右子树 if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval) else: return modelEval(tree['right'], inData)#作用:预测测试数据集在树中的值#输入:数,测试数据集,树形式#输出:测试数据集在树中的值def createForeCast(tree, testData, modelEval = regTreeEval): m = len(testData) yHat = mat(zeros((m, 1))) for i in range(m): yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval) return yHat
3、使用Python的Tkinter库创建GUI
#coding:gbkfrom numpy import *from Tkinter import *import regTreesimport matplotlibmatplotlib.use('TkAgg')from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figure#作用:重绘新面板#输入:容许的误差下降值,切分的最少样本数#输出:无def reDraw(tolS, tolN): reDraw.f.clf() reDraw.a = reDraw.f.add_subplot(111) #检查复选框是否选中 if chkBtnVar.get():#复选框被选中,执行模型树 if tolN < 2:#切分的最少样本数不能少于2 tolN = 2 myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN)) yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval) else:#复选框没被选中,执行会归树 myTree = regTrees.createTree(reDraw.rawDat, ops = (tolS, tolN)) yHat = regTrees.createForeCast(myTree, reDraw.testDat) #绘制sine.txt中的点 reDraw.a.scatter(reDraw.rawDat[:, 0], reDraw.rawDat[:, 1], s = 5) #绘制得到的直线 reDraw.a.plot(reDraw.testDat, yHat, linewidth = 2.0) reDraw.canvas.show()#作用:得到tolN和tolS的取值#输入:无#输出:无def getInputs(): # 得到tolN的取值 try: tolN = int(tolNentry.get()) except: tolN = 10 print "enter Integer for tolN" tolNentry.delete(0, END) tolNentry.insert(0, '10') # 得到tolS的取值 try: tolS = float(tolSentry.get()) except: tolS = 1.0 print "enter Float for tolS" tolSentry.delete(0, END) tolSentry.insert(0, '1.0') return tolN, tolS#作用:绘制新的树#输入:无#输出:无def drawNewTree(): tolN, tolS = getInputs() reDraw(tolS, tolN)root = Tk()#绘制主面板reDraw.f = Figure(figsize = (5, 4), dpi = 100)reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master = root)reDraw.canvas.show()reDraw.canvas.get_tk_widget().grid(row = 0, columnspan = 3)#创建退出按钮#Button(root, text = "Quit", fg = 'black', command = root.quit).grid(row = 1, column = 2)#插入标题标签#Label(root, text = "Plot Place Holder").grid(row = 0, columnspan = 3)#插入tolN标签Label(root, text = "tolN").grid(row = 1, column = 0)#插入tolN文本框tolNentry = Entry(root)tolNentry.grid(row = 1, column = 1)tolNentry.insert(0, '10')#插入tolS标签Label(root, text = "tolS").grid(row = 2, column = 0)#插入tolS文本框tolSentry = Entry(root)tolSentry.grid(row = 2, column = 1)tolSentry.insert(0, '1.0')#创建ReDraw按钮Button(root, text = "ReDraw", command = drawNewTree).grid(row = 1, column = 2, rowspan = 3)#创建按钮勾选状态变量chkBtnVar = IntVar()#创建Model Tree按钮chkBtn = Checkbutton(root, text = "Model Tree", variable = chkBtnVar)chkBtn.grid(row = 3, column = 0, columnspan = 2)#sine.txt中的数据集reDraw.rawDat = mat(regTrees.loadDataSet('exp2.txt'))#sine.txt第一行即x轴的最大值到最小值reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)reDraw(1.0, 10)root.mainloop()
1 0
- 代码注释:机器学习实战第9章 树回归
- 代码注释:机器学习实战第8章 预测数值型数据:回归
- 【机器学习实战】第9章 树回归
- 第9章 机器学习实战之树回归
- 代码注释:机器学习实战第3章 决策树
- 【机器学习实战】第9章 树回归(Tree Regression)
- 代码注释:机器学习实战第2章 k-近邻算法
- 代码注释:机器学习实战第4章 基于概率论的分类方法:朴素贝叶斯
- 代码注释:机器学习实战第7章 利用AdaBoost元算法提高分类性能
- 代码注释:机器学习实战第11章 使用Apriori算法来发现频繁集
- 机器学习实战第5章 Logistic回归的weights
- 机器学习实战---读书笔记: 第5章 基Logistic回归
- 读书笔记:机器学习实战【第5章:Logistic回归】
- 【机器学习实战】第5章 Logistic回归
- 【机器学习实战】第5章 Logistic回归
- 第8章 机器学习实战之线性回归
- 机器学习实战第5章-logistic(逻辑回归)
- 机器学习实战-树回归
- iOS——UIActivityIndicatorView
- dfs-329. Longest Increasing Path in a Matrix[Hard]
- Node.js 笔记(一) nodejs、npm、express安装
- git提交忽略不必要的文件或文件夹
- MKL 函数cblas_?gemm参数分析
- 代码注释:机器学习实战第9章 树回归
- struts2 拦截器
- AndroidStudio报错: undefined reference to 'AndroidBitmap_getInfo'
- 1005. 继续(3n+1)猜想 (25)-PAT乙级
- nodejs中REPL执行环境解析
- AdaBoost算法
- JAVA基础知识整理(七) ---数据库
- Android调试自测工具 (Hugo、Timber、Scalpel)
- 三次握手和四次握手