diff --git a/src/test/java/ch/digitalfondue/jfiveparse/DoubleArrayTrie.java b/src/test/java/ch/digitalfondue/jfiveparse/DoubleArrayTrie.java new file mode 100644 index 0000000..21d3837 --- /dev/null +++ b/src/test/java/ch/digitalfondue/jfiveparse/DoubleArrayTrie.java @@ -0,0 +1,355 @@ +package ch.digitalfondue.jfiveparse; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.IntUnaryOperator; + +// from https://medium.com/@lchang1994/deep-dive-dat-double-array-trie-f51e5e5f006c +// FIXME deduplicate the inserted value, currently at each insert, we reserve a new spot for the value, do a linear probe. +public class DoubleArrayTrie { + int[] check; + long[] base; + long[] values; + int valuesCount; + + public DoubleArrayTrie() { + check = new int[128]; + base = new long[128]; + values = new long[128]; + + Arrays.fill(check, -1); // -1 means empty + Arrays.fill(base, 0L); + + setBaseAt(1, 1); // Root is 1 + // base[1] = 1; + check[0] = -1; // Root has no parent + } + + private void setBaseAt(int idx, int offset) { + setBaseAt(idx, offset, -1); + } + + private void setBaseAt(int idx, int offset, int idxValue) { + base[idx] = pack1Or2IntToLong(offset, idxValue); + } + + private int getBaseAt(int idx) { + return extractHigh(base[idx]); + } + + private int getValueIdxAt(int idx) { + return extractLow(base[idx]); + } + + private void expand(int index) { + if (index < base.length) { + return; + } + int oldSize = base.length; + int newSize = Math.max(index + 1, oldSize * 2); + + // Resize check array and fill new slots with -1 + check = Arrays.copyOf(check, newSize); + Arrays.fill(check, oldSize, newSize, -1); + + // Resize base array (new slots automatically defaults to 0) + base = Arrays.copyOf(base, newSize); + } + + public void insert(String word, long value) { + int node = 1; + for (int i = 0; i < word.length(); i++) { + int code = charId(word.charAt(i)); + if (code == -1) { + throw new IllegalStateException("Character " + word.charAt(i) + " is not mapped"); + } + + // Ensure base has a valid offset (0 means no children yet) + if (getBaseAt(node) == 0) { + // base[node] = 1; // Default offset 1 + setBaseAt(node, 1); + } + + // Use Math.abs() because base might be negative if it's a terminal node + int offset = Math.abs(getBaseAt(node)); + int nextNode = offset + code; + expand(nextNode); + + if (check[nextNode] == -1) { + // Free spot + check[nextNode] = node; + } else if (check[nextNode] != node) { + // Collision + resolveCollision(node, code); + // Re-calculate after resolution + offset = Math.abs(getBaseAt(node)); + nextNode = offset + code; + expand(nextNode); + check[nextNode] = node; + } + + node = nextNode; + } + setTerminal(node, value); + } + + private void setTerminal(int node, long value) { + int idxValue = valuesCount; + for (int i = 0; i < valuesCount; i++) { + if (values[i] == value) { + idxValue = i; + break; + } + } + + if (getBaseAt(node) == 0) { + // base[node] = -1; // Terminal, offset 1 + setBaseAt(node, -1, idxValue); + } else if (getBaseAt(node) > 0) { + // base[node] = -base[node]; + setBaseAt(node, getBaseAt(node), idxValue); + } + if (idxValue >= values.length) { + values = Arrays.copyOf(values, values.length * 2); + } + values[idxValue] = value; + if (idxValue == valuesCount) { + valuesCount++; + } + } + + private int isTerminal(int node) { + if (node >= base.length) { + return 0; + } + return getBaseAt(node) < 0 ? 1 : 0; + } + + private void resolveCollision(int node, int newCharCode) { + List children = getChildren(node); + children.add(newCharCode); + + int newBaseOffset = findValidBase(children); + int oldBaseOffset = Math.abs(getBaseAt(node)); + int baseValue = getValueIdxAt(node); + + // Apply new base, maintaining terminal status + if (getBaseAt(node) < 0) { + setBaseAt(node, -newBaseOffset, baseValue); + //base[node] = -newBaseOffset; + } else { + setBaseAt(node, newBaseOffset, baseValue); + // base[node] = newBaseOffset; + } + + for (int c : children) { + if (c == newCharCode) { + continue; + } + + int oldIdx = oldBaseOffset + c; + int newIdx = newBaseOffset + c; + expand(newIdx); + + // Move node info: base, check + // CHECK, this should be the only place where we relocate the whole info (high/low) + base[newIdx] = base[oldIdx]; + check[newIdx] = node; // Parent is still 'node' + + // If the moved child has children, update their parent pointers (check) + if (getBaseAt(oldIdx) != 0) { + int childBaseOffset = Math.abs(getBaseAt(oldIdx)); + List grandchildren = getChildrenOfOffset(oldIdx, childBaseOffset); + for (int gcCode : grandchildren) { + int gcIdx = childBaseOffset + gcCode; + check[gcIdx] = newIdx; + } + } + + // Clear old spot + check[oldIdx] = -1; + // base[oldIdx] = 0; + setBaseAt(oldIdx, 0); + } + } + + private int findValidBase(List children) { + int q = 1; + while (true) { + boolean ok = true; + for (int c : children) { + int idx = q + c; + if (idx < check.length && check[idx] != -1) { + ok = false; + break; + } + } + if (ok) { + return q; + } + q++; + } + } + + private List getChildren(int node) { + return getChildrenOfOffset(node, Math.abs(getBaseAt(node))); + } + + private List getChildrenOfOffset(int nodeIdx, int baseOffset) { + List children = new ArrayList<>(); + if (baseOffset == 0) { + return children; + } + // the chars map to 1 up to MAX_CHAR_ID + for (int c = 1; c <= MAX_CHAR_ID; c++) { + int idx = baseOffset + c; + if (idx < check.length && check[idx] == nodeIdx) { + children.add(c); + } + } + return children; + } + + // return: + // -1 don't exist + // 0 is a sub element (e.g "ab" is a subelement of "abc") + // 1 is terminal + public int lookup(String word) { + return lookup(word, this::isTerminal); + } + + // can only be called when lookupp as returned a terminal node + public long lookupValue(String word) { + return values[lookup(word, this::getValueIdxAt)]; + } + + private int lookup(String word, IntUnaryOperator op) { + int node = 1; + for (int i = 0; i < word.length(); i++) { + int code = charId(word.charAt(i)); + if (code == -1) { + return -1; + } + int offset = Math.abs(getBaseAt(node)); + if (offset == 0) { + return -1; + } + + int nextNode = offset + code; + if (nextNode >= check.length || check[nextNode] != node) { + return -1; + } + node = nextNode; + } + return op.applyAsInt(node); + } + + private static long pack1Or2IntToLong(int high, int low) { + return ((long) high << 32) | (low & 0xFFFFFFFFL); + } + + // extract from a long v the high value part + static int extractHigh(long v) { + return (int) (v >> 32); + } + + // extract from a long v the low value part + static int extractLow(long v) { + return (int) v; + } + + + @Test + @Disabled + void check() { + var dat = new DoubleArrayTrie(); + + var words = List.of("dog", "apple", "app", "banana", "band", "b"); + for (var v : words) { + dat.insert(v, 0); + } + var testWords = List.of("dog", "apple", "app", "banana", "band", "b", "appl", "ban", "c", "", "&", "."); + for (var w : testWords) { + var found = dat.lookup(w); + System.err.println(w + " " + found); + } + } + + + // generated code + + private static final int MAX_CHAR_ID = 62; + public int charId(char c) { + return c >= 38 && c <= 122 ? entitiesCharIdx[c - 38] : -1; + } + private static final int[] entitiesCharIdx = new int[85]; + static { + Arrays.fill(entitiesCharIdx, -1); + entitiesCharIdx[0] = 1; + entitiesCharIdx[11] = 2; + entitiesCharIdx[12] = 3; + entitiesCharIdx[13] = 4; + entitiesCharIdx[14] = 5; + entitiesCharIdx[15] = 6; + entitiesCharIdx[16] = 7; + entitiesCharIdx[17] = 8; + entitiesCharIdx[18] = 9; + entitiesCharIdx[21] = 10; + entitiesCharIdx[27] = 11; + entitiesCharIdx[28] = 12; + entitiesCharIdx[29] = 13; + entitiesCharIdx[30] = 14; + entitiesCharIdx[31] = 15; + entitiesCharIdx[32] = 16; + entitiesCharIdx[33] = 17; + entitiesCharIdx[34] = 18; + entitiesCharIdx[35] = 19; + entitiesCharIdx[36] = 20; + entitiesCharIdx[37] = 21; + entitiesCharIdx[38] = 22; + entitiesCharIdx[39] = 23; + entitiesCharIdx[40] = 24; + entitiesCharIdx[41] = 25; + entitiesCharIdx[42] = 26; + entitiesCharIdx[43] = 27; + entitiesCharIdx[44] = 28; + entitiesCharIdx[45] = 29; + entitiesCharIdx[46] = 30; + entitiesCharIdx[47] = 31; + entitiesCharIdx[48] = 32; + entitiesCharIdx[49] = 33; + entitiesCharIdx[50] = 34; + entitiesCharIdx[51] = 35; + entitiesCharIdx[52] = 36; + entitiesCharIdx[59] = 37; + entitiesCharIdx[60] = 38; + entitiesCharIdx[61] = 39; + entitiesCharIdx[62] = 40; + entitiesCharIdx[63] = 41; + entitiesCharIdx[64] = 42; + entitiesCharIdx[65] = 43; + entitiesCharIdx[66] = 44; + entitiesCharIdx[67] = 45; + entitiesCharIdx[68] = 46; + entitiesCharIdx[69] = 47; + entitiesCharIdx[70] = 48; + entitiesCharIdx[71] = 49; + entitiesCharIdx[72] = 50; + entitiesCharIdx[73] = 51; + entitiesCharIdx[74] = 52; + entitiesCharIdx[75] = 53; + entitiesCharIdx[76] = 54; + entitiesCharIdx[77] = 55; + entitiesCharIdx[78] = 56; + entitiesCharIdx[79] = 57; + entitiesCharIdx[80] = 58; + entitiesCharIdx[81] = 59; + entitiesCharIdx[82] = 60; + entitiesCharIdx[83] = 61; + entitiesCharIdx[84] = 62; + } +} diff --git a/src/test/java/ch/digitalfondue/jfiveparse/GenerateEntities.java b/src/test/java/ch/digitalfondue/jfiveparse/GenerateEntities.java index 77711ef..ee1101f 100644 --- a/src/test/java/ch/digitalfondue/jfiveparse/GenerateEntities.java +++ b/src/test/java/ch/digitalfondue/jfiveparse/GenerateEntities.java @@ -17,6 +17,9 @@ import com.google.gson.GsonBuilder; import com.google.gson.reflect.TypeToken; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; @@ -24,7 +27,7 @@ import java.lang.reflect.Type; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.Map; +import java.util.*; import java.util.zip.GZIPOutputStream; /** @@ -37,13 +40,162 @@ private static class EntityValues { int[] codepoints; } - public static void main(String[] args) throws IOException { + + // this generates the supporting array for mapping a character to a unique id + // for the double array trie + // step 1 + @Test + @Disabled + void generateCharactersMappingArray() throws IOException { + Map m = getEntitiesMap(); + + var allUsedChars = new TreeSet(); + for (String key: m.keySet()) { + key.chars().forEach(allUsedChars::add); + } + var allChars = new ArrayList<>(allUsedChars); + var firstChar = allChars.get(0); + var lastChar = allChars.get(allChars.size() - 1); + // mapping function + System.out.println("private static final int MAX_CHAR_ID = " + allChars.size() + ";"); + System.out.println("public int charId(char c) {"); + System.out.printf(" return c >= %d && c <= %d ? entitiesCharIdx[c - %d] : -1;\n", firstChar, lastChar, firstChar); + System.out.println("}"); + // + + System.out.println("private static final int[] entitiesCharIdx = new int[" + (lastChar-firstChar+1) + "];"); + System.out.println("static {"); + // create support array + System.out.println("Arrays.fill(entitiesCharIdx, -1);"); + for (int i = 0; i < allChars.size(); i++) { + System.out.println("entitiesCharIdx[" + (allChars.get(i) - firstChar)+"] = "+ (i+1) + ";"); + } + System.out.println("}"); + // + } + + private static Map getEntitiesMap() throws IOException { Type type = (new TypeToken>() { }).getType(); String json = Files.readString(Paths.get("src/test/resources/entities.json")); Map m = new GsonBuilder().create().fromJson(json, type); + return m; + } + + // long l = (((long)x) << 32) | (y & 0xffffffffL); + // int x = (int)(l >> 32); + // int y = (int)l; + + private static long pack1Or2IntToLong(int[] codepoint) { + if (codepoint.length == 1) { + return ((long) codepoint[0] << 32); + } else if (codepoint.length == 2) { + return ((long) codepoint[0] << 32) | (codepoint[1] & 0xFFFFFFFFL); + } + throw new IllegalStateException("1 or 2 codepoint"); + } + + @Test + void checkValueSub() { + var dat = new DoubleArrayTrie(); + dat.insert("Æ", 40); + dat.insert("Æ", 41); + dat.insert("&", 42); + dat.insert("&", 43); + dat.lookupValue("&"); + } + + + // step 2 + @Test + @Disabled + void pregenerateSupportArray() throws IOException { + var entities = getEntitiesMap(); + var dat = new DoubleArrayTrie(); + for (String key : entities.keySet()) { + var codepoints = entities.get(key).codepoints; + dat.insert(key, pack1Or2IntToLong(codepoints)); + } + //System.out.println(Arrays.toString(dat.base)); + //System.out.println(Arrays.toString(dat.check)); + + // right size base and check array + int base_check = 0; + for (int i = dat.base.length - 1; i >= 0; i--) { + if (dat.base[i] != 0) { + base_check = i+1; + System.out.println("base length should be " + (i+1)); + break; + } + } + + for (int i = dat.check.length - 1; i >= 0; i--) { + if (dat.check[i] != -1) { + System.out.println("check length should be " + (i+1)); + break; + } + } + + ByteArrayOutputStream baosOneCodePoint = new ByteArrayOutputStream(); + GZIPOutputStream osOneCodePoint = new GZIPOutputStream(baosOneCodePoint); + DataOutputStream daos = new DataOutputStream(osOneCodePoint); + daos.writeInt(base_check); + for(int i = 0; i < base_check; i++) { + daos.writeLong(dat.base[i]); + } + for(int i = 0; i < base_check; i++) { + daos.writeInt(dat.check[i]); + } + daos.writeInt(dat.valuesCount); + for(int i = 0; i < dat.valuesCount; i++) { + daos.writeLong(dat.values[i]); + } + daos.flush(); + daos.close(); + Files.write(Paths.get("double-trie-array"), baosOneCodePoint.toByteArray()); + + + + /*System.out.println("base"); + System.out.println(Arrays.stream(Arrays.copyOfRange(dat.base, 0, base_check)) + .mapToObj(v -> ""+v + "L") + .toList()); + System.out.println("check"); + System.out.println(Arrays.toString(Arrays.copyOfRange(dat.check, 0, base_check))); + System.out.println("values"); + System.out.println(Arrays.stream(Arrays.copyOfRange(dat.values, 0, dat.valuesCount)) + .mapToObj(v -> ""+v + "L") + .toList()); + + System.out.println("values length is " + dat.valuesCount); + */ + //dat.lookup("&"); + //dat.lookupValue("&"); + + // all entities must be accounted for + for(String key : entities.keySet()) { + Assertions.assertEquals(1, dat.lookup(key)); + System.err.println("key " + key); + var k = dat.lookupValue(key); + var h = DoubleArrayTrie.extractHigh(k); + var l = DoubleArrayTrie.extractLow(k); + int[] codePoints = l == 0 ? new int[] {h} : new int[] {h, l}; + Assertions.assertArrayEquals(entities.get(key).codepoints, codePoints); + } + Assertions.assertEquals(-1, dat.lookup("&lol;")); + Assertions.assertEquals(0, dat.lookup("&am")); + } + + public static void main(String[] args) throws IOException { + Map m = getEntitiesMap(); + + + // we can iterate for making a and array for mapping char -> id! + // and also generate the range! + EntitiesPrefix p = new EntitiesPrefix(null); + //DoubleArrayTrie dat = new DoubleArrayTrie(); ByteArrayOutputStream baosOneCodePoint = new ByteArrayOutputStream(); GZIPOutputStream osOneCodePoint = new GZIPOutputStream(baosOneCodePoint); @@ -53,6 +205,7 @@ public static void main(String[] args) throws IOException { int twoCodePointLength = 0; for (String key : m.keySet()) { + //dat.insert(key); p.addWord(key, m.get(key).codepoints); if (m.get(key).codepoints.length == 1) {