R语言,ID4.5算法,实现离散型与连续型数据决策树构建及打印
来源:互联网 发布:飞鱼打印软件 编辑:程序博客网 时间:2024/05/18 09:42
本人的第一篇文章,趁着我们的数据挖掘课设的时间,把实现的决策树代码,拿出来分享下。有很多漏洞和缺陷,还有很多骇客思想的成分,但是总之,能实现,看网上的代码,能用的其实也没几个。废话不多说,直接看代码
特别鸣谢博主skyonefly的代码
附上链接:R语言决策树代码
########## ID4.5算法实现决策树##################################################Part1 基础函数########################################################计算香农熵calShannonEnt <- function(dataSet,labels){ numEntries<-length(dataSet[,labels]) key<-rep("a",numEntries) for(i in 1:numEntries) key[i]<-dataSet[i,labels] shannonEnt<-0 prob<-table(key)/numEntries for(i in 1:length(prob)) shannonEnt=shannonEnt-prob[i]*log(prob[i],2) return(shannonEnt)}#划分数据集splitDataSet <- function(dataSet,axis,value,tempSet = dataSet){ retDataSet = NULL for(i in 1:nrow(dataSet)){ if(dataSet[i,axis] == value){ tempDataSet = tempSet[i,] retDataSet = rbind(retDataSet,tempDataSet) } } rownames(retDataSet) = NULL return (retDataSet)}#选择信息增益最大的内部节点chooseBestFeatureToSplita <- function(dataSet,labels, bestInfoGain){ numFeatures = ncol(dataSet) - 1 baseEntropy = calShannonEnt(dataSet,labels) #最大信息增益 bestFeature = -1 for(i in 1: numFeatures){ featureLabels = levels(factor(dataSet[,i])) # featureLabels = as.numeric(featureLabels) newEntropy = 0.0 SplitInfo = 0.0 for( j in 1:length(featureLabels)){ subDataSet = splitDataSet(dataSet,i,featureLabels[j]) prob = length(subDataSet[,1])*1.0/nrow(dataSet) newEntropy = newEntropy + prob*calShannonEnt(subDataSet,labels) SplitInfo = -prob*log2(prob) + SplitInfo } infoGain = baseEntropy - newEntropy GainRadio = infoGain/SplitInfo if(SplitInfo > 0){ GainRadio = infoGain/SplitInfo if(GainRadio > bestInfoGain){ bestInfoGain = infoGain bestFeature = i } } } return (bestFeature)} #返回频数最高的列标签majorityCnt <- function(classList){ classCount = NULL count = as.numeric(table(classList)) majorityList = levels(as.factor(classList)) if(length(count) == 1){ return (majorityList[1]) }else{ f = max(count) return (majorityList[which(count == f)][1]) }}#判断类标签是否只有一个因子水平oneValue <- function(classList){ count = as.numeric(table(classList)) if(length(count) == 1){ return (TRUE) }else return (FALSE)}#树的打印printTree <- function(tree){ df <- data.frame() col <- 1 point <- c() count <- 0 for(i in 1:length(tree)){ if(rownames(tree)[[i]] == 'labelFeature'){ df[i,col] = tree[[i]] col = col + 1 count = count + 1 point[count] = col names(point)[count] = tree[[i]] }else if(rownames(tree)[[i]] == 'FeatureValue'){ for(j in 1:length(point)){ len = grep(names(point)[j],tree[[i]]) if(length(len) >= 1){ col = point[j] } } df[i,col] = tree[[i]] col = col + 1 }else{ df[i,col] = tree[[i]] col = col + 1 } } for(i in 1:nrow(df)){ for(j in 1:ncol(df)){ if(is.na(df[i,j])){ df[i,j] = "" } } } return(df)}numericCol <- NULL#加载数据load <- function(filePath){ dataSet<-read.table(filePath,header = T) dataSet = dataSet[,-1] numericCol <<- c() for(i in 1:length(ncol(dataSet))){ if(is.numeric(dataSet[,i])){ numericCol[i] <<- colnames(dataSet)[i] } } dataSet = as.matrix(dataSet) trainSet = dataSet[1:14,] testSet = dataSet[15:21,] preSet = dataSet[22,,drop=FALSE] result = list(trainSet,testSet,preSet) return (result)}#############################################Part2 连续值处理######################################################将数据集全部转换成离散型toDiscrete <- function(numericCol,dataSet,labels){ s <- c() if(length(numericCol)>0){ for(i in 1:length(numericCol)){ if(numericCol[i] %in% colnames(dataSet)){ atr = dataSet[,numericCol[i]] atr = as.numeric(atr) Split = BestSplit(dataSet,numericCol[i],labels)#计算最佳分裂点 s[i] = Split names(s)[i] = numericCol[i] dataSet[,numericCol[i]] = toDiscreteCol(atr, Split)#将新离散型数据的写入dataSet } } } result = list(dataSet,s) return(result)}#计算基尼系数jini <- function(data,labels){ nument<-length(data[,1]) key<-rep("a",nument) for(i in 1:nument) { key[i]<-data[i,labels] } ent<-0 prob<-table(key)/nument for(i in 1:length(prob)) ent=ent+prob[i]*prob[i] ent = 1 - ent return(ent)}#找到基尼系数最小的中值,作为分裂点BestSplit <- function(dataSet,colname,labels){ numFeatures = nrow(dataSet) - 1 bestSplit = -1 sorted = sort(as.numeric(dataSet[,colname])) atr = dataSet[,colname] bestGini = 999 for(i in 1: numFeatures){ middle = (sorted[i] + sorted[i+1])/2 tempCol = toDiscreteCol(atr,middle) dataSet[,colname] = tempCol featureLabels = levels(factor(dataSet[,colname])) Gini = 0.0 for( j in 1:length(featureLabels)){ subDataSet = splitDataSet(dataSet,colname,featureLabels[j]) prob = length(subDataSet[,1])*1.0/nrow(dataSet) Gini = Gini + prob*jini(subDataSet,labels) } if(Gini <= bestGini){ bestGini = Gini count = middle } } return (count)} #将该列转换成离散型toDiscreteCol <- function(atr, Split){ str <- c() for(i in 1:length(atr)){ if(atr[i] <= Split){ str[i] = 'a' }else{ str[i] = 'b' } } return(str)}#############################################Part3 决策树的构建####################################################递归建立生成树creatTree <- function(dataSet,labels,bestInfoGain){ result = toDiscrete(numericCol,dataSet,labels) tempSet = dataSet dataSet = result[[1]] Split = result[[2]] decision_tree = list() classList = dataSet[,labels] #判断是否属于同一类 if(oneValue(classList)){ label = classList[1] return (rbind(decision_tree,label)) } #是否在矩阵中只剩Label标签了,若只剩Label标签,则都分完了 if((ncol(dataSet) == 1)){ label = majorityCnt(classList) decision_tree = rbind(decision_tree,labels) return (decision_tree) } #选择bestFeature作为分割属性 bestFeature = chooseBestFeatureToSplita(dataSet,labels,bestInfoGain) bestFeatureName = colnames(dataSet)[bestFeature] #所有信息增益都小于bestInfoGain if(bestFeature == -1){ label = majorityCnt(classList) decision_tree = rbind(decision_tree,label) return (decision_tree) } labelFeature = colnames(dataSet)[bestFeature] #添加内部节点 decision_tree = rbind(decision_tree,labelFeature) #选中了那个标签作为此次分类标签 attriCol = dataSet[,bestFeature] temp_tree = data.frame() stayData = dataSet factor=levels(as.factor(attriCol)) for(j in 1:length(factor)){ #分裂成小数据集 dataSet = splitDataSet(stayData,bestFeature,factor[j],tempSet) if(bestFeatureName %in% numericCol){ if(factor[j] == 'a'){ character = '<' }else{ character = '>' } numpd = paste(character,Split[bestFeature]) FeatureValue = paste(bestFeatureName,numpd) }else{ FeatureValue = paste(bestFeatureName,factor[j]) } decision_tree = rbind(decision_tree, FeatureValue ) #删除已使用属性列 dataSet = dataSet[,-bestFeature,drop=FALSE] #递归调用这个函数 temp_tree = creatTree(dataSet,labels,bestInfoGain) decision_tree = rbind(decision_tree,temp_tree) } return (decision_tree)}#############################################Part4 预测函数########################################################算正确率test <- function(testSet,labels){ count <- 0 for(i in 1:nrow(testSet)){ pre = predict(testSet[i,,drop=FALSE],myTree) if(is.null(pre)){ }else{ if(pre == testSet[i,labels]){ count = count + 1 } } } return(count/nrow(testSet))}#打印预测结果pre <- function(testSet){ for(i in 1:nrow(testSet)){ pre = predict(testSet[i,,drop=FALSE],myTree) return(pre) }}#预测函数predict <- function(testSet, df, row = 1, col = 1){ if(df[row,col] == "yes"| df[row,col] == "no"){ return(df[row,col]) }else{ if(length(grep(" ",df[row,col])) == 0 & nchar(df[row,col]) > 0){ labelFeature = df[row,col]#获取属性名称 FeatureValue = testSet[,labelFeature][1]#获取属性值 if(labelFeature %in% numericCol){ count <- 1 rows <- c() Split <- NULL for(i in row:nrow(df)){ if(count > 2){ break } if(nchar(df[i,col+1]) > 0){ Split = as.numeric(strsplit(df[i,col+1]," ")[[1]][3]) rows[count] = i count = count + 1 } } if(as.numeric(FeatureValue) < Split){ prediction = predict(testSet, df, rows[1] + 1, col + 2) }else{ prediction = predict(testSet, df, rows[2] + 1, col + 2) } return(prediction) }else{ sum = paste(labelFeature, FeatureValue) for(i in row:nrow(df)){ if(df[i,col+1] == sum){ row = i break } } prediction = predict(testSet, df, row + 1, col + 2) return(prediction) } } }}###########################################Part5 加载数据并运行####################################################加载数据myData = load("C:/Users/gino2/Desktop/R/DecisiontreeSampledata.txt")#"D:/littlestar/DecisiontreeSampledata.txt" 离散型数据#"D:/littlestar/DecisiontreeSampledataContinue.txt" 连续型数据#创建决策树#三个参数:训练数据集,类标签列,增益率的阈值tree = creatTree(myData[[1]],'buys_computer',0.12)myTree <- printTree(tree)print(myTree)#预测函数rightProb = test(myData[[2]],'buys_computer')print(paste("测试集的准确率为:",rightProb))print(paste("第22行的预测结果为:",pre(myData[[3]])))
运行结果
下面是我用的数据
纯离散型
ID age income student credit_rating buys_computer1 youth high no fair no2 youth high no excellent no3 middle_age high no fair yes4 senior medium no fair yes5 senior low yes fair yes6 senior low yes excellent no7 middle_age low yes excellent yes8 youth medium no fair no9 youth low yes fair yes10 senior medium yes fair yes11 youth medium yes excellent yes12 middle_age medium no excellent yes13 middle_age high yes fair yes14 senior medium no excellent no15 youth medium no fair no16 youth low yes excellent yes17 middle_age medium yes fair yes18 middle_age high no excellent yes19 middle_age low no excellent yes20 senior low yes excellent yes21 senior high no fair no22 middle_age medium no fair NA
连续型+离散型
Id age income student credit_rating buys_computer1 16 high no fair no2 25 high no excellent no3 34 high no fair yes4 41 medium no fair yes5 45 low yes fair yes6 48 low yes excellent no7 39 low yes excellent yes8 27 medium no fair no9 25 low yes fair yes10 49 medium yes fair yes11 18 medium yes excellent yes12 36 medium no excellent yes13 38 high yes fair yes14 50 medium no excellent no15 19 medium no fair no16 25 low yes excellent yes17 35 medium yes fair yes18 31 high no excellent yes19 38 low no excellent yes20 45 low yes excellent yes21 43 high no fair no22 36 medium no fair NA
在文章的最后,我要感谢给予我动力和能量去完成这篇代码的人
阅读全文
0 0
- R语言,ID4.5算法,实现离散型与连续型数据决策树构建及打印
- R语言实现决策树算法
- 【决策树】ID3算法理解与R语言实现
- 常用连续型分布介绍及R语言实现
- 常用连续型分布介绍及R语言实现
- 常用连续型分布介绍及R语言实现
- R语言决策树算法
- R语言实现决策树
- 机器学习算法(二)——决策树分类算法及R语言实现方法
- 离散属性的决策树算法实现--基于西瓜2.0数据
- R语言学习系列(数据挖掘之决策树算法实现--ID3代码篇)
- R语言学习系列(数据挖掘之决策树算法实现--ID3代码篇)
- 连续数据与离散数据
- R语言之决策树算法
- 连续属性的决策树算法实现--基于西瓜3.0数据
- 数据分析之美:决策树R语言实现
- 决策树回归R语言实现
- 决策树与R语言(RPART)
- HDU 5572 An Easy Physics Problem(计算几何)——2015ACM/ICPC亚洲区上海站-重现赛
- Android App 插件化
- c++中输入输出流操作
- 基于photo sphere viewer的360全景展示
- myeclipse 2014新建maven web 项目步骤
- R语言,ID4.5算法,实现离散型与连续型数据决策树构建及打印
- Java的操作符instanceof的使用和注意点
- 基于嵌入式linux的freetype矢量字体简单显示的实现
- 帧同步和状态同步
- NOIP2017模拟赛(5) 总结
- codeforces 2016-2017 NTUWFTSC E Lines Game
- ReactJS入门实战——基于ReactJS构架的图片画廊应用
- ServletContext
- lib和dll的区别和联系