字典树的构建

来源:互联网 发布:西南大学网络学费多少 编辑:程序博客网 时间:2024/06/07 16:18

摘要

  该部分主要讲述基于Java语言构建字典树,包括字典树的剪枝与遍历操作。字典树原理不再赘述,代码实现部分如下

实现部分

工具类Tools.java,主要实现对大数据集的采样,以及对数据规模的统计

package main;import java.io.BufferedReader;import java.io.BufferedWriter;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStreamReader;import java.io.OutputStreamWriter;import java.util.HashSet;public class Tools {    /**     * 从大数据集中拆分出小样本, size指定获取的行数     * */    public static void getSample(String src, String des, long size) throws IOException {        FileInputStream fis = new FileInputStream(src);        InputStreamReader isr = new InputStreamReader(fis);        BufferedReader br = new BufferedReader(isr);        FileOutputStream fos = new FileOutputStream(des);        OutputStreamWriter osw = new OutputStreamWriter(fos);        BufferedWriter bw = new BufferedWriter(osw);        for (int row = 0; row < size; row++) {            String line = br.readLine();            bw.write(line);            bw.write(System.lineSeparator());        }        bw.close();        osw.close();        fos.close();        br.close();        isr.close();        fis.close();        System.out.println("get sample successful, size is: " + size);    }    /**     * 提取数据文件的指定列(列从0开始计数),按行写入到新文件     * */    public static void getColumn(String src, String des, int colNum) throws IOException {        FileInputStream fis = new FileInputStream(src);        InputStreamReader isr = new InputStreamReader(fis);        BufferedReader br = new BufferedReader(isr);        FileOutputStream fos = new FileOutputStream(des);        OutputStreamWriter osw = new OutputStreamWriter(fos);        BufferedWriter bw = new BufferedWriter(osw);        // 依次读取文件每一行        String line = br.readLine();        while (line != null) {            String column = line.split("\\t")[colNum].trim();   // 提取出指定列            if (column.equals("")) {                            // 过滤无效空记录                ;            } else {                bw.write(column);                bw.write(System.lineSeparator());               // 每个记录占一行                          }            line = br.readLine();        }        bw.close();        osw.close();        fos.close();        br.close();        isr.close();        fis.close();        System.out.println("extract column to file successful, cloumnNum is: " + colNum);    }    /**     * 数据规模统计     * */    public static void getSize(String src) throws IOException {        HashSet<String> chars_set = new HashSet<String>();      // 存储单个字符        HashSet<String> terms_set = new HashSet<>();            // 存储所有纪录        FileInputStream fis = new FileInputStream(src);        InputStreamReader isr = new InputStreamReader(fis);        BufferedReader br = new BufferedReader(isr);        String line = br.readLine();        while (line != null) {            terms_set.add(line);            String[] arrs = line.split("");            for (String arr : arrs)                chars_set.add(arr);            line = br.readLine();        }        System.out.println("chars_set.size():" + chars_set.size());        System.out.println("terms_set.size():" + terms_set.size());    }    public static void main(String[] args) throws IOException {        String path1 = "";        String path2 = "";        // 抽取第一列        getColumn(path1, path2, 1);        getSize(path2);    }}

节点类Node.java,树节点

package main;import java.util.LinkedList;/** * 树节点类 * */public class Node {    char val;       // 保存当前节点的字符    int count;      // 统计经过当前节点的字符串数目    boolean isEnd;  // 标志当前节点是否是一个词的末尾    Node parent;    // 存储当前节点的父节点    LinkedList<Node> childList; // 存储当前节点的直接子节点    int org_count;  // 当前节点的原始频次    int max_org_count;  // 当前节点路径上的最大org_count    /**     * 带参构造方法,构造含有指定字符val的节点,并指明新建节点的父节点     * */    public Node(char val, Node parent) {        this.val = val;        this.count = 1;        this.isEnd = false;        this.parent = parent;        this.childList = new LinkedList<Node>();        this.org_count = 0;        this.max_org_count = 0;    }    /**     * 无参构造方法,初始化val=' ', count=0, isEnd=false, parent=null, childList=new List()     * */    public Node() {        this(' ', null);    }    /**     * 根据指定的字符,获取当前节点的子节点。子节点不存在则返回null     * */    public Node getNode(char val) {        // 依次遍历链表        for (Node child : childList) {            if (child.val == val)                return child;            else                continue;        }        return null;    }    @Override    public String toString() {        return this.val + ":" + this.count;    }}

字典树类Trietree.java

package main;import java.util.AbstractMap.SimpleEntry;import java.util.LinkedList;import java.util.Map.Entry;/** * 定义Trie树,以Node作为节点类 * @author stevinpan * */public class TrieTree {    public Node root = null;        // 树的根节点,不存储字符信息    private int min_count = 0;  // 剪枝的最小阈值    /**     * 构造方法,初始化根节点     * */    public TrieTree() {        root = new Node();    }    /**     * 判断当前树是否存在指定字符串     * */    public boolean isExist(String word) {        Node curr = root;       // 获取根结点作为当前指针        if (curr == null || word == null)            return false;        // 依次遍历当前字符串的每个串        for (int index = 0; index < word.length(); index++) {            Node next = curr.getNode(word.charAt(index));       // 获取包含当前字符的子节点            if (next != null) {     // 子节点存在,指针后移                curr = next;            } else {                // 子节点不存在,直接返回                return false;            }        }        // 根据curr节点的标志位,判断是否是单词结尾        if (curr.isEnd) {            System.out.println(curr.count);            return true;        }        return false;    }    /**     * 插入字符串,相同字符串也要插入     * 插入成功返回true     *      * */    public boolean insert(String word) {        if (word == null)            return false;        if (root == null)            root = new Node();        Node curr = root;           // 获取根结点作为当前指针        // 依次遍历字符串,树路径上所有节点count++        for (int index = 0; index < word.length(); index++) {            Node next = curr.getNode(word.charAt(index));       // 获取包含当前字符的子节点            if (next != null) {     // 存在包含当前字符的子节点,子节点count++,指针后移                next.count++;                curr = next;            } else {                // 不存在包含当前字符的节点,则创建新节点,指针后移                next = new Node(word.charAt(index), curr);                curr.childList.add(next);                curr = next;            }        }        // 标注出单词结尾        curr.isEnd = true;        curr.org_count++;        return true;    }    /**     * trie树的层次遍历,每行一个职业名称     * */    public void printAll() {        printAll(root);    }    private void printAll(Node node) {        LinkedList<Node> childList = node.childList;        // 获取当前节点的子节点        for (Node child : childList) {                      // 依次遍历每个子节点,如果子节点isEnd==true,则反向输出该节点            if (child.isEnd) {                System.out.println(child.count+"\t"+toRoot(child));            }            printAll(child);        }    }    /**     * 层次遍历,每行按层次输出     * 从根节点开始,遇到isEnd=true则输出向上的路径,层次遍历     * */    public LinkedList<SimpleEntry<Integer, String>> level() {        LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<SimpleEntry<Integer, String>>();        level(root, result);        return result;    }    private void level(Node node, LinkedList<SimpleEntry<Integer, String>>  results) {        // 遇到isEnd=true节点,向上回溯        if (node.isEnd) {            // 获取当前节点向上的全路径            String full_path = toRoot(node);            // 获取当前节点各个isEnd祖先节点全路径            String parent_paths = toParents(node.parent);            // 输出到List            results.add(new SimpleEntry<Integer, String>(node.count, node.org_count + "\t" + full_path + parent_paths));            // 递归访问子节点            for (Node child : node.childList) {                level(child, results);            }        } else {    // 依次遍历子节点            for (Node child : node.childList) {                level(child, results);            }        }    }    /**     * 根据当前节点,向上反向输出,直至root节点     * */    public String toRoot(Node node) {        if (node != root) {            return node.val + toRoot(node.parent);        } else {            return "";        }    }    /**     * 根据当前节点,获取所有祖先节点的路径,以"\t"分隔输出     * */    public String toParents(Node node) {        String res = "";        while (node != root) {            if (node.isEnd && (node.org_count > this.min_count)) {                res += "\t" + toRoot(node);                node = node.parent;            } else {                node = node.parent;            }        }        return res;    }    /**     * 查找叶子节点,返回叶子节点组成的链表     * */    public LinkedList<Node> findLeaf() {        LinkedList<Node> leafList = new LinkedList<>();        findLeaf(root, leafList);        return leafList;    }    private void findLeaf(Node node, LinkedList<Node> leafList) {        if (node.childList.size() == 0) {   // 找到叶子节点            leafList.add(node);        } else {                            // 遍历当前节点的子节点            LinkedList<Node> childList = node.childList;            for (Node child : childList) {                findLeaf(child, leafList);            }        }    }    /**     * 从叶子节点向上回溯遍历,遇到isEnd=true节点就获取词路径     * */    public LinkedList<SimpleEntry<Integer, String>> back() {        LinkedList<Node> leafList = findLeaf();        LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<>();        // 从每个叶子节点开始向上遍历        for (Node node : leafList) {            Node curr = node;            int count = curr.count;            StringBuilder sb = new StringBuilder();            sb.append("\t" + count);            while (curr != root) {                if (curr.isEnd) {                    String path = toRoot(curr); // 获取词路径                    sb.append("\t");                    sb.append(path);                                        }                curr = curr.parent;            }            if (sb.length() > 0)                result.add(new SimpleEntry<Integer, String>(count, sb.toString()));        }        return result;    }    /**     * 从根节点开始,更新每个isEnd=true节点的max_org_count     * */    public void updateMaxOrgCount() {        updateMaxOrgCount(root);    }    private void updateMaxOrgCount(Node node) {        if (!node.isEnd) {      // 如果当前节点isEnd=false,则递归遍历子节点            LinkedList<Node> childList = node.childList;            for (Node child : childList) {                updateMaxOrgCount(child);            }        } else {                // 如果当前节点的isEnd=true            Node parent_isEnd = getParent(node);    // 获取当前节点的isEnd=true父节点            if (parent_isEnd == null) {             // 如果不存在这样的父节点,则当前节点的org_count即为max_org_count                node.max_org_count = node.org_count;            } else {                                // 如果存在这样的父节点,则将当前节点的org_count与父节点的max_org_count比较,更新当前节点的max_org_count                node.max_org_count = node.org_count > parent_isEnd.max_org_count ? node.org_count : parent_isEnd.max_org_count;            }            LinkedList<Node> childList = node.childList;            for (Node child : childList) {                updateMaxOrgCount(child);            }        }    }    /**     * 获取当前节点的上一个isEnd=true节点     * */    public Node getParent(Node node) {        Node parent = node.parent;          // 获取当前节点的父节点        while (parent != null && !parent.isEnd) {   // 父节点存在,且父节点的isEnd=false,则向上遍历            parent = parent.parent;        }        if (parent == null) {               // 最终不存在isEnd=true的父节点            return null;        } else {                            // 存在isEnd=true的父节点            return parent;        }    }    /**     * trie树剪枝     * */    public void cart(int min_count) {        this.min_count = min_count;        cart(root);    }    private void cart(Node node) {        LinkedList<Node> childList = node.childList;        // 获取当前节点的子节点        for (Node child : childList) {                      // 依次遍历每个子节点,如果子节点isEnd==true && max_org_count < min_count, 则设置isEnd=false            if (child.isEnd && child.max_org_count < this.min_count) {                child.isEnd = false;            }            cart(child);        }    }//  //  public static void main(String[] args) {//      TrieTree tree = new TrieTree();//      tree.insert(new StringBuilder("博士生导师").reverse().toString());//      tree.insert(new StringBuilder("硕士生导师").reverse().toString());//      tree.insert(new StringBuilder("硕士生导师").reverse().toString());//      tree.insert(new StringBuilder("博士硕士生导师").reverse().toString());//      tree.insert(new StringBuilder("导师").reverse().toString());//      tree.insert(new StringBuilder("导师").reverse().toString());//      tree.insert(new StringBuilder("导师").reverse().toString());//      tree.insert(new StringBuilder("老师").reverse().toString());//      tree.insert(new StringBuilder("高级工程师").reverse().toString());//      //      tree.updateMaxOrgCount();//      //      System.out.println("***************************");//      LinkedList<SimpleEntry<Integer, String>> paths = tree.level();//      for (SimpleEntry<Integer, String> entry : paths) {//          System.out.println(entry.getKey()+"\t"+entry.getValue());//      }//  }}

主要操作类Main.java

package main;import java.io.BufferedReader;import java.io.BufferedWriter;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStreamReader;import java.io.OutputStreamWriter;import java.util.AbstractMap.SimpleEntry;import java.util.Collections;import java.util.Comparator;import java.util.LinkedList;/** * 程序入口类 * */public class Main {    public static void main(String[] args) throws Exception {        String src = "";        String des = "";        int min_count = 100;            // 树剪枝的阈值        TrieTree tree = treeInit2(src, 1);  // 树的初始化, 输入为包含词条列的文件,需要指明目标词条所在的列号(列从0计数)        /**         * 剪枝之前必须更新节点的max_org_count值         * */        tree.updateMaxOrgCount();//      tree.cart(min_count);        // 遍历并排序        LinkedList<SimpleEntry<Integer, String>> results = tree.level();        Collections.sort(results, new Comparator<SimpleEntry<Integer, String>>(){            @Override            public int compare(SimpleEntry<Integer, String> o1, SimpleEntry<Integer, String> o2) {                return o2.getKey() - o1.getKey();            }        });        // 将结果写出到文件        saveToFile(results, des);        System.out.println("process sucessful");    }    /**     * 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树     * @param src : 词条文件,每个词条占一行     * @return tree : 返回初步构造的树     * @throws IOException      * */    public static TrieTree treeInit1(String src) throws IOException {        TrieTree tree = new TrieTree();        /**         * 从文件读取词条,字符串反序后插入TrieTree         * */        FileInputStream fis = new FileInputStream(src);        InputStreamReader isr = new InputStreamReader(fis);        BufferedReader br = new BufferedReader(isr);        String line = br.readLine();        while (line != null) {            if (line.equals("")) {  // 过滤掉词条为空字符串                ;            } else{                tree.insert(new StringBuilder(line).reverse().toString());            }            line = br.readLine();                       }        br.close();        isr.close();        fis.close();        return tree;    }    /**     * 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树     * @param src : 词条文件,每个词条占一行     * @return tree : 返回初步构造的树     * @throws IOException      * */    public static TrieTree treeInit2(String src, int colNum) throws IOException {        TrieTree tree = new TrieTree();        /**         * 从文件读取词条,字符串反序后插入TrieTree         * */        FileInputStream fis = new FileInputStream(src);        InputStreamReader isr = new InputStreamReader(fis);        BufferedReader br = new BufferedReader(isr);        String line = br.readLine();        while (line != null) {            String term = line.split("\\t")[colNum].trim();            if (term.equals("")) {  // 过滤掉词条为空字符串                ;            } else{                tree.insert(new StringBuilder(term).reverse().toString());            }            line = br.readLine();        }        br.close();        isr.close();        fis.close();        return tree;    }    /**     * 输出便利结果到文件     * @param list : 链表,泛型为SimpleEntry<Integer, String>     * @param des : 目标存储路径     * @return     * @author stevinpan     * @throws IOException      * */    public static void saveToFile(LinkedList<SimpleEntry<Integer, String>> list, String des) throws IOException {        FileOutputStream fos = new FileOutputStream(des);        OutputStreamWriter osw = new OutputStreamWriter(fos);        BufferedWriter bw = new BufferedWriter(osw);        long counter = 1;        for (SimpleEntry<Integer, String> entry : list) {            String kv = entry.getKey()+"\t"+entry.getValue();            bw.write(kv);            bw.write(System.lineSeparator());            System.out.println(counter++);        }        bw.close();        osw.close();        fos.close();    }}

工程文件地址为:https://github.com/panshan/TrieTree.git

原创粉丝点击