机器学习实战 第九章 源码勘误

来源:互联网 发布:淘宝抓取商品软件 编辑:程序博客网 时间:2024/05/07 07:05

机器学习实战 第九章 源码勘误

    最近学习了《机器学习实战》这本书的第九章-树回归,发现代码运行出错,就试着改了改,跟大家分享一下。

    这一章的代码总共有两个python文件,《regTrees.py》和《treeExplore.py》,其中《regTrees.py》运行时会出现两个错。

先贴上原版的《regTrees.py》

'''Created on Feb 4, 2011Tree-Based Regression Methods@author: Peter Harrington'''from numpy import *def loadDataSet(fileName):      #general function to parse tab -delimited floats    dataMat = []                #assume last column is target value    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine = map(float,curLine) #map all elements to float()        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]    return mat0,mat1def regLeaf(dataSet):#returns the value used for each leaf    return mean(dataSet[:,-1])def regErr(dataSet):    return var(dataSet[:,-1]) * shape(dataSet)[0]def linearSolve(dataSet):   #helper function used in two places    m,n = shape(dataSet)    X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y    xTx = X.T*X    if linalg.det(xTx) == 0.0:        raise NameError('This matrix is singular, cannot do inverse,\n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y)    return ws,X,Ydef modelLeaf(dataSet):#create linear model and return coeficients    ws,X,Y = linearSolve(dataSet)    return wsdef modelErr(dataSet):    ws,X,Y = linearSolve(dataSet)    yHat = X * ws    return sum(power(Y - yHat,2))def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1]    #if all the target variables are the same value: quit and return value    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1        return None, leafType(dataSet)    m,n = shape(dataSet)    #the choice of the best feature is driven by Reduction in RSS error from mean    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        for splitVal in set(dataSet[:,featIndex]):            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                 bestIndex = featIndex                bestValue = splitVal                bestS = newS    #if the decrease (S-bestS) is less than a threshold don't do the split    if (S - bestS) < tolS:         return None, leafType(dataSet) #exit cond 2    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3        return None, leafType(dataSet)    return bestIndex,bestValue#returns the best feature to split on                              #and the value used for that splitdef createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split    if feat == None: return val #if the splitting hit a stop condition return val    retTree = {}    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val)    retTree['left'] = createTree(lSet, leafType, errType, ops)    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree  def isTree(obj):    return (type(obj).__name__=='dict')def getMean(tree):    if isTree(tree['right']): tree['right'] = getMean(tree['right'])    if isTree(tree['left']): tree['left'] = getMean(tree['left'])    return (tree['left']+tree['right'])/2.0    def prune(tree, testData):    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)    #if they are now both leafs, see if we can merge them    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\            sum(power(rSet[:,-1] - tree['right'],2))        treeMean = (tree['left']+tree['right'])/2.0        errorMerge = sum(power(testData[:,-1] - treeMean,2))        if errorMerge < errorNoMerge:             print "merging"            return treeMean        else: return tree    else: return tree    def regTreeEval(model, inDat):    return float(model)def modelTreeEval(model, inDat):    n = shape(inDat)[1]    X = mat(ones((1,n+1)))    X[:,1:n+1]=inDat    return float(X*model)def treeForeCast(tree, inData, modelEval=regTreeEval):    if not isTree(tree): return modelEval(tree, inData)    if inData[tree['spInd']] > tree['spVal']:        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)        else: return modelEval(tree['left'], inData)    else:        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)        else: return modelEval(tree['right'], inData)        def createForeCast(tree, testData, modelEval=regTreeEval):    m=len(testData)    yHat = mat(zeros((m,1)))    for i in range(m):        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)    return yHat

出错的两处分别在binSplitDataSet(dataSet, feature, value)和chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4))这两个函数。

下面是修正版:

def binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]    return mat0,mat1

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1]    #if all the target variables are the same value: quit and return value    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1        return None, leafType(dataSet)    m,n = shape(dataSet)    #the choice of the best feature is driven by Reduction in RSS error from mean    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        temp=dataSet[:,featIndex].tolist()        #改了这里上下两行        for splitVal in set([a[0]for a in temp]):            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                 bestIndex = featIndex                bestValue = splitVal                bestS = newS    #if the decrease (S-bestS) is less than a threshold don't do the split    if (S - bestS) < tolS:         return None, leafType(dataSet) #exit cond 2    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3        return None, leafType(dataSet)    return bestIndex,bestValue#returns the best feature to split on                              #and the value used for that split

这两个错都有点奇葩,我现在还没学懂树回归,所以也没法完全搞懂这两个错究竟是什么意思,也许是python或工具包版本的问题…至少现在(python:2.7.11,numpy:1.11.0)是能运行了。





回去学学树回归的原理,再来发第九章的学习笔记奋斗



2 0
原创粉丝点击