ChiMerge 算法: 以鸢尾花数据集为例

来源:互联网 发布:淘宝店招商 编辑:程序博客网 时间:2024/05/01 13:52

ChiMerge 是监督的、自底向上的(即基于合并的)数据离散化方法。它依赖于卡方分析:具有最小卡方值的相邻区间合并在一起,直到满足确定的停止准则。

基本思想:对于精确的离散化,相对类频率在一个区间内应当完全一致。因此,如果两个相邻的区间具有非常类似的类分布,则这两个区间可以合并;否则,它们应当保持分开。而低卡方值表明它们具有相似的类分布。

参考:

1. ChiMerge:Discretization of numeric attributs

2. Chi算法

参考1的要点:

1、 最简单的离散算法是: 等宽区间。 从最小值到最大值之间,,均分为N等份, 这样, 如果 A, B为最小最大值, 则每个区间的长度为w=(B-A) / N, 则区间边界值为 A+W, A+2W, .  A+(N-1)W.

2、 还有一种简单算法,等频区间。区间的边界值要经过选择,使得每个区间包含大致相等的实例数量。比如说 N=10,每个区间应该包含大约10%的实例。

3、 以上两种算法有弊端:比如,等宽区间划分,划分为5区间,最高工资为50000,则所有工资低于10000的人都被划分到同一区间。等频区间可能正好相反,所有工资高于

50000的人都会被划分到50000这一区间中。这两种算法都忽略了实例所属的类型,落在正确区间里的偶然性很大。

4 C4CARTPVM算法在离散属性时会考虑类信息,但是是在算法实施的过程中间,而不是在预处理阶段。例如,C4算法(ID3决策树系列的一种),将数值属性离散为两个区间,而取这两个区间时,该属性的信息增益是最大的。

5、 评价一个离散算法是否有效很难,因为不知道什么是最高效的分类。

6、 离散化的主要目的是:消除数值属性以及为数值属性定义准确的类别。

7、 高质量的离散化应该是:区间内一致,区间之间区分明显。

8 ChiMerge算法用卡方统计量来决定相邻区间是否一致或者是否区别明显。如果经过验证,类别属性独立于其中一个区间,则这个区间就要被合并。

9 ChiMerge算法包括2部分:1、初始化,2、自底向上合并,当满足停止条件的时候,区间合并停止。

第一步:初始化

根据要离散的属性对实例进行排序:每个实例属于一个区间

第二步:合并区间,又包括两步骤

(1) 计算每一对相邻区间的卡方值

(2) 将卡方值最小的一对区间合并

预先设定一个卡方的阈值,在阈值之下的区间都合并,阈值之上的区间保持分区间。

卡方的计算公式:

 

参数说明;

m=2(每次比较的区间数是2个)

k=类别数量

Aij=i区间第j类的实例的数量

Ri=i区间的实例数量

Cj=j类的实例数量

N=总的实例数量

Eij= Aij的期望频率

10、卡方阈值的确定:先选择显著性水平,再由公式得到对应的卡方值。得到卡方值需要指定自由度,自由度比类别数量小1。例如,有3类,自由度为2,则90%置信度(10%显著性水平)下,卡方的值为4.6。阈值的意义在于,类别和属性独立时,有90%的可能性,计算得到的卡方值会小于4.6,这样,大于阈值的卡方值就说明属性和类不是相互独立的,不能合并。如果阈值选的大,区间合并就会进行很多次,离散后的区间数量少、区间大。用户可以不考虑卡方阈值,此时,用户可以考虑这两个参数:最小区间数,最大区间数。用户指定区间数量的上限和下限,最多几个区间,最少几个区间。

11 ChiMerge算法推荐使用.90.95.99置信度,最大区间数取1015之间.


举例:

取鸢尾花数据集作为待离散化的数据集合,使用ChiMerge算法,对四个数值属性分别进行离散化,令停机准则为max_interval=6。(韩家炜 数据挖掘概念与技术 第三版 习题3.12)

下面是我用Python写的程序,大致分两步:

第一步,整理数据

读入鸢尾花数据集,构造可以在其上使用ChiMerge的数据结构,即, 形如 [('4.3', [1, 0, 0]), ('4.4', [3, 0, 0]),...]的列表,每一个元素是一个元组,元组的第一项是字符串,表示区间左端点,元组的第二项是一个列表,表示在此区间各个类别的实例数目;

第二步,离散化

使用ChiMerge方法对具有最小卡方值的相邻区间进行合并,直到满足最大区间数(max_interval)为6

程序最终返回区间的分裂点


[python] view plain copy print?
  1. __author__ = "Yinlong Zhao (zhaoyl[at]sjtu[dot]edu[dot]cn)"  
  2. __date__ = "$Date: 2013/03/25 $"  
  3.   
  4. from time import ctime  
  5.   
  6. def read(file):  
  7.     '''''read raw date from a file '''  
  8.     Instances=[]  
  9.     fp=open(file,'r')  
  10.     for line in fp:  
  11.         line=line.strip('\n'#discard '\n'  
  12.         if line!='':  
  13.             Instances.append(line.split(','))  
  14.     fp.close()  
  15.     return(Instances)  
  16.   
  17.   
  18. def split(Instances,i):  
  19.     ''''' Split the 4 attibutes, collect the data of the ith attributs, i=0,1,2,3 
  20.         Return a list like [['0.2', 'Iris-setosa'], ['0.2', 'Iris-setosa'],...]'''  
  21.     log=[]  
  22.     for r in Instances:  
  23.         log.append([r[i],r[4]])  
  24.     return(log)  
  25.   
  26.   
  27. def count(log):  
  28.     '''''Count the number of the same record 
  29.        Return a list like [['4.3', 'Iris-setosa', 1], ['4.4', 'Iris-setosa', 3],...]'''  
  30.     log_cnt=[]  
  31.     log.sort(key=lambda log:log[0])  
  32.     i=0  
  33.     while(i<len(log)):  
  34.         cnt=log.count(log[i])#count the number of the same record  
  35.         record=log[i][:]  
  36.         record.append(cnt) # the return value of append is None  
  37.         log_cnt.append(record)  
  38.         i+=cnt#count the next diferent item   
  39.     return(log_cnt)  
  40.   
  41.   
  42. def build(log_cnt):  
  43.     '''''Build a structure (a list of truples) that ChiMerge algorithm works properly on it '''  
  44.     log_dic={}  
  45.     for record in log_cnt:  
  46.         if record[0not in log_dic.keys():  
  47.             log_dic[record[0]]=[0,0,0]  
  48.         if record[1]=='Iris-setosa':  
  49.             log_dic[record[0]][0]=record[2]  
  50.         elif record[1]=='Iris-versicolor':  
  51.             log_dic[record[0]][1]=record[2]  
  52.         elif record[1]=='Iris-virginica':  
  53.             log_dic[record[0]][2]=record[2]  
  54.         else:  
  55.             raise TypeError("Data Exception")  
  56.     log_truple=sorted(log_dic.items())  
  57.     return(log_truple)  
  58.   
  59. def collect(Instances,i):  
  60.     ''''' collect data for discretization '''  
  61.     log=split(Instances,i)  
  62.     log_cnt=count(log)  
  63.     log_tuple=build(log_cnt)  
  64.     return(log_tuple)  
  65.   
  66.   
  67. def combine(a,b):  
  68.     '''''  a=('4.4', [3, 1, 0]), b=('4.5', [1, 0, 2]) 
  69.          combine(a,b)=('4.4', [4, 1, 2])  '''  
  70.     c=a[:] # c[0]=a[0]  
  71.     for i in range(len(a[1])):  
  72.         c[1][i]+=b[1][i]  
  73.     return(c)  
  74.   
  75.   
  76. def chi2(A):  
  77.     ''''' Compute the Chi-Square value '''     
  78.     m=len(A);  
  79.     k=len(A[0])  
  80.     R=[]  
  81.     for i in range(m):  
  82.         sum=0  
  83.         for j in range(k):  
  84.             sum+=A[i][j]  
  85.         R.append(sum)  
  86.     C=[]  
  87.     for j in range(k):  
  88.         sum=0  
  89.         for i in range(m):  
  90.             sum+=A[i][j]  
  91.         C.append(sum)  
  92.     N=0  
  93.     for ele in C:  
  94.         N+=ele  
  95.     res=0  
  96.     for i in range(m):  
  97.         for j in range(k):  
  98.             Eij=R[i]*C[j]/N  
  99.             if Eij!=0:  
  100.                 res=res+(A[i][j]-Eij)**2/Eij  
  101.     return res  
  102.   
  103.   
  104. def ChiMerge(log_tuple,max_interval):  
  105.     ''''' ChiMerge algorithm  '''  
  106.     ''''' Return split points '''      
  107.     num_interval=len(log_tuple)  
  108.     while(num_interval>max_interval):                 
  109.         num_pair=num_interval-1  
  110.         chi_values=[]  
  111.         for i in range(num_pair):  
  112.             arr=[log_tuple[i][1],log_tuple[i+1][1]]  
  113.             chi_values.append(chi2(arr))  
  114.         min_chi=min(chi_values) # get the minimum chi value   
  115.         for i in range(num_pair-1,-1,-1): # treat from the last one  
  116.             if chi_values[i]==min_chi:  
  117.                 log_tuple[i]=combine(log_tuple[i],log_tuple[i+1]) # combine the two adjacent intervals  
  118.                 log_tuple[i+1]='Merged'  
  119.         while('Merged' in log_tuple): # remove the merged record  
  120.             log_tuple.remove('Merged')  
  121.         num_interval=len(log_tuple)  
  122.     split_points=[record[0for record in log_tuple]  
  123.     return(split_points)  
  124.   
  125.   
  126. def discrete(path):  
  127.     ''''' ChiMerege discretization of the Iris plants database '''  
  128.     Instances=read(path)  
  129.     max_interval=6  
  130.     num_log=4  
  131.     for i in range(num_log):  
  132.         log_tuple=collect(Instances,i) # collect data for discretization  
  133.         split_points=ChiMerge(log_tuple,max_interval) # discretize data using ChiMerge algorithm   
  134.         print(split_points)  
  135.       
  136.   
  137. if __name__=='__main__':  
  138.     print('Start: ' + ctime())  
  139.     discrete('c:\\Python33\\iris.data')  
  140.     print('End: ' + ctime())  


函数说明:

1. collect(Instances,i)

读入鸢尾花数据集,取第i个特征构造一个数据结构,以便使用ChiMerge算法。这个数据结构 形如 [('4.3', [1, 0, 0]), ('4.4', [3, 0, 0]),...]的列表,每一个元素是一个元组,元组的第一项是字符串,表示区间左端点,元组的第二项是一个列表,表示在此区间各个类别的实例数目

2. ChiMerge(log_tuple,max_interval)

ChiMerge算法,返回区间的分裂点


程序运行结果:

[python] view plain copy print?
  1. >>> ================================ RESTART ================================  
  2. >>>   
  3. Start: Mon Mar 25 21:31:40 2013  
  4. ['4.3''4.9''5.0''5.5''5.8''7.1']  
  5. ['2.0''2.3''2.5''2.9''3.0''3.4']  
  6. ['1.0''3.0''4.5''4.8''5.0''5.2']  
  7. ['0.1''1.0''1.4''1.7''1.8''1.9']  
  8. End: Mon Mar 25 21:31:40 2013  
  9. >>> 
0 0
原创粉丝点击