apriori java实现

来源:互联网 发布:网络性能测试 编辑:程序博客网 时间:2024/06/13 08:56
/** * 频繁项集 */public class FrequentNode {    //包含哪些项    private String[] subjects;    //几项集    private int k;    //支持度计数    private int count = 0;    public FrequentNode(String subject,int count){        this.subjects = new String[]{subject};        this.k = 1;        this.count = count;    }    public FrequentNode(String[] subjects){        this.subjects = subjects;        this.k = subjects.length;        Arrays.sort(this.subjects);//排序,方便生成k+1    }    /**     * 生成k+1项集     * @param node     * @return     */    public FrequentNode merge(FrequentNode node){        if(k==1){            if(this.subjects[0].compareTo(node.subjects[0]) < 0){                return new FrequentNode(new String[]{subjects[0],node.subjects[0]});            }else{                return new FrequentNode(new String[]{node.subjects[0],subjects[0]});                            }        }        //前k-1项相同才连接        for(int i=0;i<k-1;i++){            if(!StringUtils.equals(this.subjects[i],node.subjects[i])){                return null;            }        }        //最后一项不同才连接        if(StringUtils.equals(this.subjects[this.k-1], node.subjects[this.k-1])){            return null;                }        String[] newFre = new String[this.k+1];        System.arraycopy(this.subjects, 0, newFre, 0, this.k-1);        if(this.subjects[k-1].compareTo(node.subjects[k-1]) < 0){            newFre[k-1] = subjects[k-1];            newFre[k] = node.subjects[k-1];        }else{            newFre[k-1] = node.subjects[k-1];            newFre[k] = subjects[k-1];        }        return new FrequentNode(newFre);    }    /**     * 给出自己的k-1子集     * @return     */    public List<FrequentNode> getChildren(){        List<FrequentNode> list = new ArrayList<FrequentNode>();        if(k==2){            return list;        }        if(k==3){            list.add(new FrequentNode(new String[]{subjects[1],subjects[2]}));        }        for(int i=0;i<k-2;i++){            String[] child = new String[k-1];            System.arraycopy(subjects, 0, child, 0, i);            System.arraycopy(subjects, i+1, child, i, k-i-1);            list.add(new FrequentNode(child));        }        return list;    }    /**     * 扫描文件 如果找到 把计数+1     * @param line     * @return     */    public void countIncrement(String[] line){        if(line.length < k){            return ;        }        for(String subject:subjects){            boolean flag = false;            for(String str:line){                if(StringUtils.equals(str, subject)){                    flag = true;                    break;                }            }            if(!flag){                return;            }        }        this.count = this.count + 1;    }    public int getK() {        return k;    }    public void setK(int k) {        this.k = k;    }    public int getCount() {        return count;    }    public void setCount(int count) {        this.count = count;    }    /**     * 两个频繁模式是否相同,我只看有哪些物品。计数和大小不看     */    @Override    public int hashCode() {        return Arrays.hashCode(subjects);    }    @Override    public boolean equals(Object obj) {        if(obj == null){            return false;        }        if(obj instanceof FrequentNode){            return Arrays.equals(this.subjects, ((FrequentNode)obj).subjects);        }        return false;    }    @Override    public String toString() {        return StringUtils.join(subjects,",")+"\t"+count;    }}
/** * 简单实现aprior算法 */public class Aprior {    //最大的项集长度    private int maxLength = 0;    //支持度阈值    private int support = 3;    //总共多少购物篮    private int totalCount = 0;    //    private String filePath = "D:\\R\\aprior.txt";    public static void main(String[] args) {        Aprior a = new Aprior();        a.startMining();    }    private void startMining(){        //step1 遍历文件产生一项集,记录下支持度,并且统计最大项集长度        List<FrequentNode> freq = getFrequent1();        //printResult(freq);        //step2开始挖        while(freq.size() > 0){            //连接            List<FrequentNode> next = contact(freq);            //printResult(next);            //剪枝            prune(next,freq);            //输出结果            printResult(next);            freq = next;        }    }    /**     * 产生1项集     * @param filePath     * @return     */    private List<FrequentNode> getFrequent1(){        Iterator<String> ite;        try {            ite = FileUtils.lineIterator(new File(filePath));        } catch (IOException e) {            e.printStackTrace();            return null;        }        Map<String,Integer> map = new HashMap<String,Integer>();        while(ite.hasNext()){            totalCount++;            String line = ite.next();            String[] subjects = line.split(",");            maxLength = Math.max(maxLength, subjects.length);            for(String subject:subjects){                Integer count = map.get(subject);                if(count == null){                    map.put(subject,1);                }else{                    map.put(subject, count +1);                }            }        }        List<FrequentNode> frequent1 = new ArrayList<FrequentNode>();        for(Entry<String,Integer> entry:map.entrySet()){            if(entry.getValue() >= support){//去掉支持度不够的1项集                frequent1.add(new FrequentNode(entry.getKey(),entry.getValue()));            }        }        return frequent1;    }    /**     * 连接k-1 ,生成k项集,根据k-1剪枝.     * @param src     * @return     */    private List<FrequentNode> contact(List<FrequentNode> src){        List<FrequentNode> next = new ArrayList<FrequentNode>();        for(int i=0;i<src.size()-1;i++){            for(int j=i+1;j<src.size();j++){                FrequentNode newNode = src.get(i).merge(src.get(j));                if(newNode != null){                    next.add(newNode);                }            }        }        return next;    }    private void prune(List<FrequentNode> next,List<FrequentNode> prev){        Iterator<FrequentNode> ite = next.iterator();        while(ite.hasNext()){            FrequentNode newNode = ite.next();            //扫描k-1看有没有,没有移除            boolean flag = false;            for(FrequentNode child:newNode.getChildren()){                if(!prev.contains(child)){                    flag = true;                    break;                }            }            if(flag){                ite.remove();            }        }        //扫描文件,看有没有        Iterator<String> fileIte;        try {            fileIte = FileUtils.lineIterator(new File(filePath));        } catch (IOException e) {            e.printStackTrace();            return;        }        while(fileIte.hasNext()){            String line = fileIte.next();            String[] subjects = line.split(",");            Iterator<FrequentNode> ite1 = next.iterator();            while(ite1.hasNext()){                FrequentNode newNode = ite1.next();                newNode.countIncrement(subjects);            }        }        Iterator<FrequentNode> ite2 = next.iterator();        while(ite2.hasNext()){            FrequentNode newNode = ite2.next();            if(newNode.getCount() < support){                ite2.remove();            }        }    }    private void printResult(List<FrequentNode> list){        for(FrequentNode node:list){            System.out.println(node.toString());        }    }}
0 0
原创粉丝点击