数据挖掘 Apriori算法的Java代码实现

来源:互联网 发布:大数据双创中心是什么 编辑:程序博客网 时间:2024/05/01 19:23

简单说明

学院开了一门课《数据挖掘与机器学习》,要求我们计算机1、2两个班的全部同学选修这门课,包括课程实验。教材采用王振武、徐慧编著的《数据挖掘算法原理与实现》。教材里面提供的代码是C++代码,而由于本人更习惯使用Java语言编程,为了深入理解算法原理和过程,完成实验任务,于是用Java语言实现了Apriori关联规则挖掘算法。

Apriori算法

Apriori算法的基本思想是通过对数据库的多次扫描来计算项集的支持度,发现所有的频繁项集从而生成关联规则。

其实就是从一堆数据里面找出出现次数最多的数据组合,找出来的组合就是强关联的。

产生频繁项集的过程包括连接和剪枝两步。

连接步:
假设有两个有序3-项集L1 = (A, B, C),L2 = (A, B, D)。则L1和L2可连接产生4-项集C1 = (A, B, C, D)。
剪枝步:
频繁k-项集的任何自己必须是频繁项集,根据这个性质去除连接步产生的不满足支持度的k-项集。

代码如下:

//Item.javaimport java.util.ArrayList;/** * 项集 */@SuppressWarnings("hiding")public class Item<String> extends ArrayList<String> {    private static final long serialVersionUID = 1L;    /**     * 判断本项集与next项集是否可连接     *      * @param next     * @return     */    public boolean linkable(Item<String> next) {        if (this.size() != next.size())            return false;        for (int i = 0; i < this.size() - 1; i++) {            if (!get(i).equals(next.get(i)))                return false;        }        return true;    }    /**     * 对项集去重     */    public void unique() {        String s = get(0);        for (int i = 1; i < size(); i++) {            String t = get(i);            while (t.equals(s)) {                remove(t);                if (i < size())                    t = get(i);                else {                    break;                }            }            s = t;        }    }}
//Apriori.javaimport java.util.ArrayList;import java.util.Comparator;import java.util.HashMap;import java.util.Iterator;import java.util.Set;/** * 算法实体 */public class Apriori {    private HashMap<String, Integer> oneElementSet; // 一项集    private ArrayList<Item<String>> sourceItems; // 原始数据    private ArrayList<HashMap<Item<String>, Integer>> rankFrequentSets; // 各级频繁项集    private int minValue; // 最小阈值    Apriori(int size, int minValue) {        oneElementSet = new HashMap<>();        sourceItems = new Item<>();        rankFrequentSets = new Item<>();        this.minValue = minValue;    }    /**     * 添加项集     *      * @param item     */    public void addItem(Item<String> item) {        // 对项集排序后添加        item.sort(new Comparator<String>() {            @Override            public int compare(String arg0, String arg1) {                return arg0.compareTo(arg1);            }        });        sourceItems.add(item);    }    public ArrayList<HashMap<Item<String>, Integer>> getRankFrequentSets() {        return rankFrequentSets;    }    /**     * 找出一项集     *      * @return     */    public HashMap<String, Integer> findOneElementItems() {        for (Item<String> list : sourceItems) {            for (String s : list) {                if (!oneElementSet.containsKey(s)) {                    oneElementSet.put(s, 1);                } else {                    oneElementSet.put(s, oneElementSet.get(s) + 1);                }            }        }        return oneElementSet;    }    /**     * 产生频繁一项集     *      * @return     */    public HashMap<Item<String>, Integer> obtainFrequentOneElementSet() {        HashMap<Item<String>, Integer> map = new HashMap<>();        for (String key : oneElementSet.keySet()) {            int value = oneElementSet.get(key);            if (value >= minValue) {                Item<String> item = new Item<>();                item.add(key);                map.put(item, value);            }        }        rankFrequentSets.add(0, map);        return map;    }    /**     * 产生频繁K项集 剪枝步     *      * @param k     * @return     */    public HashMap<Item<String>, Integer> obtainFrequentSet(int k) {        Item<Item<String>> items = link(k);        HashMap<Item<String>, Integer> freSet = new HashMap<>();        for (Item<String> item : items) {            int count = 0;            for (Item<String> source : sourceItems) {                boolean flag = true;                for (String s : item) {                    if (!source.contains(s)) {                        flag = false;                        break;                    }                }                if (flag) {                    count++;                }            }            if (count >= minValue) {                freSet.put(item, count);            }        }        if (freSet.size() <= 0)            return null;        rankFrequentSets.add(k - 1, freSet);        return freSet;    }    /**     * 连接产生K项集     *      * @param k     * @return     */    public Item<Item<String>> link(int k) {        Item<Item<String>> items = new Item<>();        HashMap<Item<String>, Integer> map = rankFrequentSets.get(k - 2);        Set<Item<String>> keys = map.keySet();        Iterator<Item<String>> iterator = keys.iterator();        if (k == 2) {            for (int i = 0; i < keys.size(); i++) {                Item<String> item = iterator.next();                Iterator<Item<String>> iterator2 = keys.iterator();                for (int j = 0; j < i + 1; j++) {                    iterator2.next();                }                for (int j = i + 1; j < keys.size(); j++) {                    Item<String> item2 = iterator2.next();                    Item<String> instance = new Item<>();                    instance.add(item.get(0));                    instance.add(item2.get(0));                    items.add(instance);                }            }            return items;        } else {            for (int i = 0; i < keys.size() - 1; i++) {                Item<String> item = iterator.next();                Iterator<Item<String>> iterator2 = keys.iterator();                for (int j = 0; j < i + 1; j++) {                    iterator2.next();                }                for (int j = i + 1; j < keys.size(); j++) {                    Item<String> item2 = iterator2.next();                    if (item.linkable(item2)) {                        Item<String> instance = new Item<>();                        for (int n = 0; n < k - 1; n++) {                            instance.add(item.get(n));                        }                        instance.add(item2.get(k - 2));                        items.add(instance);                    }                }            }        }        return items;    }}
//Main.javaimport java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;import java.util.Scanner;public class Main {    public static void main(String[] args) {        int size;        int minValue;        Scanner scanner = new Scanner(System.in);        System.out.print("事务数:");        size = scanner.nextInt();        System.out.print("最小阈值:");        minValue = scanner.nextInt();        Apriori apriori = new Apriori(size, minValue);        scanner.nextLine();        for (int i = 0; i < size; i++) {            Item<String> item = new Item<>();            System.out.print("输入第" + (i + 1) + "项:");            String line = scanner.nextLine();            Scanner scanner2 = new Scanner(line);            while (scanner2.hasNext()) {                item.add(scanner2.next());            }            scanner2.close();            item.unique();//对输入的项集去重            apriori.addItem(item);        }        scanner.close();        HashMap<String, Integer> oneElementSet = apriori.findOneElementItems();        Iterator<String> iterator = oneElementSet.keySet().iterator();        while (iterator.hasNext()) {            String key = iterator.next();            System.out.println(key + ":" + oneElementSet.get(key));        }        apriori.obtainFrequentOneElementSet();        int k = 2;        while (apriori.obtainFrequentSet(k++) != null)            ;        ArrayList<HashMap<Item<String>, Integer>> rankSets = apriori.getRankFrequentSets();        Item<String> item = null;        HashMap<Item<String>, Integer> map = null;        for (int i = 0; i < k - 2; i++) {            map = rankSets.get(i);            System.out.println("第 " + (i + 1) + " 级频繁项集:");            Iterator<Item<String>> iterator2 = map.keySet().iterator();            while (iterator2.hasNext()) {                item = iterator2.next();                System.out.print("{ ");                for (String s : item) {                    System.out.print(s + " ");                }                System.out.print("}\t");                System.out.println(map.get(item));            }        }        System.out.println("最终频繁项集:");        Iterator<Item<String>> iterator2 = map.keySet().iterator();        while (iterator2.hasNext()) {            item = iterator2.next();            System.out.print("{ ");            for (String s : item) {                System.out.print(s + " ");            }            System.out.print("}\t");            System.out.println(map.get(item));        }    }}
0 0
原创粉丝点击