DAT的实现

来源:互联网 发布:百度seo排名 编辑:程序博客网 时间:2024/06/06 04:51
手痒,自己实现了一下,UT已经通过。在lucene4基础上实现,加上接口不到300行代码。
package com.dp.junhao.jhsegmenter;import gnu.trove.iterator.TByteIterator;import gnu.trove.list.array.TByteArrayList;import gnu.trove.procedure.TByteProcedure;import gnu.trove.set.TByteSet;import gnu.trove.set.hash.TByteHashSet;import org.apache.lucene.util.BytesRef;import org.apache.lucene.util.UnicodeUtil;import java.nio.charset.Charset;import java.util.*;/** * Created by junhao.zhang on 17/7/27. */public class DAT {    private final int[] base;    private final int[] check;    // base数组默认填为0,check数组默认填为-1    // 看节点是否存在,只看check[i]==-1    // 是否是tail节点,看check[i]最高位是否为0    public DAT(int[] base, int[] check) {        this.base = base;        this.check = check;    }    // for test.    int[] getBaseArray() { return base; }    // for test.    int[] getCheckArray() { return check; }    public boolean containsNode(String term) {        BytesRef bytes = new BytesRef();        UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);        return containsNode(bytes);    }    public boolean containsNode(byte[] bytes, int offset, int length) {        int node = 0, parent;        for (int i = offset; i < offset + length; ++i) {            parent = node;            node = base[node] + bytes[i];            if (node >= check.length || (check[node] & 0x7fffffff) != parent) {                return false;            }        }        return true;    }    public boolean containsNode(BytesRef bytes) {        return containsNode(bytes.bytes, bytes.offset, bytes.length);    }    public boolean containsTerm(String term) {        BytesRef bytes = new BytesRef();        UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);        return containsTerm(bytes);    }    public boolean containsTerm(BytesRef bytes) {        return containsTerm(bytes.bytes, bytes.offset, bytes.length);    }    public boolean containsTerm(byte[] bytes, int offset, int length) {        int node = 0, parent;        for (int i = offset; i < offset + length; ++i) {            parent = node;            node = base[node] + bytes[i];            if (node >= check.length || (check[node] & 0x7fffffff) != parent) {                return false;            }        }        return check[node] >= 0;    }    public double getCompactRate() {        int count = 0;        for (int e : check) {            if (e != -1) {                ++count;            }        }        return (double) count / check.length;    }    public static class Builder {        private static final BytesRef EMPTY_BYTESREF = new BytesRef(BytesRef.EMPTY_BYTES, 0, 0);        private static final int INITIAL_CAPACITY = 32;        private static final int[] table2pow = new int[]{0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,                0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000,                0x10000, 0x20000, 0x40000, 0x80000, 0x100000, 0x200000, 0x400000, 0x800000,                0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000, 0x40000000};  // 最高位留作词标志位.        private TreeSet<BytesRef> terms = new TreeSet<BytesRef>();        public Builder addTerm(String term) {            if (term.isEmpty()) {                return this;            }            BytesRef bytesRef = new BytesRef();            UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytesRef);            terms.add(bytesRef);            return this;        }        private static BytesRef nextBytesRef(BytesRef bytes) {            BytesRef newBytes = new BytesRef();            newBytes.copyBytes(bytes);            int advance = 1;            for (int i = newBytes.length - 1; i >= 0; --i) {                if ((newBytes.bytes[i] & 0xff) + advance <= 255) {                    newBytes.bytes[i] += advance;                    break;                }                newBytes.bytes[i] = 0;                advance = 1;            }            return newBytes;        }        private static boolean isTopmost(BytesRef bytes) {            for (int i = 0; i < bytes.length; ++i) {                if (bytes.bytes[bytes.offset + i] < 255) {                    return false;                }            }            return true;        }        private static int getBestBaseValue(AutoExpandIntArray check, int parentNode, TByteArrayList list) {            for (int i = parentNode + 1 - list.get(0); ;++i) {                TByteIterator itr = list.iterator();                int countdown = list.size();                while (itr.hasNext()) {                    if (check.get(i + itr.next(), -1) != -1) {                        break;                    }                    --countdown;                }                if (countdown == 0) {                    return i;                }            }        }        private static int getProper2PowNum(int minSize) {            if (minSize > table2pow[table2pow.length - 1]) {                throw new IllegalArgumentException(String.format("Array size: %d too large", minSize));            }            int low = 0, high = table2pow.length - 1;            while (low <= high) {                int mid = (low + high) / 2;                if (minSize == table2pow[mid]) {                    return table2pow[mid];                } else if (minSize < table2pow[mid]) {                    high = mid - 1;                } else {                    low = mid + 1;                }            }            return table2pow[low];        }        private SortedSet<BytesRef> getTermsWithPrefix(BytesRef bytes) {            if (isTopmost(bytes)) {                return terms.tailSet(bytes, false);            }            return terms.subSet(bytes, false, nextBytesRef(bytes), false);        }        private static class Entry {            BytesRef bytes;            int node;            Entry(BytesRef bytes, int node) {                this.bytes = bytes;                this.node = node;            }        }        private static class AutoExpandIntArray {            int[] data;            AutoExpandIntArray(int initialSize, int fillValue) {                data = new int[initialSize];                if (fillValue != 0) {                    Arrays.fill(data, fillValue);                }            }            void expand(int minSize, int fillValue) {                int[] newData = new int[getProper2PowNum(minSize)];                System.arraycopy(data, 0, newData, 0, data.length);                if (fillValue != 0) {                    Arrays.fill(newData, data.length, newData.length, fillValue);                }                data = newData;            }            int set(int pos, int value, int fillValue) {                if (pos >= data.length) {                    expand(pos + 1, fillValue);                }                return data[pos] = value;            }            int get(int pos, int defaultValue) {                if (pos >= data.length) {                    expand(pos + 1, defaultValue);                }                return data[pos];            }        }        public DAT build() {            // BFS to iterate all nodes.            final Queue<Entry> queue = new LinkedList<Entry>();            final AutoExpandIntArray base = new AutoExpandIntArray(INITIAL_CAPACITY, 0);  // TODO: calc proper capacity.            final AutoExpandIntArray check = new AutoExpandIntArray(INITIAL_CAPACITY, -1);            queue.add(new Entry(EMPTY_BYTESREF, 0));            while (!queue.isEmpty()) {                final Entry elem = queue.poll();                SortedSet<BytesRef> tailSet = getTermsWithPrefix(elem.bytes);                final int length = elem.bytes.length;                Iterator<BytesRef> itr = tailSet.iterator();                TByteArrayList siblings = new TByteArrayList();                final TByteSet termSet = new TByteHashSet();                while (itr.hasNext()) {                    BytesRef term = itr.next();                    byte b = term.bytes[length];                    if (term.length == length + 1) {                        termSet.add(b);                    }                    if (!siblings.isEmpty() && siblings.get(siblings.size() - 1) == b) {                        continue;                    }                    siblings.add(b);                }                if (siblings.isEmpty()) {                    continue;                }                final int baseValue = base.set(elem.node, getBestBaseValue(check, elem.node, siblings), 0);                final int parentNode = elem.node;                siblings.forEach(new TByteProcedure() {                    @Override                    public boolean execute(byte b) {                        byte[] newBytes = new byte[length + 1];                        System.arraycopy(elem.bytes.bytes, elem.bytes.offset, newBytes, 0, length);                        newBytes[length] = b;                        queue.add(new Entry(new BytesRef(newBytes), baseValue + b));                        if (termSet.contains(b)) {                            check.set(baseValue + b, parentNode, -1);                        } else {                            check.set(baseValue + b, (parentNode | 0x80000000), -1);                        }                        return true;                    }                });            }            return new DAT(base.data, check.data);        }    }}


UT部分:

package com.dp.junhao.jhsegmenter;import org.apache.lucene.util.CharsRef;import org.apache.lucene.util.UnicodeUtil;import org.testng.annotations.Test;import static junit.framework.Assert.assertEquals;import static junit.framework.Assert.assertFalse;import static org.testng.Assert.assertTrue;/** * Created by junhao.zhang on 17/8/10. */public class DATTest {    @Test    public void testSimpleDAT() {        DAT.Builder builder = new DAT.Builder();        builder.addTerm("ABC").addTerm("AC").addTerm("ACE")                .addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");        DAT dat = builder.build();        assertTrue(dat.containsTerm("ABC"));        assertTrue(dat.containsTerm("AC"));        assertTrue(dat.containsTerm("ACE"));        assertTrue(dat.containsTerm("ACFF"));        assertTrue(dat.containsTerm("AD"));        assertTrue(dat.containsTerm("BBC"));        assertTrue(dat.containsTerm("CD"));        assertTrue(dat.containsTerm("CF"));        assertTrue(dat.containsTerm("ZQ"));        assertFalse(dat.containsTerm("BB"));    }    @Test    public void testSimpleDATWhiteBox() {        DAT.Builder builder = new DAT.Builder();        builder.addTerm("ABC").addTerm("AC").addTerm("ACE")                .addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");        DAT dat = builder.build();        int[] base = dat.getBaseArray();        int[] check = dat.getCheckArray();        assertEquals(base.length, 32);        assertEquals(check.length, 32);        assertEquals(base[0], -64);        assertEquals(base[1], -62);        assertEquals(base[2], -59);        assertEquals(base[3], -60);        assertEquals(base[4], -58);        assertEquals(base[5], -58);        assertEquals(base[6], 0);        assertEquals(base[7], -54);        assertEquals(base[8], 0);        assertEquals(base[9], 0);        assertEquals(base[10], 0);        assertEquals(base[11], 0);        assertEquals(base[12], -56);        assertEquals(base[13], 0);        assertEquals(base[14], 0);        assertEquals(base[15], 0);        assertEquals(base[16], 0);        assertEquals(base[17], 0);        assertEquals(base[18], 0);        assertEquals(base[19], 0);        assertEquals(base[20], 0);        assertEquals(base[21], 0);        assertEquals(base[22], 0);        assertEquals(base[23], 0);        assertEquals(base[24], 0);        assertEquals(base[25], 0);        assertEquals(base[26], -54);        assertEquals(base[27], 0);        assertEquals(base[28], 0);        assertEquals(base[29], 0);        assertEquals(base[30], 0);        assertEquals(base[31], 0);        assertEquals(check[0], -1);        assertEquals(check[1], (0 | 0x80000000));        assertEquals(check[2], (0 | 0x80000000));        assertEquals(check[3], (0 | 0x80000000));        assertEquals(check[4], (1 | 0x80000000));        assertEquals(check[5], 1);  // ABC        assertEquals(check[6], 1);  // AD        assertEquals(check[7], (2 | 0x80000000));        assertEquals(check[8], 3);  // CD        assertEquals(check[9], 4);  // ABC        assertEquals(check[10], 3);  // CF        assertEquals(check[11], 5);  // ACE        assertEquals(check[12], (5 | 0x80000000));        assertEquals(check[13], 7);  // BBC        assertEquals(check[14], 12);  // ACFF        assertEquals(check[15], -1);        assertEquals(check[16], -1);        assertEquals(check[17], -1);        assertEquals(check[18], -1);        assertEquals(check[19], -1);        assertEquals(check[20], -1);        assertEquals(check[21], -1);        assertEquals(check[22], -1);        assertEquals(check[23], -1);        assertEquals(check[24], -1);        assertEquals(check[25], -1);        assertEquals(check[26], (0 | 0x80000000));        assertEquals(check[27], 26);  // ZQ        assertEquals(check[28], -1);        assertEquals(check[29], -1);        assertEquals(check[30], -1);        assertEquals(check[31], -1);        assertEquals(dat.getCompactRate(), 0.5);    }}