机器学习实战笔记_09_树回归_代码错误修正

来源:互联网 发布:淘宝网百丽女鞋 编辑:程序博客网 时间:2024/06/05 04:38
from numpy import *
def loadDataSet(fileName):    dataMat = []    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine = map(float,curLine)        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] #第一处错误修正    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] #第一处错误修正    return mat0, mat1def regLeaf(dataSet):#returns the value used for each leaf    return mean(dataSet[:,-1])def regErr(dataSet):    return var(dataSet[:,-1]) * shape(dataSet)[0]def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1]    #if all the target variables are the same value: quit and return value    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1        return None, leafType(dataSet)    m,n = shape(dataSet)    #the choice of the best feature is driven by Reduction in RSS error from mean    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        # for splitVal in set(dataSet[:,featIndex]):        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): #第二处错误修正            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    #if the decrease (S-bestS) is less than a threshold don't do the split    if (S - bestS) < tolS:        return None, leafType(dataSet) #exit cond 2    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3        return None, leafType(dataSet)    return bestIndex,bestValue#returns the best feature to split on                              #and the value used for that splitdef createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)   #choose the best split    if feat == None: return val #if the splitting hit a stop condition return val    retTree = {}    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val)    retTree['left'] = createTree(lSet, leafType, errType, ops)    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree# testMat    = mat(eye(4))# mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)       ## print testMat# print mat0# print mat1myDat = loadDataSet('ex00.txt')myMat=mat(myDat)createTree(myMat)


本人用的是python 2.7,但是敲击书上的源代码,总是运行错误,发现代码有两处错误,可以把我的代码和书上的代码对照,

错误地方已经标出regTrees.py


0 0
原创粉丝点击