通过FP-Tree找到置信度最高的组合

来源:互联网 发布:美国 非农 数据 美元 编辑:程序博客网 时间:2024/05/17 08:20
package com.winning.dm.pathway;import java.io.BufferedReader;import java.io.FileReader;import java.io.IOException;import java.sql.Connection;import java.sql.ResultSet;import java.sql.SQLException;import java.sql.Statement;import java.util.ArrayDeque;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.Deque;import java.util.HashMap;import java.util.HashSet;import java.util.LinkedList;import java.util.List;import java.util.Map;import java.util.Set;import java.util.Map.Entry;public class FindPath {    /** 频繁模式的最小支持数 **/    private int minSuport;    private TreeNode treeRoot;    static Map<String, TreeNode> header;    public List<String> lt1;    public List<String> lt2 = new ArrayList<String>();    public int getMinSuport() {        return minSuport;    }    public void setMinSuport(int minSuport) {        this.minSuport = minSuport;    }    public double getConfident() {        return confident;    }    public void setConfident(double confident) {        this.confident = confident;    }    /** 关联规则的最小置信度 **/    private double confident;    private Map<String, Integer> freqMaps;    /**     * @作者:      * @时间: 2016-9-20 下午1:41:00     * @描述: 从若干个文件中读入Transaction Record,同时把所有项设置为decideAttr     * @param filenames     * @return     * @备注:     */    public List<List<String>> readTransRocords(String[] filenames) {        Set<String> set = new HashSet<String>();        List<List<String>> transaction = null;        if (filenames.length > 0) {            transaction = new LinkedList<List<String>>();            for (String filename : filenames) {                try {                    FileReader fr = new FileReader(filename);                    BufferedReader br = new BufferedReader(fr);                    try {                        String line = null;                        // 一项事务占一行                        while ((line = br.readLine()) != null) {                            if (line.trim().length() > 0) {                                // 每个item之间用","分隔                                String[] str = line.split(",");                                // 每一项事务中的重复项需要排重                                Set<String> record = new HashSet<String>();                                for (String w : str) {                                    record.add(w);                                    set.add(w);                                }                                List<String> rl = new ArrayList<String>();                                rl.addAll(record);                                transaction.add(rl);                            }                        }                    } finally {                        br.close();                    }                } catch (IOException ex) {                    System.out.println("Read transaction records failed."                            + ex.getMessage());                    System.exit(1);                }            }        }        return transaction;    }    /**     * 计算事务集中每一项的频数     *      * @param transRecords     * @return     */    private Map<String, Integer> getFrequency(List<List<String>> transRecords) {        Map<String, Integer> rect = new HashMap<String, Integer>();        for (List<String> record : transRecords) {            for (String item : record) {                Integer cnt = rect.get(item);                if (cnt == null) {                    cnt = new Integer(0);                }                rect.put(item, ++cnt);            }        }        return rect;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:40:14     * @描述: 构建FP-Tree     * @param transRecords     * @备注:     */    public void buildFPTree(List<List<String>> transRecords) {        transRecords.size();        // 计算每项的频数        final Map<String, Integer> freqMap = getFrequency(transRecords);        for (List<String> transRecord : transRecords) {            Collections.sort(transRecord, new Comparator<String>() {                @Override                public int compare(String o1, String o2) {                    return freqMap.get(o2) - freqMap.get(o1);                }            });        }        FPGrowth(transRecords, null);    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:57:29     * @描述: 构建FP-Tree     * @param cpb     * @param postModel     * @备注:     */    private void FPGrowth(List<List<String>> cpb, LinkedList<String> postModel) {        Map<String, Integer> freqMap = getFrequency(cpb);        freqMaps = freqMap;        Map<String, TreeNode> headers = new HashMap<String, TreeNode>();        for (Entry<String, Integer> entry : freqMap.entrySet()) {            String name = entry.getKey();            int cnt = entry.getValue();            // 每一次递归时都有可能出现一部分模式的频数低于阈值            if (cnt >= minSuport) {                TreeNode node = new TreeNode(name);                node.setCount(cnt);                headers.put(name, node);            }        }        treeRoot = buildSubTree(cpb, freqMap, headers);        header = headers;    }    /**     * 把所有事务插入到一个FP树当中     *      * @param transRecords     * @param F1     * @return     */    private TreeNode buildSubTree(List<List<String>> transRecords,            final Map<String, Integer> freqMap,            final Map<String, TreeNode> headers) {        // 虚根节点        TreeNode root = new TreeNode();        for (List<String> transRecord : transRecords) {            LinkedList<String> record = new LinkedList<String>(transRecord);            TreeNode subTreeRoot = root;            TreeNode tmpRoot = null;            if (root.getChildren() != null) {                // 延已有的分支,令各节点计数加1                while (!record.isEmpty()                        && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {                    tmpRoot.countIncrement(1);                    subTreeRoot = tmpRoot;                    record.poll();                }            }            // 长出新的节点            addNodes(subTreeRoot, record, headers);        }        return root;    }    /**     * 往特定的节点下插入一串后代节点,同时维护表头项到同名节点的链表指针     *      * @param ancestor     * @param record     * @param headers     */    private void addNodes(TreeNode ancestor, LinkedList<String> record,            final Map<String, TreeNode> headers) {        while (!record.isEmpty()) {            String item = (String) record.poll();            // 单个项的出现频数必须大于最小支持数,否则不允许插入FP树。达到最小支持度的项都在headers中。每一次递归根据条件模式基本建立新的FPTree时,把要把频数低于minSuport的排除在外,这也正是FPTree比穷举法快的真正原因            if (headers.containsKey(item)) {                TreeNode leafnode = new TreeNode(item);                leafnode.setCount(1);                leafnode.setParent(ancestor);                ancestor.addChild(leafnode);                TreeNode header = headers.get(item);                TreeNode tail = header.getTail();                if (tail != null) {                    tail.setNextHomonym(leafnode);                } else {                    header.setNextHomonym(leafnode);                }                header.setTail(leafnode);                addNodes(leafnode, record, headers);            }        }    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:39:05     * @描述: 找到最优的子路径     * @param ls     * @return     * @备注:     */    public List<TreeNode> findRode(List<TreeNode> ls) {        TreeNode ts = new TreeNode();        if (ls == null) {            return null;        } else {            int i = 0;            // 遍历所有孩子节点,找到孩子节点中置信度最高的节点            for (TreeNode tn : ls) {                // 设置第一个节点是置信度最高的节点                if (i == 0) {                    ts.setName(tn.getName());                    ts.setCount(tn.getCount());                } else {                    // 通过比较找到置信度最高的节点                    if (tn.getCount() > ts.getCount()) {                        ts.setName(tn.getName());                        ts.setCount(tn.getCount());                    }                }                i++;            }        }        TreeNode tm = new TreeNode();        // 通过name找到计算的置信度最高节点,并且加入链表        for (TreeNode tn : ls) {            if (tn.getName().equals(ts.getName())) {                tm = tn;                this.lt2.add(tm.getName());            }        }        if (tm.getChildren() == null) {            return null;        } else {            // 通过迭代找到置信度最高路径            findRode(tm.getChildren());            return tm.getChildren();        }    }    /**     *      * @作者:      * @时间: 2016-9-18 下午3:10:07     * @描述: 广度优先算法找到树中项目名称为tn的所有子节点。     * @param root     * @备注:     */    public void breadthFirst(TreeNode root, TreeNode tn) {        // 通过队列能够更好的进行广度优先搜素        Deque<TreeNode> nodeDeque = new ArrayDeque<TreeNode>();        TreeNode node = root;        nodeDeque.add(node);        while (!nodeDeque.isEmpty()) {            // 每次只计算第一个节点            node = nodeDeque.peekFirst();            if ((node.getName()).equals(tn.getName())) {                tn.addChild(node);            }            // 搜索完一个节点就去除这个节点并且添加它的子节点            nodeDeque.remove();            // 获得节点的子节点,对于二叉树就是获得节点的左子结点和右子节点            List<TreeNode> children = node.getChildren();            if (children != null && !children.isEmpty()) {                for (TreeNode child : children) {                    nodeDeque.add(child);                }            }        }    }    /**     *      * @作者:      * @时间: 2016-9-18 下午4:41:42     * @描述: 找到项目使用率最高的疾病     * @param root     *            树的根节点     * @param it     *            输入的数据,条件数据包含输入数据的置信度最高的路径     * @备注:     */    public List<String> findPath(TreeNode root, int[] it) {        // 如果it为null,那么赋值root节点为2,因为数据的根节点是2        if (it == null || it.length == 0) {            int[] its = new int[1];            its[0] = 2;            it = its;        }        int last = it[it.length - 1];        TreeNode tn = new TreeNode();        tn.setName(String.valueOf(last));        // 通过遍历找到最后数据的树中所有对应项目        this.breadthFirst(root, tn);        // 对tn.child进行排序分析        List<TreeNode> lt = tn.getChildren();        int[] sm = new int[lt.size()];        for (int i = 0; i < sm.length; i++) {            sm[i] = lt.get(i).getCount();        }        Sort.Sort2(sm);        for (int i = 0; i < sm.length; i++) {            if (!isMatch(lt, sm[i], it)) {                continue;            } else {                // 找到最后的数据的                for (TreeNode tmm : lt) {                    if (tmm.getCount() == sm[i]) {                        findRode(tmm.getChildren());                    }                }                return lt1;            }        }        return null;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:07:43     * @描述: 找到父节点     * @param children     * @return     * @备注:     */    public List<String> findFather(TreeNode children) {        List<String> lt1 = new ArrayList<String>();        TreeNode tm = children;        while (tm.getParent() != null) {            lt1.add(tm.getName());            tm = tm.getParent();        }        return lt1;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:08:09     * @描述: 判断最优路径是否包含所有输入项目     * @param lt     * @param sm     * @param it     * @return     * @备注:     */    public boolean isMatch(List<TreeNode> lt, int sm, int[] it) {        this.lt1 = new ArrayList<String>();        // 通过对比判断最优路径是否包含所有int数组里面的项目如果包含返回true,如果不包含返回false        for (TreeNode ts : lt) {            if (ts.getCount() == sm) {                lt1 = findFather(ts);                for (int n = 0; n < it.length; n++) {                    if (!lt1.contains(String.valueOf(it[n]))) {                        return false;                    }                }            }        }        return true;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:09:08     * @描述: 找到最优路径,如果没有项目输入则返回全局最优,如果输入不为null,返回包含输入项目的最优路径     * @param str     * @return     * @备注:     */    public List<String> findPath(String[] str,String itemfilepath) {        List<String> ss = new ArrayList<String>();//      String infile = "E:\\路径分析\\全部排序数据.txt";        FindPath bt = new FindPath();        List<List<String>> trans = bt.readTransRocords(new String[] { itemfilepath });        bt.setConfident(0);        bt.setMinSuport(0);        bt.buildFPTree(trans);        // 判断输入是否为Null        if (str == null || str.equals(null)) {            ss = this.findPath(bt.treeRoot.getChildren().get(0), null);        } else {            int[] it = new int[str.length];            for (int i = 0; i < str.length; i++) {                it[i] = Integer.parseInt(str[i]);            }            Sort.Sort(it);            ss = this.findPath(bt.treeRoot.getChildren().get(0), it);        }        // 合并两个字符串,两个字符串一个是置信度高项目组合,一个是置信度低的项目组合        for (int i = 0; i < this.lt2.size(); i++) {            ss.add(lt2.get(i));        }        return ss;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午1:43:19     * @描述: 把List<String>排序后变成String,中间以逗号隔开。     * @return     * @备注:     */    public StringBuffer changeToString(List<String> ls) {        StringBuffer result = new StringBuffer();        int[] it = new int[ls.size()];        for (int i = 0; i < it.length; i++) {            it[i] = Integer.valueOf(ls.get(i));        }        // 把list中的数据进行排序,然后添加,目的是为了匹配数据找到对应的GHDJID        Sort.Sort1(it);        for (int i = 0; i < it.length; i++) {            if (i == it.length - 1) {                result = result.append(it[i]);            } else {                result = result.append(it[i] + ",");            }        }        return result;    }    /**     *      * @作者:      * @时间: 2016-9-28 上午11:12:47     * @描述: 根据使用的项目排序后对应的字符串进行匹配找到对应的GHDJID     * @param sb     *            排序之后的项目项目之间以","隔开     * @return GHDJID     * @throws IOException     * @备注:     */    public String matchGHDJID(StringBuffer sb,String itemfilepath,String patientfile) throws IOException {        String GHDJID = new String();        int row = 0;        String sn = sb.toString();        // 匹配排序后的项目        BufferedReader reader1 = new BufferedReader(new FileReader(                itemfilepath));        for (int i = 1; !sn.equals(reader1.readLine()); i++) {            row = i;        }        // 匹配GHDJID        BufferedReader reader2 = new BufferedReader(new FileReader(                patientfile));        String str = "";        String[] strs = new String[4];        for (int i = 1; (str = reader2.readLine()) != null; i++) {            if (i == row) {                strs = str.split(",");                GHDJID = strs[1];            }        }        GHDJID = GHDJID.substring(1, GHDJID.length() - 1);        return GHDJID;    }    /**     *      * @作者:      * @时间: 2016-9-20 下午3:03:04     * @描述: 根据GHDJID找到疾病的诊断方法     * @param GHDJID     * @return 返回值是map,map中存储项目名称和项目对应的天数。     * @throws SQLException     * @throws ClassNotFoundException     * @备注:     */    public Map<String, List<String>> finalPath(String GHDJID)            throws ClassNotFoundException, SQLException {        // Map<String, Integer> result = new HashMap<String, Integer>();        Map<String, List<String>> map = new HashMap<String, List<String>>();        Connection connect = ConnectDatabase.connectDatabase();        Statement statement = connect.createStatement();        String sql = "Select XMMC,RQ,dense_rank() over(ORDER BY RQ) day from bnzh13_mx where GHDJID="                + GHDJID + "ORDER BY RQ";        ResultSet resultSet = statement.executeQuery(sql);        while (resultSet.next()) {            if (map.containsKey(resultSet.getString("XMMC"))) {                map.get(resultSet.getString("XMMC")).add(                        resultSet.getString("day"));            } else {                List<String> list = new ArrayList<String>();                list.add(resultSet.getString("day"));                map.put(resultSet.getString("XMMC"), list);            }        }        return map;    }}

节点代码

package com.winning.dm.pathway;import java.util.ArrayList;import java.util.List;/** *  * @Description: FP树的节点 * @Author orisun * @Date Jun 23, 2016 */class TreeNode {    /**节点名称**/    private String name;    /**频数**/    private int count;    private TreeNode parent;    private List<TreeNode> children;    /**下一个节点(由表头项维护的那个链表)**/    private TreeNode nextHomonym;    /**末节点(由表头项维护的那个链表)**/    private TreeNode tail;    @Override    public String toString() {        return name;    }    public TreeNode() {    }    public TreeNode(String name) {        this.name = name;    }    public String getName() {        return this.name;    }    public void setName(String name) {        this.name = name;    }    public int getCount() {        return this.count;    }    public void setCount(int count) {        this.count = count;    }    public TreeNode getParent() {        return this.parent;    }    public void setParent(TreeNode parent) {        this.parent = parent;    }    public List<TreeNode> getChildren() {        return this.children;    }    public void addChild(TreeNode child) {        if (getChildren() == null) {            List<TreeNode> list = new ArrayList<TreeNode>();            list.add(child);            setChildren(list);        } else {            getChildren().add(child);        }    }    public TreeNode findChild(String name) {        List<TreeNode> children = getChildren();        if (children != null) {            for (TreeNode child : children) {                if (child.getName().equals(name)) {                    return child;                }            }        }        return null;    }    public void setChildren(List<TreeNode> children) {        this.children = children;    }    public void printChildrenName() {        List<TreeNode> children = getChildren();        if (children != null) {            for (TreeNode child : children)                System.out.print(child.getName() + " ");        } else            System.out.print("null");    }    public TreeNode getNextHomonym() {        return this.nextHomonym;    }    public void setNextHomonym(TreeNode nextHomonym) {        this.nextHomonym = nextHomonym;    }    public void countIncrement(int n) {        this.count += n;    }    public TreeNode getTail() {        return tail;    }    public void setTail(TreeNode tail) {        this.tail = tail;    }}
0 0
原创粉丝点击