决策树之 sklearn 实现

来源:互联网 发布:java动态bean 编辑:程序博客网 时间:2024/06/06 18:13

官方文档:http://scikit-learn.org/stable/modules/tree.html

训练集:命名为 AllElectronics.csv 的文件

RID,age,income,student,credit_rating,Class_bugs_computer1,youth,high,no,fair,no2,youth,high,no,excellent,no3,middle_aged,high,no,fair,yes4,senior,medium,no,fair,yes5,senior,low,yes,fair,yes6,senior,low,yes,excellent,no7,middle_aged,low,yes,excellent,yes8,youth,medium,no,fair,no9,youth,low,yes,fair,yes10,senior,medium,yes,fair,yes11,youth,medium,yes,excellent,yes12,middle_aged,medium,no,excellent,yes13,middle_aged,high,yes,fair,yes14,senior,medium,no,excellent,no

安装 graphviz,在 mac 系统下安装 graphviz 只须要一句代码:

brew install graphviz

代码实现:

from sklearn.feature_extraction import DictVectorizerimport csvfrom sklearn import preprocessingfrom sklearn import treeimport osexists = os.path.exists("allEletronicInfomationGainOri.dot")print("dot 文件是否存在", exists)if exists == True:    print("删除了文件")    os.remove("allEletronicInfomationGainOri.dot")# electronic 电子的# 从 csv 文件中读取数据,并保存到 allEletronicsData 变量中allElectronicsData = open(r'AllElectronics.csv', 'r')# csv 提供的 reader 方法按行读取数据reader = csv.reader(allElectronicsData)# next 方法读取到 csv 文件的第一行数据headers = next(reader)print(headers)featureList = []labelList = []for row in reader:    # 将类别标签加入到 labelList 中    labelList.append(row[len(row) - 1])    rowDict = {}    for i in range(len(row) - 1):        rowDict[headers[i]] = row[i]    featureList.append(rowDict)print("featureList", featureList)# 实例化vec = DictVectorizer()dummyX = vec.fit_transform(featureList).toarray()print("dummyX", dummyX)print(vec.get_feature_names())lb = preprocessing.LabelBinarizer()dummyY = lb.fit_transform(labelList)print("dummy", str(dummyY))print("labelList", str(labelList))# criterion 评判标准# 这里的 “criterion” 选择的是 “entropy”,说明我们选择的是 ID3 的决策树算法clf = tree.DecisionTreeClassifier(criterion="entropy")clf = clf.fit(dummyX, dummyY)print("clf:" + str(clf))# mac 系统下安装 graphviz 只须要一行代码# brew install graphviz# 生成 pdf 的命令如下# dot -T pdf allEletronicInfomationGainOri.dot -o output.pdfwith open('allEletronicInfomationGainOri.dot', 'w') as f:    f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file=f)# 下面是测试代码oneRowX = dummyX[0, :]print("oneRowX", oneRowX)newRowX = oneRowXnewRowX[0] = 1newRowX[2] = 0print("newRowX", newRowX)# 预测代码(预测的代码,可以试试可不可以改进以下,把 14 条数据中的后 4 条作为测试集,看看是否能够预测正确)# clf.predict(newRowX) 是会出错的# 这里根据官方的代码,应该改成 2 维的结构,代码就能顺利运行了predictedY = clf.predict([newRowX])print("predictedY:" + str(predictedY))

参考以下学习笔记:
第6节–决策树算法实现(scikit-learn)
http://blog.csdn.net/youyuyixiu/article/details/52895111

原创粉丝点击