python_nba_tree

来源:互联网 发布:2016年java程序员工资 编辑:程序博客网 时间:2024/05/22 04:51
#! /usr/bin/env python#coding=utf-8import pandas as pdroot="F:/Data/data/"nba=pd.read_csv(root+"nba.txt",parse_dates=["Date"])nba.columns=["Date","Start","Visitor Team","Visitor Pts","Home Team",             "Home Pts","Score Type","OT?","Notes"]nba=nba.drop(nba.columns[1],axis=1)#增加HomeWin特征nba["HomeWin"]=nba["Visitor Pts"]<nba["Home Pts"]#把主场获胜球队的数据保存到NumPy数组里y_true=nba["HomeWin"].values#数组保存的是类别数据#增加HomeLastWinVisitorLastWin特征nba["HomeLastWin"]=True#之前我并未先设定这两个特征,后面for循环添加两天特征却不显示????nba["VisitorLastWin"]=Truefrom collections import defaultdict#创建默认字典,存储球队上次比赛的结果won_last=defaultdict(int)#字典的键为球队,值为能否赢得上一场比赛????不懂什么时候设定的键和值????for index,row in nba.iterrows():    homeTeam=row["Home Team"]    visitorTeam=row["Visitor Team"]    row["HomeLastWin"]=won_last[homeTeam]    row["VisitorLastWin"]=won_last[visitorTeam]    nba.ix[index]=row    won_last[homeTeam]=row["HomeWin"]    won_last[visitorTeam]=not row["HomeWin"]from sklearn.tree import DecisionTreeClassifierclf=DecisionTreeClassifier(random_state=14)x_previousWins=nba[["HomeLastWin","VisitorLastWin"]].values#查看数据值from sklearn.model_selection import cross_val_score#之前引文python版本问题有错误import numpy as npscores=cross_val_score(clf,x_previousWins,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.574679022572#再增加HomeTeamRanksHigher特征expStanding=pd.read_csv(root+"nba1.txt",skiprows=[0])#去掉第一行nba["HomeTeamRanksHigher"]=0for index,row in nba.iterrows():    homeTeam=row["Home Team"]    visitorTeam=row["Visitor Team"]    if homeTeam=="New Orleans Pelicans":        homeTeam="New Orleans Hornets"    elif visitorTeam=="New Orleans Pelicans":        visitorTeam="New Orleans Hornets"    homeRank=expStanding[expStanding["Team"]==homeTeam]["Rk"].values[0]    visitorRank=expStanding[expStanding["Team"]==visitorTeam]["Rk"].values[0]    row["HomeTeamRanksHigher"]=int(homeRank>visitorRank)    nba.ix[index]=rowx_homehigher=nba[["HomeLastWin","VisitorLastWin","HomeTeamRanksHigher"]].values#查看数据值clf=DecisionTreeClassifier(random_state=14)scores=cross_val_score(clf,x_homehigher,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.596657347967#用两支球队上场比赛的情况作为另一个特征:HomeTeamWonLastlast_match_winner=defaultdict(int)nba["HomeTeamWonLast"]=0for index,row in nba.iterrows():    homeTeam=row["Home Team"]    visitorTeam=row["Visitor Team"]    teams=tuple(sorted([homeTeam,visitorTeam]))    row["HomeTeamWonLast"]=1 if last_match_winner[teams]==homeTeam else 0#??????????    nba.ix[index]=row    winner=homeTeam if row["HomeWin"] else visitorTeam    last_match_winner[teams]=winner#本场比赛中两支球队胜败情况#HomeTeamRanksHigherHomeTeamWonLast两个特征来做数据集x_lastwinner=nba[["HomeLastWin","VisitorLastWin","HomeTeamRanksHigher","HomeTeamWonLast"]].valuesclf=DecisionTreeClassifier(random_state=14)scores=cross_val_score(clf,x_lastwinner,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.603482432526from sklearn.preprocessing import LabelEncoderencoding=LabelEncoder()#把字符串类型的球队转化为整型encoding.fit(nba["Home Team"].values)#将主场球队的名称化为整型homeTeam=encoding.transform(nba["Home Team"].values)#visitorTeam=encoding.transform(nba["Visitor Team"].values)#不造为啥得到的是行向量x_teams=np.vstack([homeTeam,visitorTeam]).T#矩阵转置后每行两个特征#由于决策树会把特征看成是连续型的,所以改用二进制来表示from sklearn.preprocessing import OneHotEncoderonehot=OneHotEncoder()x_teams_expanded=onehot.fit_transform(x_teams).todense()clf=DecisionTreeClassifier(random_state=14)scores=cross_val_score(clf,x_teams_expanded,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.595154276248from sklearn.ensemble import RandomForestClassifierclf=RandomForestClassifier(random_state=14)scores=cross_val_score(clf,x_teams,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.583773383033x_all=np.hstack([x_lastwinner,x_teams])scores=cross_val_score(clf,x_all,y_true,scoring="accuracy")print "精确度:{}".format(np.mean(scores))#精确度:0.579208945952paramete_space={    "max_features":[2,10 ],    "n_estimators":[100,],    "criterion":["gini","entropy"],    "min_samples_leaf":[2,4,6]}"""from sklearn.grid_search import GridSearchCVclf=RandomForestClassifier(random_state=14)grid=GridSearchCV(clf,paramete_space)grid.fit(x_all,y_true)print "精确度:{}".format(np.mean(scores))print grid.best_estimator_"""
0 0
原创粉丝点击