Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ retrieve the encoding you want to use. You can then use the `encode` and
```java
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE);
List<Integer> encoded = enc.encode("This is a sample sentence.");
IntArrayList encoded = enc.encode("This is a sample sentence.");
// encoded = [2028, 374, 264, 6205, 11914, 13]

String decoded = enc.decode(encoded);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Benchmark;

import java.util.List;
Expand Down Expand Up @@ -34,5 +35,5 @@ public Object benchmarkCl100kBase(BenchmarkingState state) {
* @param fileContents the file contents to encode
* @return a list of encoded token lists
*/
protected abstract List<List<Integer>> encodeAll(Encoding encoding, List<String> fileContents);
protected abstract List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents);
}
Original file line number Diff line number Diff line change
@@ -1,46 +1,48 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;

@State(Scope.Thread)
public abstract class AbstractMultiThreadedBenchmark extends AbstractBenchmark {

private final int threads;
private ExecutorService executor;
private final int threads;
private ExecutorService executor;

public AbstractMultiThreadedBenchmark(final int threads) {
this.threads = threads;
}
public AbstractMultiThreadedBenchmark(int threads) {
this.threads = threads;
}

@Setup
public void setup() {
executor = Executors.newFixedThreadPool(threads);
}
@Setup
public void setup() {
executor = Executors.newFixedThreadPool(threads);
}

@TearDown
public void tearDown() {
executor.shutdown();
}
@TearDown
public void tearDown() {
executor.shutdown();
}

@Override
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
final var futures = fileContents.stream()
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
.collect(Collectors.toList());
@Override
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
var futures = fileContents.stream()
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
.toList();

CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();

return futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList());
}
return futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.knuddels.jtokkit;


import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.infra.Blackhole;

public class Cl100kParserBenchmark {
@Benchmark
public void benchmarkIsLetter(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetter(cp)));
}
}

@Benchmark
public void benchmarkIsLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNewline(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNewline(cp)));
}
}

@Benchmark
public void benchmarkIsNotNewlineOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotNewlineOrLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNotWhitespaceOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotWhitespaceOrLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsWhitespace(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isWhitespace(cp)));
}
}

@Benchmark
public void benchmarkToUtf8Conversion(BenchmarkingState state, Blackhole bh) {
var dst = new ByteArrayList();
for (var fileContent : state.fileContents) {
bh.consume(Cl100kParser.addUtf8Bytes(fileContent, 0, fileContent.length(), dst));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ public static void main(String[] args) throws Exception {
var patterns = new String[]{
"'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x",
"x",
"ő",
"123",
"a",
",.!?:; \n",
"\n \n",
" ",
Expand All @@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception {
}

var totalSize = calculateTotalFileSize(rootFolder);
if (totalSize != 99_925_295) {
if (totalSize != 99_945_274) {
throw new AssertionError("Total size did not match expected value, actual: " + totalSize);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Benchmark;

import java.util.List;
Expand All @@ -18,7 +19,7 @@ public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) {
}

@Override
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
return fileContents.stream()
.map(encoding::encode)
.toList();
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/getting-started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Optional<Encoding> encoding = registry.getEncodingForModel("gpt_4");
You can use an `Encoding` to encode and decode text:

```java
List<Integer> encoded = encoding.encode("This is a sample sentence.");
IntArrayList encoded = encoding.encode("This is a sample sentence.");
// encoded = [2028, 374, 264, 6205, 11914, 13]

String decoded = encoding.decode(encoded);
Expand Down Expand Up @@ -87,13 +87,13 @@ int tokenCount = encoding.countTokensOrdinary("hello <|endoftext|> world");
If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded tokens to the specified length. They will automatically handle unicode characters that were split in half by the truncation by removing those tokens from the end of the list.

```java
List<Integer> encoded = encoding.encode("This is a sample sentence.", 3);
IntArrayList encoded = encoding.encode("This is a sample sentence.", 3);
// encoded = [2028, 374, 264]

String decoded = encoding.decode(encoded);
// decoded = "This is a"

List<Integer> encoded = encoding.encode("I love 🍕", 4);
IntArrayList encoded = encoding.encode("I love 🍕", 4);
// encoded = [40, 3021]

String decoded = encoding.decode(encoded);
Expand Down
107 changes: 107 additions & 0 deletions lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.knuddels.jtokkit;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ByteArrayList {
private byte[] array;
private int size = 0;

public ByteArrayList() {
this(10);
}

public ByteArrayList(int size) {
array = new byte[size];
}

public void clear() {
size = 0;
}

public void add(byte element) {
if (size >= array.length) {
resize();
}
array[size++] = element;
}

public byte get(int index) {
return array[index];
}

public int set(int index, byte element) {
int old = array[index];
array[index] = element;
return old;
}

private void resize() {
ensureCapacity(Math.max(1, array.length) * 2);
}

public void ensureCapacity(int targetSize) {
if (targetSize <= size) {
return;
}
byte[] newArray = new byte[targetSize];
if (size > 0) {
System.arraycopy(array, 0, newArray, 0, size);
}
array = newArray;
}

public int size() {
return size;
}

public boolean isEmpty() {
return size == 0;
}

public byte[] toArray() {
return Arrays.copyOf(array, size);
}

public List<Byte> boxed() {
List<Byte> list = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
list.add(array[i]);
}
return list;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
} else if (o == null || getClass() != o.getClass()) {
return false;
}
ByteArrayList that = (ByteArrayList) o;
if (size != that.size) {
return false;
}
for (int i = 0; i < size; i++) {
if (array[i] != that.array[i]) {
return false;
}
}
return true;
}

@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < size; i++) {
result = 31 * result + array[i];
}
return result;
}

@Override
public String toString() {
return boxed().toString();
}
}
Loading