From 89d62129883e108b558a569060898de5e1228a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Tue, 26 Dec 2023 18:39:01 +0200 Subject: [PATCH 1/9] Add a dedicated fragment splitter for the cl100k encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The "'(?:[sdmt]|ll|ve|re)|[^\r\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\r\n]*|\\s*[\r\n]|\\s+(?!\\S)|\\s+" regex was replaced by a dedicated parser for each segment with optimized unicode category detections. First category is the contractions, followed by words, numbers, punctuation and whitespace. The start and end indexes are collected and when a match is found we convert it to UFT8 to a reusable array which is presented to the fragmentConsumer (which will attempt to tokenize it). The last segment, the whitespaces are first completely consumed, split by new lines and if after the last (non-newline) whitespace there's a non-whitespace (e.g. a "\n a"), we pop off the last space for the next token. The isLetter, isNumeric, isLetterOrNumeric, isWhitespace, isNewline, isNotWhitespaceOrLetterOrNumeric, isNotNewlineOrLetterOrNumeric helpers are highly optimized (detecting the common cases first, before doing the heavy calculations) to detect if the next character category is a match or not. addUtf8Bytes needed to be reimplemented since I couldn't find any available way to convert it directly to a reusable list (which avoids creating so much temporary garbage). Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 4.368 ± 0.046 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 3.845 ± 0.082 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.355 ± 0.100 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.314 ± 0.093 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.982 ± 0.069 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 3.704 ± 0.076 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.871 ± 0.070 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.459 ± 0.065 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.437 ± 0.088 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.097 ± 0.090 s/op --- .../jtokkit/Cl100kParserBenchmark.java | 66 +++++ .../com/knuddels/jtokkit/Cl100kParser.java | 247 ++++++++++++++++++ .../com/knuddels/jtokkit/EncodingFactory.java | 32 ++- .../knuddels/jtokkit/GptBytePairEncoding.java | 4 +- .../com/knuddels/jtokkit/TokenEncoder.java | 8 +- .../knuddels/jtokkit/api/EncodingResult.java | 4 +- .../knuddels/jtokkit/Cl100kParserTest.java | 210 +++++++++++++++ .../java/com/knuddels/jtokkit/Cl100kTest.java | 22 +- 8 files changed, 571 insertions(+), 22 deletions(-) create mode 100644 benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java create mode 100644 lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java create mode 100644 lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java new file mode 100644 index 00000000..e83d9a38 --- /dev/null +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java @@ -0,0 +1,66 @@ +package com.knuddels.jtokkit; + + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.ArrayList; + +public class Cl100kParserBenchmark { + @Benchmark + public void benchmarkIsLetter(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetter(cp))); + } + } + + @Benchmark + public void benchmarkIsLetterOrNumeric(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetterOrNumeric(cp))); + } + } + + @Benchmark + public void benchmarkIsNewline(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNewline(cp))); + } + } + + @Benchmark + public void benchmarkIsNotNewlineOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotNewlineOrLetterOrNumeric(cp))); + } + } + + @Benchmark + public void benchmarkIsNotWhitespaceOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotWhitespaceOrLetterOrNumeric(cp))); + } + } + + @Benchmark + public void benchmarkIsNumeric(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNumeric(cp))); + } + } + + @Benchmark + public void benchmarkIsWhitespace(BenchmarkingState state, Blackhole bh) { + for (var fileContent : state.fileContents) { + fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isWhitespace(cp))); + } + } + + @Benchmark + public void benchmarkToUtf8Conversion(BenchmarkingState state, Blackhole bh) { + var dst = new ArrayList(); + for (var fileContent : state.fileContents) { + bh.consume(Cl100kParser.addUtf8Bytes(fileContent, 0, fileContent.length(), dst)); + } + } +} diff --git a/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java b/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java new file mode 100644 index 00000000..6aeb8afd --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java @@ -0,0 +1,247 @@ +package com.knuddels.jtokkit; + + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Predicate; + +import static java.lang.Character.*; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.binarySearch; + +public class Cl100kParser { + private static final String SDTM = "sdtmSDTMſ"; + private static final String SIMPLE_WHITESPACES = "\t\n\u000B\u000C\r"; + private static final int[] REMAINING_WHITESPACES = "\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u2028\u2029\u202F\u205F\u3000".codePoints().sorted().toArray(); + + public static void split(String input, Predicate> fragmentConsumer) { + assert isValidUTF8(input) : "Input is not UTF-8: " + input; + List utf8Bytes = new ArrayList<>(); + boolean finished = false; + for (int endIndex = 0; endIndex < input.length() && !finished; ) { + int startIndex = endIndex; + int c0 = input.codePointAt(startIndex); + int cc0 = charCount(c0); + int nextIndex = startIndex + cc0; + int c1 = (nextIndex < input.length()) ? input.codePointAt(nextIndex) : -1; + + if ((c0 == '\'') && c1 > 0) { + if (isShortContraction(c1)) { + // 1) `'[sdtm]` - contractions, such as the suffixes of `he's`, `I'd`, `'tis`, `I'm` + endIndex += 2; + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + continue; + } else if ((startIndex + 2) < input.length() && isLongContraction(c1, input.codePointAt(startIndex + 2))) { + // 1) `'(?:ll|ve|re)` - contractions, such as the suffixes of `you'll`, `we've`, `they're` + endIndex += 3; + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + continue; + } + } + + int cc1 = charCount(c1); + if ((isNotNewlineOrLetterOrNumeric(c0) && isLetter(c1)) || isLetter(c0)) { + // 2) `[^\r\n\p{L}\p{N}]?+\p{L}+` - words such as ` of`, `th`, `It`, ` not` + endIndex += cc0; + if (isLetter(c1)) { + endIndex += cc1; + while ((endIndex < input.length()) && isLetter(c0 = input.codePointAt(endIndex))) { + endIndex += charCount(c0); + } + } + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + } else if (isNumeric(c0)) { + // 3) `\p{N}{1,3}` - numbers, such as `4`, `235` or `3½` + endIndex += cc0; + if (isNumeric(c1)) { + endIndex += cc1; + if ((endIndex < input.length()) && isNumeric(c0 = input.codePointAt(endIndex))) { + endIndex += charCount(c0); + } + } + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + } else if (isNotWhitespaceOrLetterOrNumeric(c0) || ((c0 == ' ') && isNotWhitespaceOrLetterOrNumeric(c1))) { + // 4) ` ?[^\s\p{L}\p{N}]++[\r\n]*` - punctuation, such as `,`, ` .`, `"` + endIndex += cc0; + if ((endIndex < input.length()) && isNotWhitespaceOrLetterOrNumeric(c1)) { + endIndex += cc1; + while ((endIndex < input.length()) && isNotWhitespaceOrLetterOrNumeric(c0 = input.codePointAt(endIndex))) { + endIndex += charCount(c0); + } + } + while ((endIndex < input.length()) && isNewline(input.codePointAt(endIndex))) { + endIndex++; + } + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + } else { + // 5) `\s*[\r\n]+` - line endings such as `\r\n \r\n` + // 6) `\s+(?!\S)` - whitespaces such as ` ` or ` ` + // 7) `\s+` - unmatched remaining spaces, such as ` ` + assert isWhitespace(c0) : "Invalid character: " + Arrays.toString(toChars(c0)); + int lastNewLineIndex = isNewline(c0) ? endIndex : -1; + endIndex += cc0; + if (isWhitespace(c1)) { + lastNewLineIndex = isNewline(c1) ? endIndex : lastNewLineIndex; + endIndex += cc1; + while (endIndex < input.length() && isWhitespace(c0 = input.codePointAt(endIndex))) { + lastNewLineIndex = isNewline(c0) ? endIndex : lastNewLineIndex; + endIndex += charCount(c0); + } + } + + if (lastNewLineIndex > -1) { + int finalEndIndex = endIndex; + endIndex = lastNewLineIndex + 1; + if (endIndex < finalEndIndex) { + assert startIndex < endIndex; + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + startIndex = endIndex; + endIndex = finalEndIndex; + } + } + if (!finished) { + if (lastNewLineIndex + 1 < endIndex && !isWhitespace(c0)) { + endIndex--; + } + if (startIndex < endIndex) { + finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes)); + } + } + } + } + } + + + static boolean isShortContraction(int ch) { + return SDTM.indexOf(ch) >= 0; + } + + static boolean isLongContraction(int ch1, int ch2) { + if (((ch1 == 'l') && (ch2 == 'l')) + || ((ch1 == 'v') && (ch2 == 'e')) + || ((ch1 == 'r') && (ch2 == 'e'))) { + return true; + } else { + int lch1 = toUpperCase(ch1); + int lch2 = toUpperCase(ch2); + return ((lch1 == 'L') && (lch2 == 'L')) + || ((lch1 == 'V') && (lch2 == 'E')) + || ((lch1 == 'R') && (lch2 == 'E')); + } + } + + public static boolean isValidUTF8(String input) { + return UTF_8.newEncoder().canEncode(input); + } + + public static boolean isLetter(int ch) { + if (ch < 0xaa) { + return ((ch >= 'a') && (ch <= 'z')) + || ((ch >= 'A') && (ch <= 'Z')); + } else if (ch <= 0x323af) { + switch (getType(ch)) { + case UPPERCASE_LETTER: + case LOWERCASE_LETTER: + case TITLECASE_LETTER: + case MODIFIER_LETTER: + case OTHER_LETTER: + return true; + } + } + return false; + } + + public static boolean isNumeric(int ch) { + if (ch < 0xb2) { + return (ch >= '0') && (ch <= '9'); + } else if (ch <= 0x1fbf9) { + switch (getType(ch)) { + case DECIMAL_DIGIT_NUMBER: + case LETTER_NUMBER: + case OTHER_NUMBER: + return true; + } + } + return false; + } + + static boolean isLetterOrNumeric(int ch) { + if (ch < 0xaa) { + return ((ch >= 'a') && (ch <= 'z')) + || ((ch >= 'A') && (ch <= 'Z')) + || ((ch >= '0') && (ch <= '9')); + } else if (ch <= 0x323af) { + switch (getType(ch)) { + case UPPERCASE_LETTER: + case LOWERCASE_LETTER: + case TITLECASE_LETTER: + case MODIFIER_LETTER: + case OTHER_LETTER: + case DECIMAL_DIGIT_NUMBER: + case LETTER_NUMBER: + case OTHER_NUMBER: + return true; + } + } + return false; + } + + public static boolean isWhitespace(int ch) { + if (ch <= '\r') { + return SIMPLE_WHITESPACES.indexOf(ch) >= 0; + } else if (ch < '\u0085') { + return ch == ' '; + } else { + return (ch == '\u0085') + || (ch == '\u00A0') + || ((ch >= '\u1680') && (ch <= '\u3000') && (binarySearch(REMAINING_WHITESPACES, ch) >= 0)); + } + } + + static boolean isNewline(int ch) { + return (ch == '\r') + || (ch == '\n'); + } + + public static boolean isNotWhitespaceOrLetterOrNumeric(int ch) { + if (ch < '0') { + return ch >= 0 && ch != ' ' && (ch > '\r' || ch < '\t'); + } else { + return !isLetterOrNumeric(ch) && !isWhitespace(ch); + } + } + + public static boolean isNotNewlineOrLetterOrNumeric(int ch) { + if (ch < '0') { + return ch >= 0 && (ch == ' ' || !isNewline(ch)); + } else { + return !isLetterOrNumeric(ch); + } + } + + static List addUtf8Bytes(String input, int start, int end, List dst) { + dst.clear(); + for (int i = start; i < end; i++) { + int cp = input.codePointAt(i); + if (cp < 0x80) { + dst.add((byte) cp); + } else if (cp < 0x800) { + dst.add((byte) (0xc0 | (cp >> 0x6))); + dst.add((byte) (0x80 | (cp & 0x3f))); + } else if (cp < MIN_SUPPLEMENTARY_CODE_POINT) { + dst.add((byte) (0xe0 | (cp >> 0xc))); + dst.add((byte) (0x80 | ((cp >> 0x6) & 0x3f))); + dst.add((byte) (0x80 | (cp & 0x3f))); + } else { + assert cp < (MAX_CODE_POINT + 1) : "Invalid code point: " + cp; + dst.add((byte) (0xf0 | (cp >> 0x12))); + dst.add((byte) (0x80 | ((cp >> 0xc) & 0x3f))); + dst.add((byte) (0x80 | ((cp >> 0x6) & 0x3f))); + dst.add((byte) (0x80 | (cp & 0x3f))); + i++; + } + } + return dst; + } +} \ No newline at end of file diff --git a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java index 7cb93fc0..235acfa6 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java +++ b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java @@ -103,13 +103,10 @@ public static Encoding p50kEdit() { * @return an {@link Encoding} instance for the cl100k_base encoding */ public static Encoding cl100kBase() { - return fromPredefinedParameters( - "cl100k_base", - "'(?:[sdmt]|ll|ve|re)|[^\r\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\r\n]*|\\s*[\r\n]|\\s+(?!\\S)|\\s+", - "/com/knuddels/jtokkit/cl100k_base.tiktoken", - SPECIAL_TOKENS_CL100K_BASE, - true - ); + // "'(?:[sdmt]|ll|ve|re)|[^\r\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\r\n]*|\\s*[\r\n]|\\s+(?!\\S)|\\s+" + Map mergeableRanks = loadMergeableRanks("/com/knuddels/jtokkit/cl100k_base.tiktoken"); + GptBytePairEncodingParams params = new GptBytePairEncodingParams("cl100k_base", null, mergeableRanks, SPECIAL_TOKENS_CL100K_BASE); + return new Cl100kGptBytePairEncoding(params); } /** @@ -176,4 +173,25 @@ public static Map loadMergeableRanks(String fileName) { throw new IllegalStateException("Could not load " + fileName + " from resources", e); } } + + private static class Cl100kGptBytePairEncoding extends GptBytePairEncoding { + public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) { + super(params); + } + + @Override + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { + int[] tokenCount = {0}; + ArrayList ranks = new ArrayList<>(); + Cl100kParser.split(text, utf8BytesList -> { + byte[] utf8Bytes = new byte[utf8BytesList.size()]; + for (int i = 0; i < utf8BytesList.size(); i++) { + utf8Bytes[i] = utf8BytesList.get(i); + } + tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8Bytes, out, ranks); + return tokenCount[0] >= maxTokenCount; + }); + return tokenCount[0]; + } + } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 2562bb91..34da2537 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -19,9 +19,9 @@ */ class GptBytePairEncoding implements Encoding { + final TokenEncoder encoder; private final String name; private final Pattern pattern; - private final TokenEncoder encoder; private final SpecialEncoder specialEncoder; /** @@ -78,7 +78,7 @@ private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, bo // Make sure we didn't break the multibyte character for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { int size = out.size() - tokensToRemove; - List tokens = new ArrayList<>(size); + ArrayList tokens = new ArrayList<>(size); for (int i = 0; i < size; i++) { tokens.add(out.get(i)); } diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 9aefc200..7a432c4c 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -8,17 +8,17 @@ import static java.util.Collections.emptyMap; public final class TokenEncoder { - public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; - public static final int DUMMY_RANK = MAX_VALUE; public static final int MAX_RANK = MAX_VALUE - 1; + static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; + static final int DUMMY_RANK = MAX_VALUE; private final Map[] encoders; private final Map decoder; private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD; - public TokenEncoder(Map encoder) { + TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { - VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500")); + VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "1500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { ByteArrayWrapper key = new ByteArrayWrapper(k); diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java index f2713af3..7fb4dc1a 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java @@ -44,8 +44,8 @@ public boolean isTruncated() { @Override public String toString() { return "EncodingResult{" - + "tokens=" + getTokens() - + ", tokenCount=" + getTokenCount() + + "tokens=" + tokens + + ", tokenCount=" + tokenCount + ", truncated=" + truncated + '}'; } diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java new file mode 100644 index 00000000..8376a95f --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java @@ -0,0 +1,210 @@ +package com.knuddels.jtokkit; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.URL; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static com.knuddels.jtokkit.Cl100kParser.addUtf8Bytes; +import static com.knuddels.jtokkit.Cl100kParser.isValidUTF8; +import static com.knuddels.jtokkit.EncodingFactory.compileRegex; +import static java.lang.Character.MAX_CODE_POINT; +import static java.lang.Character.MIN_CODE_POINT; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.*; + +public class Cl100kParserTest { + + public static Map fetchUnicodeData() { + var url = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"; + Map unicodeMap = new HashMap<>(); + + try (var br = new BufferedReader(new InputStreamReader(new URL(url).openStream()))) { + String line; + while ((line = br.readLine()) != null) { + var parts = line.split(";"); + assert parts.length > 1; + var codePoint = Integer.parseInt(parts[0], 16); + var name = parts[1]; + unicodeMap.put(codePoint, name); + } + } catch (Exception e) { + throw new IllegalStateException(e); + } + return unicodeMap; + } + + @Disabled // Takes too long + @Test + public void testToUtf8BytesOnFetchedUnicodeData() throws Exception { + fetchUnicodeData().entrySet().stream().parallel().forEach(e -> { + var expected = Character.toString(e.getKey()); + if (isValidUTF8(expected)) { + var dst = new ByteArrayList(); + addUtf8Bytes(expected, 0, expected.length(), dst); + + assertArrayEquals(expected.getBytes(UTF_8), dst.toArray(), () -> "Expected `" + Arrays.toString(expected.getBytes(UTF_8)) + "` (`" + expected + "`) but was `" + Arrays.toString(dst.toArray()) + "`"); + } else { + System.out.println("Skipping invalid UTF-8: " + e.getValue() + " (" + e.getKey() + ")"); + } + }); + } + + @Test + public void testIsApostophed() { + var count = 0; + var pattern = compileRegex("^(?:'s|'t|'re|'ve|'m|'ll|'d)$", true); + + System.out.println("isShortContraction"); + for (var cp1 = MIN_CODE_POINT; cp1 <= MAX_CODE_POINT; cp1++) { // Seems 'ſ is also a contraction... + var asString = "'" + Character.toString(cp1); + var matchesRegex = pattern.matcher(asString).matches(); + var actual = Cl100kParser.isShortContraction(cp1); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + asString + "` (" + cp1 + ")"); + } + + if (false) { // Takes too long + System.out.println("isLongContraction"); + for (var cp1 = MIN_CODE_POINT; cp1 <= MAX_CODE_POINT; cp1++) { + for (var cp2 = MIN_CODE_POINT; cp2 <= MAX_CODE_POINT; cp2++) { + var asString = "'" + Character.toString(cp1) + Character.toString(cp2); + var matchesRegex = pattern.matcher(asString).matches(); + var actual = Cl100kParser.isLongContraction(cp1, cp2); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + asString + "` (" + cp1 + ", " + cp2 + ")"); + } + } + System.out.println(count); + } + } + + @Test + public void testIsNumeric() { + var count = 0; + assertFalse(Cl100kParser.isNumeric(-1)); + var pattern = compileRegex("^\\p{N}$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isNumeric(cp); + if (matchesRegex) { + count++; + } + + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsLetter() { + var count = 0; + assertFalse(Cl100kParser.isLetter(-1)); + var pattern = compileRegex("^\\p{L}$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isLetter(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsUnicodeWhitespace() { + var count = 0; + assertFalse(Cl100kParser.isWhitespace(-1)); + var pattern = compileRegex("^\\s$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isWhitespace(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsLetterOrNumeric() { + var count = 0; + assertFalse(Cl100kParser.isLetterOrNumeric(-1)); + var pattern = compileRegex("^[\\p{L}\\p{N}]$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isLetterOrNumeric(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsNotWhitespaceOrLetterOrNumeric() { + var count = 0; + assertFalse(Cl100kParser.isNotWhitespaceOrLetterOrNumeric(-1)); + var pattern = compileRegex("^[^\\s\\p{L}\\p{N}]$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isNotWhitespaceOrLetterOrNumeric(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsNotNewlineOrLetterOrNumeric() { + var count = 0; + assertFalse(Cl100kParser.isNotNewlineOrLetterOrNumeric(-1)); + var pattern = compileRegex("^[^\r\n\\p{L}\\p{N}]$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var actual = Cl100kParser.isNotNewlineOrLetterOrNumeric(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, actual, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } + + @Test + public void testIsNewline() { + var count = 0; + assertFalse(Cl100kParser.isNewline(-1)); + var pattern = compileRegex("^[\r\n]$", true); + for (var cp = MIN_CODE_POINT; cp <= MAX_CODE_POINT; cp++) { + var charAsString = Character.toString(cp); + var matchesRegex = pattern.matcher(charAsString).matches(); + var isNewline = Cl100kParser.isNewline(cp); + if (matchesRegex) { + count++; + } + assertEquals(matchesRegex, isNewline, "Mismatch at code point: `" + charAsString + "` (" + cp + ")"); + } + System.out.println(count); + } +} \ No newline at end of file diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index e96d3b1b..0cbfd33e 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.TreeMap; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.IntPredicate; import java.util.stream.IntStream; import static java.lang.Character.*; @@ -17,17 +18,24 @@ class Cl100kTest { private static final String PUNCTUATION = "'\".,?!:()"; - private static final String LETTERS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZő你好ſ ½"; - private static final String NUMBERS = "0123456789½"; + private static final String LETTERS = generateUnicodeCategoryString(Cl100kParser::isLetter); + private static final String NUMBERS = generateUnicodeCategoryString(Cl100kParser::isNumeric); + private static final String WHITESPACES = generateUnicodeCategoryString(Cl100kParser::isWhitespace); private static final String NEWLINES = "\n\r"; - private static final String WHITESPACES = " \t " + NEWLINES; - private static final String NOT_NEWLINE_OR_LETTER_OR_NUMERIC = " \t🤚🏾😩" + PUNCTUATION; - private static final String NOT_WHITESPACE_OR_LETTER_OR_NUMERIC = NOT_NEWLINE_OR_LETTER_OR_NUMERIC + NEWLINES; + private static final String NOT_NEWLINE_OR_LETTER_OR_NUMERIC = generateUnicodeCategoryString(Cl100kParser::isNotNewlineOrLetterOrNumeric); + private static final String NOT_WHITESPACE_OR_LETTER_OR_NUMERIC = generateUnicodeCategoryString(Cl100kParser::isNotWhitespaceOrLetterOrNumeric); private static final List SPECIAL = List.of("'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'ſ", "'x", "🤚🏾", "😩", " ", "½"); - private static final Encoding ENCODING = EncodingFactory.cl100kBase(); - private static String normalizeStringForTesting(String testString) { + private static String generateUnicodeCategoryString(IntPredicate characterProperty) { + return IntStream.range(MIN_CODE_POINT, MAX_CODE_POINT) + .filter(Character::isDefined) + .filter(characterProperty) + .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append) + .toString(); + } + + static String normalizeStringForTesting(String testString) { return testString .replaceAll("\\r", "\\\\r") .replaceAll("\\n", "\\\\n") From 5d177b646df148a4c00acad96cdd211fe074277b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Tue, 26 Dec 2023 20:55:43 +0200 Subject: [PATCH 2/9] Avoid primitive boxing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 3.704 ± 0.076 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.871 ± 0.070 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.459 ± 0.065 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.437 ± 0.088 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.097 ± 0.090 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 3.263 ± 0.286 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.688 ± 0.054 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.335 ± 0.106 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.277 ± 0.067 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.002 ± 0.091 s/op --- .../jtokkit/Cl100kParserBenchmark.java | 4 +- .../com/knuddels/jtokkit/ByteArrayList.java | 38 ++++++++++++++ .../com/knuddels/jtokkit/Cl100kParser.java | 8 ++- .../com/knuddels/jtokkit/EncodingFactory.java | 6 +-- .../com/knuddels/jtokkit/TokenEncoder.java | 4 +- .../knuddels/jtokkit/ByteArrayListTest.java | 50 +++++++++++++++++++ 6 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java create mode 100644 lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java index e83d9a38..c4a1c361 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/Cl100kParserBenchmark.java @@ -4,8 +4,6 @@ import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.infra.Blackhole; -import java.util.ArrayList; - public class Cl100kParserBenchmark { @Benchmark public void benchmarkIsLetter(BenchmarkingState state, Blackhole bh) { @@ -58,7 +56,7 @@ public void benchmarkIsWhitespace(BenchmarkingState state, Blackhole bh) { @Benchmark public void benchmarkToUtf8Conversion(BenchmarkingState state, Blackhole bh) { - var dst = new ArrayList(); + var dst = new ByteArrayList(); for (var fileContent : state.fileContents) { bh.consume(Cl100kParser.addUtf8Bytes(fileContent, 0, fileContent.length(), dst)); } diff --git a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java new file mode 100644 index 00000000..a4341bba --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java @@ -0,0 +1,38 @@ +package com.knuddels.jtokkit; + +import java.util.Arrays; + +public class ByteArrayList { + private byte[] array; + private int size; + + public ByteArrayList() { + array = new byte[10]; + size = 0; + } + + public void clear() { + size = 0; + } + + public void add(byte element) { + if (size >= array.length) { + resize(); + } + array[size++] = element; + } + + private void resize() { + byte[] newArray = new byte[array.length * 2]; + System.arraycopy(array, 0, newArray, 0, size); + array = newArray; + } + + int length() { + return size; + } + + public byte[] toByteArray() { + return Arrays.copyOf(array, size); + } +} diff --git a/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java b/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java index 6aeb8afd..5b9296f4 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java +++ b/lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java @@ -1,9 +1,7 @@ package com.knuddels.jtokkit; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.function.Predicate; import static java.lang.Character.*; @@ -15,9 +13,9 @@ public class Cl100kParser { private static final String SIMPLE_WHITESPACES = "\t\n\u000B\u000C\r"; private static final int[] REMAINING_WHITESPACES = "\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u2028\u2029\u202F\u205F\u3000".codePoints().sorted().toArray(); - public static void split(String input, Predicate> fragmentConsumer) { + public static void split(String input, Predicate fragmentConsumer) { assert isValidUTF8(input) : "Input is not UTF-8: " + input; - List utf8Bytes = new ArrayList<>(); + ByteArrayList utf8Bytes = new ByteArrayList(); boolean finished = false; for (int endIndex = 0; endIndex < input.length() && !finished; ) { int startIndex = endIndex; @@ -220,7 +218,7 @@ public static boolean isNotNewlineOrLetterOrNumeric(int ch) { } } - static List addUtf8Bytes(String input, int start, int end, List dst) { + static ByteArrayList addUtf8Bytes(String input, int start, int end, ByteArrayList dst) { dst.clear(); for (int i = start; i < end; i++) { int cp = input.codePointAt(i); diff --git a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java index 235acfa6..bdddfce3 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java +++ b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java @@ -184,11 +184,7 @@ int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings int[] tokenCount = {0}; ArrayList ranks = new ArrayList<>(); Cl100kParser.split(text, utf8BytesList -> { - byte[] utf8Bytes = new byte[utf8BytesList.size()]; - for (int i = 0; i < utf8BytesList.size(); i++) { - utf8Bytes[i] = utf8BytesList.get(i); - } - tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8Bytes, out, ranks); + tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toByteArray(), out, ranks); return tokenCount[0] >= maxTokenCount; }); return tokenCount[0]; diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 7a432c4c..60d3d9f0 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -99,8 +99,8 @@ private static int getPreviousIndex(List ranks, int previousIndex) { return previousIndex; } - int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, ArrayList ranks) { - ByteArrayWrapper match = new ByteArrayWrapper(utf8Bytes); + int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, List out, ArrayList ranks) { + ByteArrayWrapper match = new ByteArrayWrapper(byteArray); int encoded = encode(match); if (encoded != MAX_RANK) { if (keepEncodings) { diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java new file mode 100644 index 00000000..012aa424 --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java @@ -0,0 +1,50 @@ +package com.knuddels.jtokkit; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class ByteArrayListTest { + + private static byte randomByte(Random random) { + return (byte) (random.nextInt() & 0xFF); + } + + @Test + public void testArrayListOperations() { + var byteArrayList = new ByteArrayList(); + var standardList = new ArrayList(); + var random = new Random(); + + for (var i = 0; i < 1_000; i++) { + // Add + if (randomByte(random) % 2 == 0) { + var element = randomByte(random); + var lastIndex = standardList.size(); + byteArrayList.add(element); + standardList.add(element); + assertEquals(standardList.get(lastIndex), byteArrayList.toByteArray()[lastIndex]); + } + + // Size + assertEquals(standardList.size(), byteArrayList.length()); + + // Clear + if (randomByte(random) % 10 == 0) { + byteArrayList.clear(); + standardList.clear(); + assertEquals(standardList.size(), byteArrayList.length()); + } + } + + assertEquals(standardList.size(), byteArrayList.length()); + var byteArray = byteArrayList.toByteArray(); + assertEquals(standardList.size(), byteArray.length); + for (var i = 0; i < byteArrayList.length(); i++) { + assertEquals(standardList.get(i), byteArray[i]); + } + } +} From 9f56cf595ab0025981596883f257af4d20769323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Tue, 26 Dec 2023 22:51:21 +0200 Subject: [PATCH 3/9] Add IntArrayList to store tokens without boxing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 3.263 ± 0.286 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.688 ± 0.054 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.335 ± 0.106 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.277 ± 0.067 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.002 ± 0.091 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.498 ± 0.019 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.223 ± 0.014 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.354 ± 0.122 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.341 ± 0.076 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.068 ± 0.020 s/op --- README.md | 2 +- .../knuddels/jtokkit/AbstractBenchmark.java | 3 +- .../AbstractMultiThreadedBenchmark.java | 56 ++++----- .../com/knuddels/jtokkit/DataDownloader.java | 4 +- .../jtokkit/SingleThreadedBenchmark.java | 3 +- docs/docs/getting-started/usage.md | 6 +- .../com/knuddels/jtokkit/ByteArrayList.java | 83 ++++++++++++-- .../com/knuddels/jtokkit/EncodingFactory.java | 7 +- .../knuddels/jtokkit/GptBytePairEncoding.java | 35 +++--- .../com/knuddels/jtokkit/TokenEncoder.java | 22 ++-- .../knuddels/jtokkit/TokenEncoderLarge.java | 5 +- .../com/knuddels/jtokkit/api/Encoding.java | 10 +- .../knuddels/jtokkit/api/EncodingResult.java | 12 +- .../knuddels/jtokkit/api/IntArrayList.java | 107 ++++++++++++++++++ .../jtokkit/BaseEncodingRegistryTest.java | 9 +- .../knuddels/jtokkit/ByteArrayListTest.java | 47 +++++--- .../jtokkit/ByteArrayWrapperTest.java | 18 +-- .../jtokkit/Cl100kLargeTokenizerTest.java | 1 - .../knuddels/jtokkit/Cl100kParserTest.java | 21 ++-- .../java/com/knuddels/jtokkit/Cl100kTest.java | 2 +- .../knuddels/jtokkit/IntArrayListTest.java | 66 +++++++++++ .../knuddels/jtokkit/reference/TestUtils.java | 12 +- 22 files changed, 398 insertions(+), 133 deletions(-) create mode 100644 lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java create mode 100644 lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java diff --git a/README.md b/README.md index 7fe6e4e3..b654807b 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ retrieve the encoding you want to use. You can then use the `encode` and ```java EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); -List encoded = enc.encode("This is a sample sentence."); +IntArrayList encoded = enc.encode("This is a sample sentence."); // encoded = [2028, 374, 264, 6205, 11914, 13] String decoded = enc.decode(encoded); diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java index 7e2c338c..014b3789 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java @@ -1,6 +1,7 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; import org.openjdk.jmh.annotations.Benchmark; import java.util.List; @@ -34,5 +35,5 @@ public Object benchmarkCl100kBase(BenchmarkingState state) { * @param fileContents the file contents to encode * @return a list of encoded token lists */ - protected abstract List> encodeAll(Encoding encoding, List fileContents); + protected abstract List encodeAll(Encoding encoding, List fileContents); } diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java index a0cc3cfa..80353a54 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java @@ -1,46 +1,48 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; + import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.Collectors; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.TearDown; @State(Scope.Thread) public abstract class AbstractMultiThreadedBenchmark extends AbstractBenchmark { - private final int threads; - private ExecutorService executor; + private final int threads; + private ExecutorService executor; - public AbstractMultiThreadedBenchmark(final int threads) { - this.threads = threads; - } + public AbstractMultiThreadedBenchmark(int threads) { + this.threads = threads; + } - @Setup - public void setup() { - executor = Executors.newFixedThreadPool(threads); - } + @Setup + public void setup() { + executor = Executors.newFixedThreadPool(threads); + } - @TearDown - public void tearDown() { - executor.shutdown(); - } + @TearDown + public void tearDown() { + executor.shutdown(); + } - @Override - protected List> encodeAll(final Encoding encoding, final List fileContents) { - final var futures = fileContents.stream() - .map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor)) - .collect(Collectors.toList()); + @Override + protected List encodeAll(Encoding encoding, List fileContents) { + var futures = fileContents.stream() + .map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor)) + .toList(); - CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); + CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); - return futures.stream() - .map(CompletableFuture::join) - .collect(Collectors.toList()); - } + return futures.stream() + .map(CompletableFuture::join) + .collect(Collectors.toList()); + } } diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java index 64a435d8..37dd7a93 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java @@ -149,8 +149,8 @@ public static void main(String[] args) throws Exception { var patterns = new String[]{ "'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x", "x", + "ő", "123", - "a", ",.!?:; \n", "\n \n", " ", @@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception { } var totalSize = calculateTotalFileSize(rootFolder); - if (totalSize != 99_925_295) { + if (totalSize != 99_945_274) { throw new AssertionError("Total size did not match expected value, actual: " + totalSize); } } diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java index 896757af..c7ce4014 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java @@ -1,6 +1,7 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; import org.openjdk.jmh.annotations.Benchmark; import java.util.List; @@ -18,7 +19,7 @@ public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) { } @Override - protected List> encodeAll(final Encoding encoding, final List fileContents) { + protected List encodeAll(Encoding encoding, List fileContents) { return fileContents.stream() .map(encoding::encode) .toList(); diff --git a/docs/docs/getting-started/usage.md b/docs/docs/getting-started/usage.md index 152119b2..bdd2f6ec 100644 --- a/docs/docs/getting-started/usage.md +++ b/docs/docs/getting-started/usage.md @@ -45,7 +45,7 @@ Optional encoding = registry.getEncodingForModel("gpt_4"); You can use an `Encoding` to encode and decode text: ```java -List encoded = encoding.encode("This is a sample sentence."); +IntArrayList encoded = encoding.encode("This is a sample sentence."); // encoded = [2028, 374, 264, 6205, 11914, 13] String decoded = encoding.decode(encoded); @@ -87,13 +87,13 @@ int tokenCount = encoding.countTokensOrdinary("hello <|endoftext|> world"); If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded tokens to the specified length. They will automatically handle unicode characters that were split in half by the truncation by removing those tokens from the end of the list. ```java -List encoded = encoding.encode("This is a sample sentence.", 3); +IntArrayList encoded = encoding.encode("This is a sample sentence.", 3); // encoded = [2028, 374, 264] String decoded = encoding.decode(encoded); // decoded = "This is a" -List encoded = encoding.encode("I love 🍕", 4); +IntArrayList encoded = encoding.encode("I love 🍕", 4); // encoded = [40, 3021] String decoded = encoding.decode(encoded); diff --git a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java index a4341bba..69a66118 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java +++ b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java @@ -1,14 +1,19 @@ package com.knuddels.jtokkit; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; public class ByteArrayList { private byte[] array; - private int size; + private int size = 0; public ByteArrayList() { - array = new byte[10]; - size = 0; + this(10); + } + + public ByteArrayList(int size) { + array = new byte[size]; } public void clear() { @@ -22,17 +27,81 @@ public void add(byte element) { array[size++] = element; } + public byte get(int index) { + return array[index]; + } + + public int set(int index, byte element) { + int old = array[index]; + array[index] = element; + return old; + } + private void resize() { - byte[] newArray = new byte[array.length * 2]; - System.arraycopy(array, 0, newArray, 0, size); + ensureCapacity(Math.max(1, array.length) * 2); + } + + public void ensureCapacity(int targetSize) { + if (targetSize <= size) { + return; + } + byte[] newArray = new byte[targetSize]; + if (size > 0) { + System.arraycopy(array, 0, newArray, 0, size); + } array = newArray; } - int length() { + public int size() { return size; } - public byte[] toByteArray() { + public boolean isEmpty() { + return size == 0; + } + + public byte[] toArray() { return Arrays.copyOf(array, size); } + + public List boxed() { + List list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + list.add(array[i]); + } + return list; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o == null || getClass() != o.getClass()) { + return false; + } + ByteArrayList that = (ByteArrayList) o; + if (size != that.size) { + return false; + } + for (int i = 0; i < size; i++) { + if (array[i] != that.array[i]) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < size; i++) { + result = 31 * result + array[i]; + } + return result; + } + + @Override + public String toString() { + return boxed().toString(); + } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java index bdddfce3..cad8e718 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java +++ b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java @@ -2,6 +2,7 @@ import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; +import com.knuddels.jtokkit.api.IntArrayList; import java.io.BufferedReader; import java.io.IOException; @@ -180,11 +181,11 @@ public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) { } @Override - int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) { int[] tokenCount = {0}; - ArrayList ranks = new ArrayList<>(); + IntArrayList ranks = new IntArrayList(); Cl100kParser.split(text, utf8BytesList -> { - tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toByteArray(), out, ranks); + tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toArray(), out, ranks); return tokenCount[0] >= maxTokenCount; }); return tokenCount[0]; diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 34da2537..5a1b4aa5 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -3,15 +3,12 @@ import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingResult; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; +import com.knuddels.jtokkit.api.IntArrayList; -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; -import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -37,7 +34,7 @@ class GptBytePairEncoding implements Encoding { } @Override - public List encode(String text) { + public IntArrayList encode(String text) { return encode(text, Integer.MAX_VALUE).getTokens(); } @@ -48,7 +45,7 @@ public EncodingResult encode(String text, int maxTokenCount) { private EncodingResult encodeInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), -1, false); + return new EncodingResult(new IntArrayList(0), -1, false); } specialEncoder.checkForSpecialTokens(text); @@ -57,7 +54,7 @@ private EncodingResult encodeInternal(String text, int maxTokenCount, boolean ke } @Override - public List encodeOrdinary(String text) { + public IntArrayList encodeOrdinary(String text) { return encodeOrdinary(text, Integer.MAX_VALUE).getTokens(); } @@ -68,17 +65,17 @@ public EncodingResult encodeOrdinary(String text, int maxTokenCount) { private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), -1, false); + return new EncodingResult(new IntArrayList(0), -1, false); } - List out = new ArrayList<>(); + IntArrayList out = new IntArrayList(); int tokenCount = encodeOrdinaryInternal(text, maxTokenCount, keepEncodings, out); if (keepEncodings && maxTokenCount != Integer.MAX_VALUE) { // Make sure we didn't break the multibyte character for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { int size = out.size() - tokensToRemove; - ArrayList tokens = new ArrayList<>(size); + IntArrayList tokens = new IntArrayList(size); for (int i = 0; i < size; i++) { tokens.add(out.get(i)); } @@ -93,9 +90,9 @@ private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, bo return new EncodingResult(out, tokenCount, false); } - int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) { int tokenCount = 0; - ArrayList ranks = new ArrayList<>(); // reused to avoid allocations + IntArrayList ranks = new IntArrayList(); // reused to avoid allocations for (Matcher matcher = pattern.matcher(text); tokenCount < maxTokenCount && matcher.find(); ) { byte[] bytes = matcher.group().getBytes(UTF_8); tokenCount += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, bytes, out, ranks); @@ -109,20 +106,20 @@ public int countTokens(String text) { } @Override - public String decode(List tokens) { + public String decode(IntArrayList tokens) { return new String(decodeBytes(tokens), UTF_8); } @Override - public byte[] decodeBytes(List tokens) { - ByteArrayOutputStream out = new ByteArrayOutputStream(10 * tokens.size()); - for (int token : tokens) { - byte[] decodedToken = decodeToken(token); + public byte[] decodeBytes(IntArrayList tokens) { + ByteArrayList out = new ByteArrayList(10 * tokens.size()); + for (int i = 0; i < tokens.size(); i++) { + byte[] decodedToken = decodeToken(tokens.get(i)); for (byte b : decodedToken) { - out.write(b); + out.add(b); } } - return out.toByteArray(); + return out.toArray(); } @Override diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 60d3d9f0..9f7782dc 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,6 +1,10 @@ package com.knuddels.jtokkit; -import java.util.*; +import com.knuddels.jtokkit.api.IntArrayList; + +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; import static java.lang.Integer.MAX_VALUE; @@ -9,7 +13,7 @@ public final class TokenEncoder { public static final int MAX_RANK = MAX_VALUE - 1; - static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; + public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; static final int DUMMY_RANK = MAX_VALUE; private final Map[] encoders; private final Map decoder; @@ -33,11 +37,11 @@ public final class TokenEncoder { } else { //noinspection unchecked encoders = new Map[0]; // for testing - this.decoder = emptyMap(); + decoder = emptyMap(); } } - private static int getMinRankIndex(List ranks) { + private static int getMinRankIndex(IntArrayList ranks) { int minRankIndex = -1; int minRank = MAX_RANK; @@ -85,21 +89,21 @@ private static int getMinRankIndex(List ranks) { return minRankIndex; } - private static int getNextIndex(List ranks, int nextIndex) { + private static int getNextIndex(IntArrayList ranks, int nextIndex) { while (nextIndex < ranks.size() && ranks.get(nextIndex) == DUMMY_RANK) { nextIndex++; } return nextIndex; } - private static int getPreviousIndex(List ranks, int previousIndex) { + private static int getPreviousIndex(IntArrayList ranks, int previousIndex) { while (previousIndex >= 0 && ranks.get(previousIndex) == DUMMY_RANK) { previousIndex--; } return previousIndex; } - int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, List out, ArrayList ranks) { + int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, IntArrayList out, IntArrayList ranks) { ByteArrayWrapper match = new ByteArrayWrapper(byteArray); int encoded = encode(match); if (encoded != MAX_RANK) { @@ -117,7 +121,7 @@ int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteAr } } - private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List out, ArrayList ranks, ByteArrayWrapper match, int length) { + private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntArrayList out, IntArrayList ranks, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; ranks.clear(); ranks.ensureCapacity(length + 1); @@ -149,7 +153,7 @@ private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List< return tokenCount; } - int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List ranks, int validRanks, int minRankIndex) { + int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList ranks, int validRanks, int minRankIndex) { assert getMinRankIndex(ranks) == minRankIndex; while (validRanks > 0) { assert minRankIndex >= 0; diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index ada9b364..4df67ff5 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -1,14 +1,15 @@ package com.knuddels.jtokkit; -import java.util.List; +import com.knuddels.jtokkit.api.IntArrayList; + import java.util.Map.Entry; import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; final class TokenEncoderLarge { - static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ByteArrayWrapper match, int length) { + static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; TreeMap> rankMap = new TreeMap<>(); diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java index 9fe467cd..728faa5f 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java @@ -1,7 +1,5 @@ package com.knuddels.jtokkit.api; -import java.util.List; - public interface Encoding { /** @@ -26,7 +24,7 @@ public interface Encoding { * @return the list of token ids * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now */ - List encode(String text); + IntArrayList encode(String text); /** * Encodes the given text into a list of token ids. @@ -77,7 +75,7 @@ public interface Encoding { * @param text the text to encode * @return the list of token ids */ - List encodeOrdinary(String text); + IntArrayList encodeOrdinary(String text); /** * Encodes the given text into a list of token ids, ignoring special tokens. @@ -139,7 +137,7 @@ public interface Encoding { * @return the decoded text * @throws IllegalArgumentException if the list contains invalid token ids */ - String decode(List tokens); + String decode(IntArrayList tokens); /** * Decodes the given list of token ids into a byte array. @@ -156,7 +154,7 @@ public interface Encoding { * @return the decoded byte array * @throws IllegalArgumentException if the list contains invalid token ids */ - byte[] decodeBytes(List tokens); + byte[] decodeBytes(IntArrayList tokens); /** * Returns the name of this encoding. This is the name which is used to identify diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java index 7fb4dc1a..80953c1c 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java @@ -1,17 +1,15 @@ package com.knuddels.jtokkit.api; -import java.util.List; - /** * The result of encoding operation. */ public final class EncodingResult { - private final List tokens; + private final IntArrayList tokens; private final boolean truncated; private int tokenCount; - public EncodingResult(List tokens, int tokenCount, boolean truncated) { + public EncodingResult(IntArrayList tokens, int tokenCount, boolean truncated) { this.tokens = tokens; this.tokenCount = tokenCount; this.truncated = truncated; @@ -20,7 +18,7 @@ public EncodingResult(List tokens, int tokenCount, boolean truncated) { /** * @return the list of token ids */ - public List getTokens() { + public IntArrayList getTokens() { if (tokens.size() != getTokenCount()) { throw new IllegalStateException("Token count does not match token list size (tokenCount=" + tokenCount + ", tokens size=" + tokens.size() + ")"); } @@ -44,8 +42,8 @@ public boolean isTruncated() { @Override public String toString() { return "EncodingResult{" - + "tokens=" + tokens - + ", tokenCount=" + tokenCount + + "tokens=" + getTokens() + + ", tokenCount=" + getTokenCount() + ", truncated=" + truncated + '}'; } diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java b/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java new file mode 100644 index 00000000..809e3dc6 --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java @@ -0,0 +1,107 @@ +package com.knuddels.jtokkit.api; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class IntArrayList { + private int[] array; + private int size = 0; + + public IntArrayList() { + this(10); + } + + public IntArrayList(int size) { + array = new int[size]; + } + + public void clear() { + size = 0; + } + + public void add(int element) { + if (size >= array.length) { + resize(); + } + array[size++] = element; + } + + public int get(int index) { + return array[index]; + } + + public int set(int index, int element) { + int old = array[index]; + array[index] = element; + return old; + } + + private void resize() { + ensureCapacity(Math.max(1, array.length) * 2); + } + + public void ensureCapacity(int targetSize) { + if (targetSize <= size) { + return; + } + int[] newArray = new int[targetSize]; + if (size > 0) { + System.arraycopy(array, 0, newArray, 0, size); + } + array = newArray; + } + + public int size() { + return size; + } + + public boolean isEmpty() { + return size == 0; + } + + public int[] toArray() { + return Arrays.copyOf(array, size); + } + + public List boxed() { + List list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + list.add(array[i]); + } + return list; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o == null || getClass() != o.getClass()) { + return false; + } + IntArrayList that = (IntArrayList) o; + if (size != that.size) { + return false; + } + for (int i = 0; i < size; i++) { + if (array[i] != that.array[i]) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < size; i++) { + result = 31 * result + array[i]; + } + return result; + } + + @Override + public String toString() { + return boxed().toString(); + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java index 0c9fca76..f34a6f6c 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java @@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test; import java.util.Collections; -import java.util.List; import java.util.function.Consumer; import java.util.regex.Pattern; @@ -143,7 +142,7 @@ void getEncodingReturnsEmptyOptionalForNonExistingEncodingName() { private static class DummyEncoding implements Encoding { @Override - public List encode(String text) { + public IntArrayList encode(String text) { return null; } @@ -153,7 +152,7 @@ public EncodingResult encode(String text, int maxTokens) { } @Override - public List encodeOrdinary(String text) { + public IntArrayList encodeOrdinary(String text) { return null; } @@ -168,12 +167,12 @@ public int countTokens(String text) { } @Override - public String decode(List tokens) { + public String decode(IntArrayList tokens) { return null; } @Override - public byte[] decodeBytes(List tokens) { + public byte[] decodeBytes(IntArrayList tokens) { return new byte[0]; } diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java index 012aa424..c788dcba 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java @@ -5,7 +5,7 @@ import java.util.ArrayList; import java.util.Random; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; class ByteArrayListTest { @@ -14,37 +14,56 @@ private static byte randomByte(Random random) { } @Test - public void testArrayListOperations() { + void testArrayListOperations() { var byteArrayList = new ByteArrayList(); var standardList = new ArrayList(); var random = new Random(); + assertTrue(byteArrayList.isEmpty()); + for (var i = 0; i < 1_000; i++) { // Add - if (randomByte(random) % 2 == 0) { - var element = randomByte(random); - var lastIndex = standardList.size(); - byteArrayList.add(element); - standardList.add(element); - assertEquals(standardList.get(lastIndex), byteArrayList.toByteArray()[lastIndex]); + var element = randomByte(random); + byteArrayList.add(element); + standardList.add(element); + assertEquals(standardList.get(standardList.size() - 1), byteArrayList.get(byteArrayList.size() - 1)); + + // Set + if (!byteArrayList.isEmpty() && random.nextBoolean()) { + var randomIndex = random.nextInt(byteArrayList.size()); + var newElement = randomByte(random); + byteArrayList.set(randomIndex, newElement); + standardList.set(randomIndex, newElement); + assertEquals(standardList.get(randomIndex), byteArrayList.get(randomIndex)); } - // Size - assertEquals(standardList.size(), byteArrayList.length()); + // Size and IsEmpty + assertEquals(standardList.size(), byteArrayList.size()); + assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); // Clear if (randomByte(random) % 10 == 0) { byteArrayList.clear(); standardList.clear(); - assertEquals(standardList.size(), byteArrayList.length()); + assertEquals(standardList.size(), byteArrayList.size()); } } - assertEquals(standardList.size(), byteArrayList.length()); - var byteArray = byteArrayList.toByteArray(); + // Test toArray + var byteArray = byteArrayList.toArray(); assertEquals(standardList.size(), byteArray.length); - for (var i = 0; i < byteArrayList.length(); i++) { + for (var i = 0; i < byteArrayList.size(); i++) { assertEquals(standardList.get(i), byteArray[i]); } + + // Test Equals and HashCode + var anotherByteArrayList = new ByteArrayList(); + standardList.forEach(anotherByteArrayList::add); + + assertEquals(byteArrayList, anotherByteArrayList); + if (!byteArrayList.isEmpty()) { + assertNotEquals(byteArrayList, new ByteArrayList()); + } + assertEquals(byteArrayList.hashCode(), anotherByteArrayList.hashCode()); } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java index 053b2f51..71924051 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java @@ -5,33 +5,33 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -public class ByteArrayWrapperTest { +class ByteArrayWrapperTest { @Test - public void getBytesBetweenReturnsCorrectSliceOfArray() { - final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + void getBytesBetweenReturnsCorrectSliceOfArray() { + ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); assertEquals(new ByteArrayWrapper(new byte[]{4, 5, 6}), byteArray.getBytesBetween(3, 6)); } @Test - public void getBytesBetweenThrowsWhenInclusiveStartIndexOutOfBounds() { - final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + void getBytesBetweenThrowsWhenInclusiveStartIndexOutOfBounds() { + ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(-1, 6)); assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(9, 10)); } @Test - public void getBytesBetweenThrowsWhenExclusiveEndIndexOutOfBounds() { - final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + void getBytesBetweenThrowsWhenExclusiveEndIndexOutOfBounds() { + ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, 7)); assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, -1)); } @Test - public void getBytesBetweenThrowsWhenStartIndexIsGreaterThanEndIndex() { - final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + void getBytesBetweenThrowsWhenStartIndexIsGreaterThanEndIndex() { + ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); assertThrows(IllegalArgumentException.class, () -> byteArray.getBytesBetween(3, 2)); } diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java index debe1c89..ef78a580 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java @@ -1,7 +1,6 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; -import com.knuddels.jtokkit.api.EncodingType; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java index 8376a95f..5236c68f 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kParserTest.java @@ -18,8 +18,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.jupiter.api.Assertions.*; -public class Cl100kParserTest { - +class Cl100kParserTest { public static Map fetchUnicodeData() { var url = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"; Map unicodeMap = new HashMap<>(); @@ -41,7 +40,7 @@ public static Map fetchUnicodeData() { @Disabled // Takes too long @Test - public void testToUtf8BytesOnFetchedUnicodeData() throws Exception { + void testToUtf8BytesOnFetchedUnicodeData() throws Exception { fetchUnicodeData().entrySet().stream().parallel().forEach(e -> { var expected = Character.toString(e.getKey()); if (isValidUTF8(expected)) { @@ -56,7 +55,7 @@ public void testToUtf8BytesOnFetchedUnicodeData() throws Exception { } @Test - public void testIsApostophed() { + void testIsApostophed() { var count = 0; var pattern = compileRegex("^(?:'s|'t|'re|'ve|'m|'ll|'d)$", true); @@ -89,7 +88,7 @@ public void testIsApostophed() { } @Test - public void testIsNumeric() { + void testIsNumeric() { var count = 0; assertFalse(Cl100kParser.isNumeric(-1)); var pattern = compileRegex("^\\p{N}$", true); @@ -107,7 +106,7 @@ public void testIsNumeric() { } @Test - public void testIsLetter() { + void testIsLetter() { var count = 0; assertFalse(Cl100kParser.isLetter(-1)); var pattern = compileRegex("^\\p{L}$", true); @@ -124,7 +123,7 @@ public void testIsLetter() { } @Test - public void testIsUnicodeWhitespace() { + void testIsUnicodeWhitespace() { var count = 0; assertFalse(Cl100kParser.isWhitespace(-1)); var pattern = compileRegex("^\\s$", true); @@ -141,7 +140,7 @@ public void testIsUnicodeWhitespace() { } @Test - public void testIsLetterOrNumeric() { + void testIsLetterOrNumeric() { var count = 0; assertFalse(Cl100kParser.isLetterOrNumeric(-1)); var pattern = compileRegex("^[\\p{L}\\p{N}]$", true); @@ -158,7 +157,7 @@ public void testIsLetterOrNumeric() { } @Test - public void testIsNotWhitespaceOrLetterOrNumeric() { + void testIsNotWhitespaceOrLetterOrNumeric() { var count = 0; assertFalse(Cl100kParser.isNotWhitespaceOrLetterOrNumeric(-1)); var pattern = compileRegex("^[^\\s\\p{L}\\p{N}]$", true); @@ -175,7 +174,7 @@ public void testIsNotWhitespaceOrLetterOrNumeric() { } @Test - public void testIsNotNewlineOrLetterOrNumeric() { + void testIsNotNewlineOrLetterOrNumeric() { var count = 0; assertFalse(Cl100kParser.isNotNewlineOrLetterOrNumeric(-1)); var pattern = compileRegex("^[^\r\n\\p{L}\\p{N}]$", true); @@ -192,7 +191,7 @@ public void testIsNotNewlineOrLetterOrNumeric() { } @Test - public void testIsNewline() { + void testIsNewline() { var count = 0; assertFalse(Cl100kParser.isNewline(-1)); var pattern = compileRegex("^[\r\n]$", true); diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 0cbfd33e..36ce5d96 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -35,7 +35,7 @@ private static String generateUnicodeCategoryString(IntPredicate characterProper .toString(); } - static String normalizeStringForTesting(String testString) { + private static String normalizeStringForTesting(String testString) { return testString .replaceAll("\\r", "\\\\r") .replaceAll("\\n", "\\\\n") diff --git a/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java new file mode 100644 index 00000000..909f6a7f --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java @@ -0,0 +1,66 @@ +package com.knuddels.jtokkit; + +import com.knuddels.jtokkit.api.IntArrayList; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; + +class IntArrayListTest { + + @Test + void testArrayListOperations() { + var byteArrayList = new IntArrayList(); + var standardList = new ArrayList(); + var random = new Random(); + + assertTrue(byteArrayList.isEmpty()); + + for (var i = 0; i < 1_000; i++) { + // Add + var element = random.nextInt(); + byteArrayList.add(element); + standardList.add(element); + assertEquals(standardList.get(standardList.size() - 1), byteArrayList.get(byteArrayList.size() - 1)); + + // Set + if (!byteArrayList.isEmpty() && random.nextBoolean()) { + var randomIndex = random.nextInt(byteArrayList.size()); + var newElement = random.nextInt(); + byteArrayList.set(randomIndex, newElement); + standardList.set(randomIndex, newElement); + assertEquals(standardList.get(randomIndex), byteArrayList.get(randomIndex)); + } + + // Size and IsEmpty + assertEquals(standardList.size(), byteArrayList.size()); + assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); + + // Clear + if (random.nextInt() % 10 == 0) { + byteArrayList.clear(); + standardList.clear(); + assertEquals(standardList.size(), byteArrayList.size()); + } + } + + // Test toArray + var byteArray = byteArrayList.toArray(); + assertEquals(standardList.size(), byteArray.length); + for (var i = 0; i < byteArrayList.size(); i++) { + assertEquals(standardList.get(i), byteArray[i]); + } + + // Test Equals and HashCode + var anotherIntArrayList = new IntArrayList(); + standardList.forEach(anotherIntArrayList::add); + + assertEquals(byteArrayList, anotherIntArrayList); + if (!byteArrayList.isEmpty()) { + assertNotEquals(byteArrayList, new IntArrayList()); + } + assertEquals(byteArrayList.hashCode(), anotherIntArrayList.hashCode()); + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java b/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java index 3b8961f1..ae8c48a9 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java @@ -1,17 +1,21 @@ package com.knuddels.jtokkit.reference; +import com.knuddels.jtokkit.api.IntArrayList; + import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; class TestUtils { - static List parseEncodingString(String encodingString) { - return Arrays.stream( + static IntArrayList parseEncodingString(String encodingString) { + List list = Arrays.stream( encodingString.substring(1, encodingString.length() - 1) .replaceAll(" ", "") .split(",") ).map(Integer::parseInt) - .collect(Collectors.toList()); + .toList(); + var result = new IntArrayList(list.size()); + list.forEach(result::add); + return result; } } From 58b8de1185bfb41e9ba7eb39c2f1651b0fa11b3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Wed, 3 Jan 2024 13:27:58 +0100 Subject: [PATCH 4/9] Replace internal TreeMap with faster LinkedHashMap iteration without ceilingEntry in TokenEncoderLarge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.596 ± 0.269 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.191 ± 0.028 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.442 ± 0.061 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.466 ± 0.032 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.081 ± 0.106 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.365 ± 0.019 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.130 ± 0.024 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.393 ± 0.026 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.408 ± 0.015 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.073 ± 0.017 s/op --- .../com/knuddels/jtokkit/TokenEncoder.java | 2 +- .../knuddels/jtokkit/TokenEncoderLarge.java | 33 ++++++++++--------- .../java/com/knuddels/jtokkit/Cl100kTest.java | 2 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 9f7782dc..1325ef8a 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -22,7 +22,7 @@ public final class TokenEncoder { TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { - VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "1500")); + VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { ByteArrayWrapper key = new ByteArrayWrapper(k); diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index 4df67ff5..74f9877e 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -3,16 +3,19 @@ import com.knuddels.jtokkit.api.IntArrayList; -import java.util.Map.Entry; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; +import static java.util.Objects.requireNonNull; final class TokenEncoderLarge { static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; - TreeMap> rankMap = new TreeMap<>(); + TreeMap> rankMap = new TreeMap<>(); RankNode head = null; RankNode prevNode = null; @@ -31,14 +34,12 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo } prevNode = node; - rankMap.computeIfAbsent(encoded, k -> new TreeMap<>()).put(i, node); + rankMap.computeIfAbsent(encoded, k -> new LinkedHashMap<>()).put(i, node); } while (validRanks > 0) { - TreeMap minNodes = rankMap.pollFirstEntry().getValue(); - int firstIndex; - for (Entry entry = minNodes.firstEntry(); entry != null; entry = minNodes.ceilingEntry(firstIndex)) { - RankNode minNode = entry.getValue(); + for (Iterator it = rankMap.pollFirstEntry().getValue().values().iterator(); it.hasNext(); ) { + RankNode minNode = it.next(); int minRank = minNode.rank; assert minRank != MAX_RANK; @@ -56,7 +57,7 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo assert previousNode.rank != minRank; removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); previousNode.rank = newRank; - rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); + rankMap.computeIfAbsent(newRank, k -> new LinkedHashMap<>()).put(previousNode.index, previousNode); } } @@ -64,9 +65,8 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo if (newRank == MAX_RANK) { validRanks--; } - firstIndex = minNode.index + 1; minNode.rank = newRank; - rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(minNode.index, minNode); + rankMap.computeIfAbsent(newRank, k -> new LinkedHashMap<>()).put(minNode.index, minNode); minNode.next = nextNextNode; if (nextNode != null) { @@ -77,9 +77,10 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo validRanks--; if (nextNode.rank != minRank) { removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); + } else { + it.next(); } } - firstIndex = nextNode.index + 1; } length--; @@ -99,8 +100,8 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo return length; } - static void removeNode(TreeMap nodeMap, TreeMap> rankMap, RankNode nextNode) { - if (nodeMap.size() == 1) { + static void removeNode(Map nodeMap, Map> rankMap, RankNode nextNode) { + if (requireNonNull(nodeMap).size() == 1) { assert nodeMap.containsKey(nextNode.index); rankMap.remove(nextNode.rank); } else { @@ -121,9 +122,9 @@ private static class RankNode { @Override public String toString() { return "RankNode{" + - "rank=" + rank + - ", index=" + index + - '}'; + "rank=" + rank + + ", index=" + index + + '}'; } } } \ No newline at end of file diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 36ce5d96..41a109d8 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -57,7 +57,7 @@ void measureEncodingSpeeds() { var measurements = new TreeMap(); var iterations = 20; - for (var i = 1.0; i < 2_000; i = Math.max(i + 1, i * 1.01)) { + for (var i = 1.0; i < 1000; i = Math.max(i + 1, i * 1.01)) { while (input.length() < i) { input.append("a"); } From 36c5fa71da481b0294cb982c6a59c169ec1df325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Thu, 4 Jan 2024 19:20:39 +0100 Subject: [PATCH 5/9] Simplify TokenEncoder.encode when we're attempting to encode the whole piece MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since whole pieces were already checked, we don't have to try to reencode them Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.365 ± 0.019 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.130 ± 0.024 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.393 ± 0.026 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.408 ± 0.015 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.073 ± 0.017 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.340 ± 0.023 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.096 ± 0.029 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.385 ± 0.017 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.372 ± 0.041 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.059 ± 0.026 s/op --- lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 1325ef8a..118184af 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -205,10 +205,8 @@ private int encode(ByteArrayWrapper payload) { } int encode(ByteArrayWrapper piece, int start, int end) { - if (end > piece.length()) { + if (end > piece.length() || end - start == piece.length()) { return MAX_RANK; - } else if (end - start == piece.length()) { - return encode(piece); } else { return encode(piece.getBytesBetween(start, end)); } From 72e7aad6486ad66db1de747cad58d78b52dad5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sat, 6 Jan 2024 20:10:28 +0100 Subject: [PATCH 6/9] Update non-standard csv examples --- lib/src/test/resources/cl100k_base_encodings.csv | 4 ++-- lib/src/test/resources/p50k_base_encodings.csv | 4 ++-- lib/src/test/resources/p50k_edit_encodings.csv | 4 ++-- lib/src/test/resources/r50k_base_encodings.csv | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/src/test/resources/cl100k_base_encodings.csv b/lib/src/test/resources/cl100k_base_encodings.csv index 6ca65e20..b92f9d0b 100644 --- a/lib/src/test/resources/cl100k_base_encodings.csv +++ b/lib/src/test/resources/cl100k_base_encodings.csv @@ -420,5 +420,5 @@ Quel est votre caractère chinois préféré ? Et comment le dessiner ?,"[2232, "Olá, como vai você?","[43819, 1995, 11, 8112, 40586, 25738, 30]","[43819, 1995, 11, 8112, 40586, 25738, 30]" "Здравствуй, как поживаете?","[36551, 7094, 28086, 20812, 83680, 11, 52770, 5173, 21956, 28089, 28007, 1532, 30]","[36551, 7094, 28086, 20812, 83680, 11, 52770, 5173, 21956, 28089]" "Hola, ¿cómo estás?","[69112, 11, 29386, 66, 72561, 1826, 7206, 30]","[69112, 11, 29386, 66, 72561, 1826, 7206, 30]" -"  ", "[44529]", "[44529]" -"  a", "[23249, 23249, 64]", "[23249, 23249, 64]" +"  ","[44529]","[44529]" +"  a","[23249, 23249, 64]","[23249, 23249, 64]" diff --git a/lib/src/test/resources/p50k_base_encodings.csv b/lib/src/test/resources/p50k_base_encodings.csv index ca1dc49a..9602392e 100644 --- a/lib/src/test/resources/p50k_base_encodings.csv +++ b/lib/src/test/resources/p50k_base_encodings.csv @@ -420,5 +420,5 @@ Quel est votre caractère chinois préféré ? Et comment le dessiner ?,"[48, 27 "Olá, como vai você?","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]" "Здравствуй, как поживаете?","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072, 140, 117, 11, 12466, 118, 16142, 31583, 12466, 123, 25443, 114, 18849, 38857, 16142, 16843, 20375, 16843, 30]","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072]" "Hola, ¿cómo estás?","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138, 30]","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138]" -"  ", "[5099, 222, 5099, 222]", "[5099, 222, 5099, 222]" -"  a", "[5099, 222, 5099, 222, 64]", "[5099, 222, 5099, 222, 64]" +"  ","[5099, 222, 5099, 222]","[5099, 222, 5099, 222]" +"  a","[5099, 222, 5099, 222, 64]","[5099, 222, 5099, 222, 64]" diff --git a/lib/src/test/resources/p50k_edit_encodings.csv b/lib/src/test/resources/p50k_edit_encodings.csv index 2179bc8e..88e62280 100644 --- a/lib/src/test/resources/p50k_edit_encodings.csv +++ b/lib/src/test/resources/p50k_edit_encodings.csv @@ -420,6 +420,6 @@ Quel est votre caractère chinois préféré ? Et comment le dessiner ?,"[48, 27 "Olá, como vai você?","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]" "Здравствуй, как поживаете?","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072, 140, 117, 11, 12466, 118, 16142, 31583, 12466, 123, 25443, 114, 18849, 38857, 16142, 16843, 20375, 16843, 30]","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072]" "Hola, ¿cómo estás?","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138, 30]","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138]" -"  ", "[5099, 222, 5099, 222]", "[5099, 222, 5099, 222]" -"  a", "[5099, 222, 5099, 222, 64]", "[5099, 222, 5099, 222, 64]" +"  ","[5099, 222, 5099, 222]","[5099, 222, 5099, 222]" +"  a","[5099, 222, 5099, 222, 64]","[5099, 222, 5099, 222, 64]" diff --git a/lib/src/test/resources/r50k_base_encodings.csv b/lib/src/test/resources/r50k_base_encodings.csv index 9e858f13..585f8a60 100644 --- a/lib/src/test/resources/r50k_base_encodings.csv +++ b/lib/src/test/resources/r50k_base_encodings.csv @@ -420,5 +420,5 @@ Quel est votre caractère chinois préféré ? Et comment le dessiner ?,"[48, 27 "Olá, como vai você?","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]","[30098, 6557, 11, 401, 78, 410, 1872, 12776, 25792, 30]" "Здравствуй, как поживаете?","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072, 140, 117, 11, 12466, 118, 16142, 31583, 12466, 123, 25443, 114, 18849, 38857, 16142, 16843, 20375, 16843, 30]","[140, 245, 43666, 21169, 16142, 38857, 21727, 20375, 38857, 35072]" "Hola, ¿cómo estás?","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138, 30]","[39, 5708, 11, 1587, 123, 66, 10205, 5908, 1556, 40138]" -"  ", "[5099, 222, 5099, 222]", "[5099, 222, 5099, 222]" -"  a", "[5099, 222, 5099, 222, 64]", "[5099, 222, 5099, 222, 64]" +"  ","[5099, 222, 5099, 222]","[5099, 222, 5099, 222]" +"  a","[5099, 222, 5099, 222, 64]","[5099, 222, 5099, 222, 64]" From 0411aaa1e73b800892df19e037591330a6c89aa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 8 Jan 2024 12:51:29 +0100 Subject: [PATCH 7/9] Remove validRanks, there are simpler ways to keep track of iteration count We're also skipping the last getMinRankIndex calculation when we have 2 remaining tokens --- .../com/knuddels/jtokkit/TokenEncoder.java | 42 +++++++------------ .../knuddels/jtokkit/TokenEncoderLarge.java | 29 ++++--------- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 118184af..e2872788 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -112,26 +112,24 @@ int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteAr } return 1; } else { - int length = match.length(); - if (length < VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) { - return calculateTokensSmall(maxTokenCount, keepEncodings, out, ranks, match, length); + if (match.length() < VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) { + return calculateTokensSmall(maxTokenCount, keepEncodings, out, ranks, match); } else { - return calculateTokensLarge(this, maxTokenCount, keepEncodings, out, match, length); + return calculateTokensLarge(this, maxTokenCount, keepEncodings, out, match); } } } - private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntArrayList out, IntArrayList ranks, ByteArrayWrapper match, int length) { + private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntArrayList out, IntArrayList ranks, ByteArrayWrapper match) { + int length = match.length(); assert length > 1 : "Already filtered out"; ranks.clear(); ranks.ensureCapacity(length + 1); - int validRanks = 0; int minRankIndex = -1; for (int i = 0, minRank = MAX_RANK; i < length + 1; i++) { int encoded = encode(match, i, i + 2); if (encoded != MAX_RANK) { - validRanks++; if (encoded < minRank) { minRankIndex = i; minRank = encoded; @@ -139,7 +137,7 @@ private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntAr } ranks.add(encoded); } - int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks, minRankIndex); + int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, minRankIndex); if (keepEncodings) { for (int start = 0, end = 1; end < ranks.size() && out.size() < maxTokenCount; end++) { if (ranks.get(end) != DUMMY_RANK) { @@ -153,11 +151,9 @@ private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntAr return tokenCount; } - int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList ranks, int validRanks, int minRankIndex) { + int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList ranks, int minRankIndex) { assert getMinRankIndex(ranks) == minRankIndex; - while (validRanks > 0) { - assert minRankIndex >= 0; - + while (minRankIndex >= 0) { int previousIndex = getPreviousIndex(ranks, minRankIndex - 1); int nextIndex = getNextIndex(ranks, minRankIndex + 1); int nextNextIndex = getNextIndex(ranks, nextIndex + 1); @@ -166,26 +162,20 @@ int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList if (previousIndex >= 0) { assert ranks.get(previousIndex) != DUMMY_RANK; int newRank = encode(piece, previousIndex, nextNextIndex); - int oldRank = ranks.set(previousIndex, newRank); - if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) { - validRanks -= (newRank == MAX_RANK) ? 1 : -1; - } + ranks.set(previousIndex, newRank); } assert ranks.get(minRankIndex) != DUMMY_RANK; int newRank = encode(piece, minRankIndex, nextNextNextIndex); - int oldRank = ranks.set(minRankIndex, newRank); - if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) { - validRanks--; - } + ranks.set(minRankIndex, newRank); - int oldDeletedRank = ranks.set(nextIndex, DUMMY_RANK); - if (oldDeletedRank != MAX_RANK) { - validRanks--; - } + ranks.set(nextIndex, DUMMY_RANK); length--; - - minRankIndex = getMinRankIndex(ranks); + if (length < 3) { + break; // single tokens were already filtered out, let's skip a minimum calculation + } else { + minRankIndex = getMinRankIndex(ranks); + } } assert getMinRankIndex(ranks) < 0; return length; diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index 74f9877e..dc766f89 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -12,19 +12,16 @@ import static java.util.Objects.requireNonNull; final class TokenEncoderLarge { - static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match, int length) { + static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match) { + int length = match.length(); assert length > 1 : "Already filtered out"; TreeMap> rankMap = new TreeMap<>(); RankNode head = null; RankNode prevNode = null; - int validRanks = 0; for (int i = 0; i < length + 1; i++) { int encoded = tokenEncoder.encode(match, i, i + 2); - if (encoded != MAX_RANK) { - validRanks++; - } RankNode node = new RankNode(encoded, i); if (head == null) { head = node; @@ -36,8 +33,8 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo rankMap.computeIfAbsent(encoded, k -> new LinkedHashMap<>()).put(i, node); } - - while (validRanks > 0) { + assert rankMap.containsKey(MAX_RANK); + while (rankMap.size() > 1) { for (Iterator it = rankMap.pollFirstEntry().getValue().values().iterator(); it.hasNext(); ) { RankNode minNode = it.next(); int minRank = minNode.rank; @@ -51,9 +48,6 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo if (previousNode != null) { int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); if (previousNode.rank != newRank) { - if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { - validRanks -= (newRank == MAX_RANK) ? 1 : -1; - } assert previousNode.rank != minRank; removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); previousNode.rank = newRank; @@ -62,19 +56,13 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo } int newRank = tokenEncoder.encode(match, minNode.index, nextNextNextNode != null ? nextNextNextNode.index : Integer.MAX_VALUE); - if (newRank == MAX_RANK) { - validRanks--; - } minNode.rank = newRank; rankMap.computeIfAbsent(newRank, k -> new LinkedHashMap<>()).put(minNode.index, minNode); minNode.next = nextNextNode; - if (nextNode != null) { - if (nextNextNode != null) { - nextNextNode.prev = minNode; - } + if (nextNode != null && nextNextNode != null) { + nextNextNode.prev = minNode; if (nextNode.rank != MAX_RANK) { - validRanks--; if (nextNode.rank != minRank) { removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); } else { @@ -121,10 +109,7 @@ private static class RankNode { @Override public String toString() { - return "RankNode{" + - "rank=" + rank + - ", index=" + index + - '}'; + return "RankNode{rank=" + rank + ", index=" + index + '}'; } } } \ No newline at end of file From d04d6602795072373dcf9da42e49d233a5789110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 12 Jan 2024 13:10:44 +0100 Subject: [PATCH 8/9] Polish TokenEncoderLarge --- .../knuddels/jtokkit/TokenEncoderLarge.java | 70 ++++++++----------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index dc766f89..ec2eca09 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -13,40 +13,35 @@ final class TokenEncoderLarge { static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match) { - int length = match.length(); - assert length > 1 : "Already filtered out"; - TreeMap> rankMap = new TreeMap<>(); - RankNode head = null; - RankNode prevNode = null; - for (int i = 0; i < length + 1; i++) { - int encoded = tokenEncoder.encode(match, i, i + 2); - RankNode node = new RankNode(encoded, i); - if (head == null) { - head = node; - } else { - prevNode.next = node; - node.prev = prevNode; + RankNode prev = null; + for (int i = 0; i < match.length() + 1; i++) { + int rank = tokenEncoder.encode(match, i, i + 2); + RankNode node = new RankNode(rank, i, prev); + if (prev != null) { + prev.next = node; } - prevNode = node; + prev = node; - rankMap.computeIfAbsent(encoded, k -> new LinkedHashMap<>()).put(i, node); + rankMap.computeIfAbsent(rank, k -> new LinkedHashMap<>()).put(i, node); } assert rankMap.containsKey(MAX_RANK); - while (rankMap.size() > 1) { + + int tokenCount = match.length(); + while (tokenCount > 2 && rankMap.size() > 1) { for (Iterator it = rankMap.pollFirstEntry().getValue().values().iterator(); it.hasNext(); ) { RankNode minNode = it.next(); int minRank = minNode.rank; assert minRank != MAX_RANK; - RankNode previousNode = minNode.prev; - RankNode nextNode = minNode.next; - RankNode nextNextNode = nextNode != null ? nextNode.next : null; - RankNode nextNextNextNode = nextNextNode != null ? nextNextNode.next : null; + RankNode previousNode = minNode.prev, + nextNode = minNode.next, + nextNextNode = nextNode.next, + nextNextNextNode = nextNextNode.next; if (previousNode != null) { - int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); + int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode.index); if (previousNode.rank != newRank) { assert previousNode.rank != minRank; removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); @@ -60,40 +55,36 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo rankMap.computeIfAbsent(newRank, k -> new LinkedHashMap<>()).put(minNode.index, minNode); minNode.next = nextNextNode; - if (nextNode != null && nextNextNode != null) { - nextNextNode.prev = minNode; - if (nextNode.rank != MAX_RANK) { - if (nextNode.rank != minRank) { - removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); - } else { - it.next(); - } + nextNextNode.prev = minNode; + if (nextNode.rank != MAX_RANK) { + if (nextNode.rank != minRank) { + removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); + } else { + it.next(); } } - length--; + tokenCount--; } } - assert rankMap.firstEntry().getValue().values().iterator().next().rank == MAX_RANK; if (keepEncodings) { - while (head.next != null && out.size() < maxTokenCount) { + for (RankNode head = rankMap.get(MAX_RANK).get(0); head.next != null && out.size() < maxTokenCount; head = head.next) { int token = tokenEncoder.encode(match, head.index, head.next.index); assert token != MAX_RANK : "Token should not be MAX_RANK"; out.add(token); - head = head.next; } } - return length; + return tokenCount; } - static void removeNode(Map nodeMap, Map> rankMap, RankNode nextNode) { + static void removeNode(Map nodeMap, Map> rankMap, RankNode node) { if (requireNonNull(nodeMap).size() == 1) { - assert nodeMap.containsKey(nextNode.index); - rankMap.remove(nextNode.rank); + assert nodeMap.containsKey(node.index); + rankMap.remove(node.rank); } else { - nodeMap.remove(nextNode.index); + nodeMap.remove(node.index); } } @@ -102,9 +93,10 @@ private static class RankNode { int index; RankNode prev, next; - RankNode(int rank, int index) { + RankNode(int rank, int index, RankNode prev) { this.rank = rank; this.index = index; + this.prev = prev; } @Override From 781345227759348d45bc2f5753ad378caa1aee0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sun, 14 Jan 2024 12:17:26 +0100 Subject: [PATCH 9/9] Cover remaining primitive list methods with tests fully --- .../knuddels/jtokkit/ByteArrayListTest.java | 18 +++++++++++++++--- .../com/knuddels/jtokkit/IntArrayListTest.java | 17 ++++++++++++++--- ...nizerTest.java => Cl100kLargeBaseTest.java} | 2 +- 3 files changed, 30 insertions(+), 7 deletions(-) rename lib/src/test/java/com/knuddels/jtokkit/reference/{Cl100kLargeTokenizerTest.java => Cl100kLargeBaseTest.java} (93%) diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java index c788dcba..74cb8c8f 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java @@ -41,6 +41,10 @@ void testArrayListOperations() { assertEquals(standardList.size(), byteArrayList.size()); assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); + // Boxed and ToString + assertEquals(standardList, byteArrayList.boxed()); + assertEquals(standardList.toString(), byteArrayList.toString()); + // Clear if (randomByte(random) % 10 == 0) { byteArrayList.clear(); @@ -48,6 +52,10 @@ void testArrayListOperations() { assertEquals(standardList.size(), byteArrayList.size()); } } + var element = randomByte(random); + byteArrayList.add(element); + standardList.add(element); + // Test toArray var byteArray = byteArrayList.toArray(); @@ -60,10 +68,14 @@ void testArrayListOperations() { var anotherByteArrayList = new ByteArrayList(); standardList.forEach(anotherByteArrayList::add); + assertEquals(byteArrayList, byteArrayList); assertEquals(byteArrayList, anotherByteArrayList); - if (!byteArrayList.isEmpty()) { - assertNotEquals(byteArrayList, new ByteArrayList()); - } assertEquals(byteArrayList.hashCode(), anotherByteArrayList.hashCode()); + + assertNotEquals(byteArrayList, new Object()); + anotherByteArrayList.set(0, (byte) (byteArrayList.get(0) + 1)); + assertNotEquals(byteArrayList, anotherByteArrayList); + + assertNotEquals(byteArrayList, new ByteArrayList()); } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java index 909f6a7f..4b56b92b 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java @@ -38,6 +38,10 @@ void testArrayListOperations() { assertEquals(standardList.size(), byteArrayList.size()); assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); + // Boxed and ToString + assertEquals(standardList, byteArrayList.boxed()); + assertEquals(standardList.toString(), byteArrayList.toString()); + // Clear if (random.nextInt() % 10 == 0) { byteArrayList.clear(); @@ -45,6 +49,9 @@ void testArrayListOperations() { assertEquals(standardList.size(), byteArrayList.size()); } } + var element = random.nextInt(); + byteArrayList.add(element); + standardList.add(element); // Test toArray var byteArray = byteArrayList.toArray(); @@ -57,10 +64,14 @@ void testArrayListOperations() { var anotherIntArrayList = new IntArrayList(); standardList.forEach(anotherIntArrayList::add); + assertEquals(byteArrayList, byteArrayList); assertEquals(byteArrayList, anotherIntArrayList); - if (!byteArrayList.isEmpty()) { - assertNotEquals(byteArrayList, new IntArrayList()); - } assertEquals(byteArrayList.hashCode(), anotherIntArrayList.hashCode()); + + assertNotEquals(byteArrayList, new Object()); + anotherIntArrayList.set(0, byteArrayList.get(0) + 1); + assertNotEquals(byteArrayList, anotherIntArrayList); + + assertNotEquals(byteArrayList, new IntArrayList()); } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeBaseTest.java similarity index 93% rename from lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java rename to lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeBaseTest.java index c7ea36fa..69a304cc 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeBaseTest.java @@ -8,7 +8,7 @@ import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; -class Cl100kLargeTokenizerTest extends Cl100kBaseTest { +class Cl100kLargeBaseTest extends Cl100kBaseTest { public static Encoding ENCODING;