交叉验证获得最佳二元决策树深度

来源:互联网 发布:数组作为形参 编辑:程序博客网 时间:2024/06/02 04:22

10折交叉验证各个深度下的平均误差然后看看哪个深度会对预测产生明显的优势

# -*- coding:utf-8 -*-import numpyimport matplotlib.pyplot as plotfrom sklearn import treefrom sklearn.tree import DecisionTreeRegressorfrom sklearn.externals.six import StringIO#构造简单的数据y=x+randomnpoints=100#使x在-0.5和0.5之间共100份xplot=[(float(i)/float(npoints)-0.5) for i in range(npoints+1)]#多行变多列x=[[s] for s in xplot]#生成随机数并生成y=x+randomnumpy.random.seed(1)y=[s+numpy.random.normal(scale=0.1) for s in xplot]nrow=len(x)depthlist=[1,2,3,4,5,6,7]xvalmse=[]nxval=10#使用各个深度数据循环尝试for idepth in depthlist:    #确定深度数据后分块交叉验证循环    for ixval in range(nxval):        #分割数据        itest=[a for a in range(nrow) if a%nxval==ixval]        itrain=[a for a in range(nrow) if a%nxval!=ixval]        xtrain=[x[r] for r in itrain]        xtest=[x[r] for r in itest]        ytrain=[y[r] for r in itrain]        ytest=[y[r] for r in itest]        #训练        treemodel=DecisionTreeRegressor(max_depth=idepth)        treemodel.fit(xtrain,ytrain)        #预测        treeprediction=treemodel.predict(xtest)        #算误差        error=[ytest[r]-treeprediction[r] for r in range(len(xtest))]        if ixval==0:            ooserrors=sum([e*e for e in error])        else:            ooserrors+=sum([e*e for e in error])    xvalmse.append(ooserrors/nrow)plot.plot(depthlist,xvalmse)plot.axis('tight')plot.xlabel('depth')plot.ylabel('mse')plot.show()

100个样本下3层深度最优
这里写图片描述

1000个样本下4层更好
这里写图片描述

原创粉丝点击