获取频繁项集和关联规则的Python实现【先验算法】

来源:互联网 发布:乘法口诀8的手指算法 编辑:程序博客网 时间:2024/06/03 05:37

# -*- coding: utf-8 -*-#参数设定data_file = 'F:\\user_match_stat\\itemset.txt'#文件格式csv,形如:item1,item2,item3#每个事务占一行frequent_itemsets_save_file = 'F:\\user_match_stat\\frequent_itemsets.txt'rules_readable_file_dest = 'F:\\user_match_stat\\rules_readable.txt'rules_csv_file_dest = 'F:\\user_match_stat\\rules_csv.txt'rules_ranked_desc_by_liftrate = 'F:\\user_match_stat\\rules_liftrate_desc.txt'#格式:itemset A,itemset B,support,confidence,liftrate#itemset的项之间用|分隔minsup = 0.01   #最小支持度,所有规则的支持度需要大于等于此值minconf = 0.000001  #最小置信度  #通过计算得到的统计量transaction_cnt = 0  #总事务数min_sup_cnt = 0     #最小支持记数transaction_cnt_distinct = 0 #总不同事务数         #全局数据结构transaction_cnt_dict = {} # dict(tuple , int)  frequent_itemsets_verified = {}  #dict(tuple , int)frequent_itemsets_candidate = {}  #list(tuple , [set])frequent_itemsets = {} #dict( int , dict(tuple , int) ) 保存所有的频率项集,第一个int是项集的长度closed_frequent_itemsets = {} #保存所有闭频率项集distinct_item_in_candidate_itemsets = set()distinct_item_in_transaction_cnt_dict = set()item_transaction_list_dict = {} # {元素 , set(tranidx in transaction_cnt_dict)}hitted_transaction_set = set()  #获取事务集def prepare_data() :    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        file = open(data_file)    print 'Reading data from ' + data_file + '...'        pre_transaction_cnt_dict = {}    n = 0    for line in file :   #读取事务列表        line = line.strip()   #不然会有\n        if line == '' :            continue                n = n + 1        item_list = line.split(',')        item_list.sort()        tp = tuple(item_list)                if tp in pre_transaction_cnt_dict :            pre_transaction_cnt_dict[tp] = pre_transaction_cnt_dict[tp] + 1        else :            pre_transaction_cnt_dict[tp] = 1                #测试用            #if n > 20000000000 :        #    break        #总事务数            transaction_cnt = n    print 'Totally read ' + str(n) + ' lines.'    file.close()        #初始化transaction_cnt_dict和item_transaction_list_dict    tranidx = 1    for tp in pre_transaction_cnt_dict :        transaction_cnt_dict[tranidx] = pre_transaction_cnt_dict[tp]        for item in tp :            if item in item_transaction_list_dict :                item_transaction_list_dict[item].add(tranidx)            else :                item_transaction_list_dict[item] = set((tranidx,))        tranidx = tranidx + 1            del pre_transaction_cnt_dict        transaction_cnt_distinct = len(transaction_cnt_dict)        min_sup_cnt = int(transaction_cnt * minsup)    print 'The number of total transactions is '+str(transaction_cnt) + '.'    print 'The number of distinct transactions is '+str(transaction_cnt_distinct) + '.'    print 'The min support count is '+str(min_sup_cnt) + '.'    print 'Function prepare_data done.'    return        #得到频繁一项集,直接从item_transaction_list_dict里面统计即可        def get_frequent_itemsets_1() :    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        hitted_transaction_set.clear()    frequent_itemsets[1] = {}    for item in item_transaction_list_dict :        #cnt = len(item_transaction_list_dict[item]) 严重错误        cnt = 0        for tranidx in item_transaction_list_dict[item] :            cnt = cnt + transaction_cnt_dict[tranidx]                if cnt >= min_sup_cnt :            frequent_itemsets[1][(item,)] = cnt            #fill hitted_transaction_set                for tranidx in item_transaction_list_dict[item] :                  hitted_transaction_set.add(tranidx)       print 'Function get_frequent_itemsets_1 done'    return         #获取候选项,根据frequent_itemsets_verified填充frequent_itemsets_candidate,清空frequent_itemsets_verified#用Fk-1 * Fk-1法#如果得不到新的K项,返回-1def get_candidates(k):    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        frequent_itemsets_candidate.clear()    #为每一个项生成一个序号,只遍历比自己序号大的项    transaction_cnto_lay1 = 0    for tp_out in frequent_itemsets_verified :        transaction_cnto_lay1 = transaction_cnto_lay1 + 1        transaction_cnto_lay2 = 0        for tp_in in frequent_itemsets_verified :            transaction_cnto_lay2 = transaction_cnto_lay2 + 1            if transaction_cnto_lay2 > transaction_cnto_lay1 :                if k == 2 :#长度为1的时候单独处理                    #把大的放后面                    #保证项集的有序性                    if tp_out[0] > tp_in[0]:                        tmp_tuple = tp_in + tp_out                        frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)                    else:                        tmp_tuple = tp_out + tp_in                        frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)                else :                    #比较前K-2项,如果全部相同,则产生一个K项                    if tp_out[:-1] == tp_in[:-1] :                        #把大的放后面                        #保证项集的有序性                        if tp_out[-1] > tp_in[-1] :                            tmp_tuple = tp_out[:-1] + tp_in[-1:] + tp_out[-1:]                            frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)                        else :                            tmp_tuple = tp_out[:-1] + tp_out[-1:] + tp_in[-1:]                            frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)                       if len(frequent_itemsets_candidate) == 0 :        return -1        #通过frequent_itemsets_verified剪枝    #检查第一个候选项的所有子集是否都在frequent_itemsets_verified中    #非常有用,用减少60%以上的候选    if k != 2 :        del_list = []        for tp in frequent_itemsets_candidate :            for i in range(0 , len(tp)):                test_tp = tp[:i] + tp[i+1:]                if test_tp not in frequent_itemsets_verified :                    del_list.append(tp)                    break                            print '-------------------------------------------------'        print '........Total ' + str(len(frequent_itemsets_candidate)) + ' candidates before cut.'                    print '........Cut ' + str(len(del_list)) + ' candidates.'        print '-------------------------------------------------'                for tp in del_list :            del frequent_itemsets_candidate[tp]        if len(frequent_itemsets_candidate) == 0 :        return -1            frequent_itemsets_verified.clear()    return 0  #检查frequent_itemsets_candidate的支持度计数,将符合条件的插入到frequent_itemsets_verified#清空frequent_itemsets_candidate    #如果检查全部是不频繁的,返回-1def check_candidates_1():    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        print 'Start check candidates.'    total_candidates = len(frequent_itemsets_candidate)    print 'Total ' + str(total_candidates) + ' candidates need to check.'        frequent_itemsets_verified.clear()        hitted_transaction_set.clear()    cnt = 0                pct = 0    for tp in frequent_itemsets_candidate :        #打印进度        cnt = cnt + 1        new_pct = cnt*100/total_candidates        if new_pct != pct :            pct = new_pct            print str(pct) + '%'                tmp_set = None        for item in tp :            if item not in item_transaction_list_dict :                print 'Error!!!'            else :                if tmp_set is None :                    tmp_set = item_transaction_list_dict[item]                else :                    tmp_set = tmp_set & item_transaction_list_dict[item]                         suport_cnt = 0           if len(tmp_set) != 0 :            for ele in tmp_set :                suport_cnt = suport_cnt + transaction_cnt_dict[ele]           if suport_cnt >= min_sup_cnt :            frequent_itemsets_verified[tp] = suport_cnt            #记录命中的事务项,在过滤事务时去掉不在这里面的事务            for ele in tmp_set :                hitted_transaction_set.add(ele)                frequent_itemsets_candidate.clear()    if len(frequent_itemsets_verified) == 0 :        return -1    print 'Finish check candidates.'    return 0       #将frequent_itemsets_verified中的内容append到frequent_itemsetsdef save_frequent_itemsets(k):    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        #去掉k-1级中的非闭频繁项集    del_list_k_1 = []    if k - 1 >= 1 :        for tp in frequent_itemsets[k - 1] :            sup = frequent_itemsets[k - 1][tp]            for tp_in in frequent_itemsets_verified :                if sup == frequent_itemsets_verified[tp_in] : #这里也许可以把条件放宽一点,不一定要绝对相等                    #print '........................enter'                    if set(tp).issubset(set(tp_in)) :                        del_list_k_1.append(tp)                        break     print '-------------------------------------------------'    print '...Cutting unclosed frequent itemsets in k = ' + str(k - 1) + '.'    print '........Total ' + str(len(frequent_itemsets[k - 1])) + ' itemsets before cut.'                print '........Cut ' + str(len(del_list_k_1)) + ' itemsets for not closed.'    print '-------------------------------------------------'        for tp_del in del_list_k_1 :        del frequent_itemsets[k - 1][tp_del]           if k not in frequent_itemsets :        frequent_itemsets[k] = {}    for tp in frequent_itemsets_verified :        frequent_itemsets[k][tp] = frequent_itemsets_verified[tp]    return    #得到不同项的数量def get_distinct_item_in_candidate_itemsets() :    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        distinct_item_in_candidate_itemsets.clear()    for tp in frequent_itemsets_candidate :        for item in tp :            distinct_item_in_candidate_itemsets.add(item)    return    ##过滤掉不再有用的数据,以减少计算量def filter_data(k) :    global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        print 'Function filter_data begin.'        #在裁剪数据之前统计    print '---------------------------------------------------'    print '...Stat data before cut data.'    item_num = len(item_transaction_list_dict)    print '......Total ' + str(item_num) + ' items in item_transaction_list_dict.'    tmp_num = 0L    for item in item_transaction_list_dict :        tmp_num = tmp_num + len(item_transaction_list_dict[item])    print '......The average length of transaction set for each item is ' + str(round(tmp_num/item_num)) + '.'    print '......The number of total transactions is ' + str(len(transaction_cnt_dict)) + '.'    print '---------------------------------------------------'        #直接重构数据    tran_del_list = set() #测试用    item_del_list = set() #测试用    tranidx_itemset_dict = {}   #{tranidx : set(item)}    for item in item_transaction_list_dict :        if item in distinct_item_in_candidate_itemsets : #只考虑在候选中出现的item            for tranidx in item_transaction_list_dict[item] :                if tranidx in hitted_transaction_set : #只考虑上次命中的                    if tranidx in tranidx_itemset_dict :                        tranidx_itemset_dict[tranidx].add(item)                    else :                        tranidx_itemset_dict[tranidx] = set((item, ))                else :#对于上次没有命中的,不再考虑                    tran_del_list.add(tranidx) #测试用        else : #如果item在候选中没有出现            item_del_list.add(item)    print '...' + str(len(item_del_list)) + ' items were cut for no appearence in candidates.'    print '...' + str(len(tran_del_list)) + ' transactions were cut for no match in k-1 level.'        new_itemset_cnt_dict = {}  #{tuple(tra) , cnt} , 用它来重构数据    merge_cnt = 0    lt_k_cnt = 0    for tranidx in tranidx_itemset_dict :        if len(tranidx_itemset_dict[tranidx]) >= k : #只取项数大于K的(即与候选项的交集大于等于K)            tp = tuple(tranidx_itemset_dict[tranidx])            if tp in new_itemset_cnt_dict :                merge_cnt = merge_cnt + 1                new_itemset_cnt_dict[tp] = new_itemset_cnt_dict[tp] + transaction_cnt_dict[tranidx]            else :                new_itemset_cnt_dict[tp] = transaction_cnt_dict[tranidx]        else :            lt_k_cnt = lt_k_cnt + 1    del tranidx_itemset_dict  #不再有用,删掉    print '...' + str(lt_k_cnt) + ' transactions were cut for item number less than k.'    print '...' + str(merge_cnt) + ' transactions were cut for merge.'      transaction_cnt_dict.clear()    item_transaction_list_dict.clear()    tranidx = 1    for tp in new_itemset_cnt_dict :        transaction_cnt_dict[tranidx] = new_itemset_cnt_dict[tp]        for item in tp :            if item in item_transaction_list_dict :                item_transaction_list_dict[item].add(tranidx)            else :                item_transaction_list_dict[item] = set((tranidx,))          tranidx = tranidx + 1    del new_itemset_cnt_dict             #在裁剪数据之后统计    print '---------------------------------------------------'    print '...Stat data after cut data.'    item_num = len(item_transaction_list_dict)    print '......Total ' + str(item_num) + ' items in item_transaction_list_dict.'    tmp_num = 0L    for item in item_transaction_list_dict :        tmp_num = tmp_num + len(item_transaction_list_dict[item])    print '......The average length of transaction set for each item is ' + str(round(tmp_num/item_num)) + '.'    print '......The length of transaction_cnt_dict is ' + str(len(transaction_cnt_dict)) + '.'    print '---------------------------------------------------'        print 'Function filter_data done.'    return  def get_rules() :    #对于每一个频繁项集L,如果support_cnt L / support_cnt S >= min_conf ,则输出 S =》 (L - S)    #其中S是L的真子集        global transaction_cnt_dict    global frequent_itemsets_verified    global frequent_itemsets_candidate    global frequent_itemsets    global distinct_item_in_candidate_itemsets    global distinct_item_in_transaction_cnt_dict    global transaction_cnt    global min_sup_cnt    global transaction_cnt_distinct    global item_transaction_list_dict    global hitted_transaction_set        #最好直接输出到文件,不然结构太大    #rule_dict = {}  #{tp L : {tp S , (conf , lift )}}        file_rule_readable = open(rules_readable_file_dest , 'w')    file_rule_csv = open(rules_csv_file_dest , 'w')        rule_list = [] #[{rulestr:liftrate}]        #计算一共有多少层    layer_cnt = len(frequent_itemsets)        #从最外层开始,到第二层    for kk in range(layer_cnt , 1 , -1 ) :        for itemset_kk in frequent_itemsets[kk] :            set_itemset_kk = set(itemset_kk)            sup_kk = round((frequent_itemsets[kk][itemset_kk] + 0.0 )/ transaction_cnt , 4 )            #遍历一遍K-1到1级频项,            for kkk in range(kk-1 , 0 , -1) :                for itemset_kkk in frequent_itemsets[kkk] :                    set_itemset_kkk = set(itemset_kkk)                    #如果是set_itemset_kk的真子集,则输出规则                    if set_itemset_kk.issuperset(set_itemset_kkk) :                        tmp_conf = round((frequent_itemsets[kk][itemset_kk] + 0.0) / frequent_itemsets[kkk][itemset_kkk] , 4 )                        #只考虑大于最小置信度的情况                        if tmp_conf >= minconf :                            #提升度等于(tmp_conf - 原支持度)/ 原支持度                            #提升度的正确定义是 P(A 交 B) / P(A)*P(B) ,即A与B同时发生的概率 与 A与B独立的情况下同时发生的概率的比值                            #相当于 tmp_conf / 原支持度                            #取值范围是0到无穷大 , 小于1说明两者互斥 ,大于1说明两者的发生互有提升                            set_dest = set_itemset_kk - set_itemset_kkk                            list_dest = list(set_dest)                            list_dest.sort()                            tp_dest = tuple(list_dest)                            tmp_lift = 0                            tmp_sup = 0                            tmp_length = len(tp_dest)                            if tp_dest in frequent_itemsets[tmp_length] :                                tmp_sup = round((frequent_itemsets[tmp_length][tp_dest] + 0.0) / transaction_cnt , 4 )                                tmp_lift = round( tmp_conf / tmp_sup , 4 )                            else :                                tmp_lift = None                                tmp_sup = None                                                        #输出规则                            #itemset_kkk 》》 set_dest , with support  sup_kk , confidence tmp_conf , lift tmp_list , orirate tmp_sup                            tmp_str = str(itemset_kkk) + '  >>>>  ' + str(tp_dest)  + '   with support: ' + str(sup_kk) + \                            '   confidence: ' + str(tmp_conf) + '   liftrate: ' + str(tmp_lift) + '   origin: ' + str(tmp_sup) + '\n'                            file_rule_readable.write(tmp_str)                                                        rule_list.append({tmp_lift : tmp_str } )                                                          #csv文件                            tmp_str = '|'.join(itemset_kkk)  + ',' + '|'.join(tp_dest) + ',' + str(sup_kk) + \                            ',' +  str(tmp_conf) + ',' + str(tmp_lift) + '\n'                            file_rule_csv.write(tmp_str)                                                            file_rule_readable.close()    file_rule_csv.close()        file_rule_liftrate_desc = open(rules_ranked_desc_by_liftrate , 'w')    rule_list.sort()    for tmp_dict in rule_list :        file_rule_liftrate_desc.write(tmp_dict[tmp_dict.keys()[0]])    file_rule_liftrate_desc.close() if __name__ == '__main__' :    prepare_data()    get_frequent_itemsets_1()        #为循环做准备    #-------------------------------------------------------------------------    #计算distinct_item_cnt_now的初始值    distinct_item_cnt_now = len(item_transaction_list_dict)    #K的初始值为1    k = 1    #准备frequent_itemsets_verified    for item in frequent_itemsets[1] :        frequent_itemsets_verified[item] = frequent_itemsets[1][item]    print ''    print ''    #-------------------------------------------------------------------------        #start loop    #-------------------------------------------------------------------------    print 'Enter loop.'    while True :        #打印上一次所得        print 'Totally ' + str(len(frequent_itemsets[k])) + ' candidates are verified.'        print '##########################################'        print ''                #K自增        k = k + 1                print ''        print '##########################################'        print 'k = ' + str(k)                #找出候选项,要尽可能少,如果没有找到,则结束        if get_candidates(k) == -1 :            print 'Get no candidates.'            break        print 'Totally get ' + str(len(frequent_itemsets_candidate)) + ' candidates.'                #根据候选做一次数据过滤,使transaction_cnt_dict尽可能的小        #如果项少了30%以上,则过滤数据        get_distinct_item_in_candidate_itemsets()        distinct_item_cnt_new = len(distinct_item_in_candidate_itemsets)        print 'There are ' + str(distinct_item_cnt_now) + ' distinct items in transaction_cnt_dict.'        print 'There are ' + str(distinct_item_cnt_new) + ' distinct items in candidates.'                #if distinct_item_cnt_new < 0.7 * distinct_item_cnt_now :        if True :            #过滤掉transaction_cnt_dict中不再有用的项            filter_data(k)            distinct_item_cnt_now = distinct_item_cnt_new                    #检查候选项集        if check_candidates_1() == -1 :            print 'No candidates is frequent.'            break                #保存frequent_itemsets_verified中的频繁项集到frequent_itemsets        save_frequent_itemsets(k)          #end loop    print 'Exit loop.'    #-------------------------------------------------------------------------            #将得到的候频繁项集写入到文件中    file = open(frequent_itemsets_save_file , 'w' )    for tmp_k in frequent_itemsets :        idx = 0        for itemset in frequent_itemsets[tmp_k] :            idx = idx + 1            tmp_str = str(tmp_k) + ',' + str(idx) + ',' + '|'.join(itemset) + ',' + \            str(frequent_itemsets[tmp_k][itemset]) + '\n'            file.write(tmp_str)    file.close()    print 'All the frequent itemsets were saved in ' + frequent_itemsets_save_file + '.'        get_rules()