diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml new file mode 100644 index 0000000..2975731 --- /dev/null +++ b/crates/bpe-openai/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "bpe-openai" +version = "0.1.0" +edition = "2021" +description = "Prebuilt fast byte-pair encoders for OpenAI." +repository = "https://github.com/github/rust-gems" +license = "MIT" +keywords = ["tokenizer", "algorithm", "encoding", "bpe"] +categories = ["algorithms", "data-structures", "encoding", "science"] + +[lib] +crate-type = ["lib", "staticlib"] +bench = false + +[dependencies] +bpe = { version = "0.1.0", path = "../bpe" } +rmp-serde = "1" +serde = { version = "1" } + +[build-dependencies] +bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken-rs"] } +rmp-serde = "1" +tiktoken-rs = { version = "0.5" } +serde = { version = "1" } diff --git a/crates/bpe-openai/README.md b/crates/bpe-openai/README.md new file mode 100644 index 0000000..e06d488 --- /dev/null +++ b/crates/bpe-openai/README.md @@ -0,0 +1,42 @@ +# OpenAI Byte Pair Encoders + +Fast tokenizers for OpenAI token sets based on the [bpe](https://crates.io/crates/bpe) crate. +Serialized BPE instances are generated during build and lazily loaded at runtime as static values. +The overhead of loading the tokenizers is small because it happens only once per process and only requires deserialization (as opposed to actually building the internal data structures). +For convencience it re-exports the `bpe` crate so that depending on this crate is enough to use these tokenizers. + +Supported token sets: + +- r50k +- p50k +- cl100k +- o200k + +## Usage + +Add a dependency by running + +```sh +cargo add bpe-openai +``` + +or by adding the following to `Cargo.toml` + +```toml +[dependencies] +bpe-openai = "0.1" +``` + +Counting tokens is as simple as: + +```rust +use bpe_openai::cl100k; + +fn main() { + let bpe = cl100k(); + let count = bpe.count("Hello, world!"); + println!("{tokens}"); +} +``` + +For more detailed documentation we refer to [bpe](https://crates.io/crates/bpe). diff --git a/crates/bpe-openai/build.rs b/crates/bpe-openai/build.rs new file mode 100644 index 0000000..b4f3837 --- /dev/null +++ b/crates/bpe-openai/build.rs @@ -0,0 +1,51 @@ +use std::env; +use std::fs::File; +use std::path::PathBuf; + +use bpe::byte_pair_encoding::BytePairEncoding; +use serde::Serialize; +use tiktoken_rs::CoreBPE; + +fn main() { + serialize_tokens( + "r50k", + &tiktoken_rs::r50k_base().expect("tiktoken initialization must not fail!"), + 50256, + 1, + ); + serialize_tokens( + "p50k", + &tiktoken_rs::p50k_base().expect("tiktoken initialization must not fail!"), + 50280, + 1, + ); + serialize_tokens( + "cl100k", + &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), + 100256, + 17846336922010275747, + ); + serialize_tokens( + "cl100k", + &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), + 100256, + 17846336922010275747, + ); + serialize_tokens( + "o200k", + &tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"), + 199998, + 17846336922010275747, + ); + println!("cargo::rerun-if-changed=build.rs"); +} + +fn serialize_tokens(name: &str, bpe: &CoreBPE, num_tokens: usize, hash_factor: u64) { + let mut path = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set during build")); + path.push(format!("bpe_{name}.dict")); + let file = File::create(path).expect("can create output file"); + let mut serializer = rmp_serde::Serializer::new(file); + let bpe = BytePairEncoding::from_tiktoken(bpe, num_tokens, Some(hash_factor)); + bpe.serialize(&mut serializer) + .expect("serialization succeeds"); +} diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs new file mode 100644 index 0000000..65c3619 --- /dev/null +++ b/crates/bpe-openai/src/lib.rs @@ -0,0 +1,66 @@ +use std::sync::LazyLock; + +use bpe::byte_pair_encoding::BytePairEncoding; + +static BPE_R50K: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k.dict")); + rmp_serde::from_slice(bytes).expect("") +}); + +static BPE_P50K: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k.dict")); + rmp_serde::from_slice(bytes).expect("") +}); + +static BPE_CL100K: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k.dict")); + rmp_serde::from_slice(bytes).expect("") +}); + +static BPE_O200K: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k.dict")); + rmp_serde::from_slice(bytes).expect("") +}); + +pub use bpe::*; + +pub fn r50k() -> &'static BytePairEncoding { + &BPE_R50K +} + +pub fn p50k() -> &'static BytePairEncoding { + &BPE_P50K +} + +pub fn cl100k() -> &'static BytePairEncoding { + &BPE_CL100K +} + +pub fn o200k() -> &'static BytePairEncoding { + &BPE_O200K +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn can_load_r50k() { + r50k().count("".as_bytes()); + } + + #[test] + fn can_load_p50k() { + p50k().count("".as_bytes()); + } + + #[test] + fn can_load_cl100k() { + cl100k().count("".as_bytes()); + } + + #[test] + fn can_load_o200k() { + o200k().count("".as_bytes()); + } +} diff --git a/crates/bpe/Cargo.toml b/crates/bpe/Cargo.toml index dd27c53..f48ad10 100644 --- a/crates/bpe/Cargo.toml +++ b/crates/bpe/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bpe" -version = "0.0.1" +version = "0.1.0" edition = "2021" description = "Fast byte-pair encoding implementation." repository = "https://github.com/github/rust-gems" @@ -16,6 +16,7 @@ bench = false name = "performance" path = "benches/performance.rs" harness = false +test = false [features] rand = ["dep:rand"] diff --git a/crates/bpe/README.md b/crates/bpe/README.md index 0cd4c58..a43c56c 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -227,6 +227,12 @@ If the requirement of correct BPE output can be relaxed, then the Greedy approac ![encoding runtime comparison](./benches/result/encoding-o200k.svg) +The graph below shows encoding results for input that is particularly challenging for tiktoken. +The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace. +This inhibits tiktoken ability to split the input before applying BPE revealing its quadratic runtime complexity. + +![worst-case encoding runtime comparison](./benches/result/worstcase-o200k.svg) + ### Incremental encoding Incremental encoding tokenizes a text while appending bytes. diff --git a/crates/bpe/benches/performance.rs b/crates/bpe/benches/performance.rs index 4cff09c..b4f1acc 100644 --- a/crates/bpe/benches/performance.rs +++ b/crates/bpe/benches/performance.rs @@ -10,21 +10,28 @@ use criterion::{ use rand::{thread_rng, Rng}; use tiktoken_rs::CoreBPE; -static TOKENIZERS: LazyLock<[(&'static str, &'static BytePairEncoding, CoreBPE); 2]> = - LazyLock::new(|| { - [ - ( - "cl100k", - BytePairEncoding::cl100k(), - tiktoken_rs::cl100k_base().unwrap(), +static TOKENIZERS: LazyLock<[(&'static str, BytePairEncoding, CoreBPE); 2]> = LazyLock::new(|| { + [ + ( + "cl100k", + BytePairEncoding::from_tiktoken( + &tiktoken_rs::cl100k_base_singleton().lock(), + 100256, + Some(17846336922010275747), ), - ( - "o200k", - BytePairEncoding::o200k(), - tiktoken_rs::o200k_base().unwrap(), + tiktoken_rs::cl100k_base().unwrap(), + ), + ( + "o200k", + BytePairEncoding::from_tiktoken( + &tiktoken_rs::o200k_base_singleton().lock(), + 199998, + Some(17846336922010275747), ), - ] - }); + tiktoken_rs::o200k_base().unwrap(), + ), + ] +}); fn counting_benchmark(c: &mut Criterion) { for (name, bpe, _) in TOKENIZERS.iter() { @@ -160,6 +167,31 @@ fn appending_benchmark(c: &mut Criterion) { } } +fn worstcase_benchmark(c: &mut Criterion) { + for (name, bpe, tiktoken) in TOKENIZERS.iter() { + let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect(); + let input = text.as_bytes(); + + let mut group = c.benchmark_group(format!("worstcase-{name}")); + for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input( + BenchmarkId::new("backtracking", bytes), + &bytes, + |b, bytes| b.iter(|| bpe.encode_via_backtracking(select_test_bytes(input, *bytes))), + ); + group.bench_with_input(BenchmarkId::new("tiktoken", bytes), &bytes, |b, bytes| { + b.iter_batched( + || select_test_bytes(input, *bytes), + |input| tiktoken.encode_ordinary(std::str::from_utf8(input).unwrap()), + criterion::BatchSize::SmallInput, + ) + }); + } + group.finish(); + } +} + fn is_char_boundary(b: u8) -> bool { // Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128 // Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192 @@ -188,12 +220,24 @@ fn create_test_string(bpe: &BytePairEncoding, tokens: usize) -> String { text } +fn select_test_bytes(input: &[u8], bytes: usize) -> &[u8] { + let mut start = thread_rng().gen_range(0..input.len() - bytes); + while start > 0 && !is_char_boundary(input[start]) { + start -= 1; + } + let mut end = start + bytes; + while end < input.len() && !is_char_boundary(input[end]) { + end += 1; + } + &input[start..end] +} + criterion_group!( name = benches; config = Criterion::default() .warm_up_time(Duration::from_millis(500)) - .measurement_time(Duration::from_millis(1000)) + .measurement_time(Duration::from_millis(4000)) .nresamples(1000); - targets = counting_benchmark, encoding_benchmark, appending_benchmark + targets = counting_benchmark, encoding_benchmark, appending_benchmark, worstcase_benchmark ); criterion_main!(benches); diff --git a/crates/bpe/benches/result/appending-o200k.svg b/crates/bpe/benches/result/appending-o200k.svg index f358527..5474718 100644 --- a/crates/bpe/benches/result/appending-o200k.svg +++ b/crates/bpe/benches/result/appending-o200k.svg @@ -34,17 +34,17 @@ - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/benches/result/counting-o200k.svg b/crates/bpe/benches/result/counting-o200k.svg index deaf497..9b93d5f 100644 --- a/crates/bpe/benches/result/counting-o200k.svg +++ b/crates/bpe/benches/result/counting-o200k.svg @@ -30,17 +30,17 @@ - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/benches/result/encoding-o200k.svg b/crates/bpe/benches/result/encoding-o200k.svg index 468755c..d0ffc09 100644 --- a/crates/bpe/benches/result/encoding-o200k.svg +++ b/crates/bpe/benches/result/encoding-o200k.svg @@ -34,41 +34,41 @@ - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/benches/result/worstcase-o200k.svg b/crates/bpe/benches/result/worstcase-o200k.svg new file mode 100644 index 0000000..7da8fca --- /dev/null +++ b/crates/bpe/benches/result/worstcase-o200k.svg @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/crates/bpe/script/copy-benchmark-results b/crates/bpe/script/copy-benchmark-results index df9e97f..ae045ed 100755 --- a/crates/bpe/script/copy-benchmark-results +++ b/crates/bpe/script/copy-benchmark-results @@ -6,6 +6,6 @@ result_dir="benches/result" mkdir -p "$result_dir" -for i in {counting,encoding,appending}-o200k; do +for i in {counting,encoding,appending,worstcase}-o200k; do rsvg-convert --format svg --output "$result_dir/$i.svg" --background-color white "target/criterion/reports/$i/lines.svg" done diff --git a/crates/bpe/src/appendable_encoder.rs b/crates/bpe/src/appendable_encoder.rs index f75fde8..b0752b5 100644 --- a/crates/bpe/src/appendable_encoder.rs +++ b/crates/bpe/src/appendable_encoder.rs @@ -90,13 +90,13 @@ impl<'a> AppendableEncoder<'a> { #[cfg(test)] mod tests { - use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; + use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; use super::AppendableEncoder; #[test] fn test_appendable_encoder() { - let bpe = BytePairEncoding::cl100k(); + let bpe = &BPE_CL100K; let mut enc = AppendableEncoder::new(bpe); let input_string = create_test_bytes(bpe, 100); for (i, c) in input_string.iter().enumerate() { diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index e00ab28..f18468e 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -2,7 +2,6 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::hash::{Hash, Hasher}; use std::ops::Range; -use std::sync::LazyLock; use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; use fnv::{FnvHashMap, FnvHasher}; @@ -12,19 +11,26 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::backtrack_encoder::BacktrackEncoder; use crate::bitfield::BitField; -use crate::byte_pair_encoding::data::TokenDict; -static BPE_CL100K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!("data/bpe_cl100k.dict"); - let dict: TokenDict = rmp_serde::from_slice(bytes).expect(""); - dict.into_bpe() -}); +#[cfg(test)] +pub(crate) static BPE_CL100K: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + BytePairEncoding::from_tiktoken( + &tiktoken_rs::cl100k_base_singleton().lock(), + 100256, + Some(17846336922010275747), + ) + }); -static BPE_O200K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!("data/bpe_o200k.dict"); - let dict: TokenDict = rmp_serde::from_slice(bytes).expect(""); - dict.into_bpe() -}); +#[cfg(test)] +pub(crate) static BPE_O200K: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + BytePairEncoding::from_tiktoken( + &tiktoken_rs::o200k_base_singleton().lock(), + 199998, + Some(17846336922010275747), + ) + }); /// Representation of the byte pair dictionary. /// This struct provides various conversions. @@ -215,14 +221,6 @@ fn find_token_by_bytes( } impl BytePairEncoding { - pub fn cl100k() -> &'static Self { - &BPE_CL100K - } - - pub fn o200k() -> &'static Self { - &BPE_O200K - } - /// Construct a BytePairEncoding instance from a tiktoken dictionary. /// A suitable hash factor may be necessary to prevent hash collisions, /// which can by found using [`find_hash_factor_for_tiktoken`]. @@ -570,12 +568,12 @@ mod tests { use std::time::Instant; use itertools::Itertools; - use tiktoken_rs::cl100k_base_singleton; + use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton}; - use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; + use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K, BPE_O200K}; #[test] - fn test_correctness() { + fn test_correctness_cl100k() { // This is quite a challenging test case... let test_string = std::str::from_utf8(&[ 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, @@ -585,7 +583,7 @@ mod tests { ]) .unwrap(); let time = Instant::now(); - let bpe = BytePairEncoding::cl100k(); + let bpe = &BPE_CL100K; println!("{:?}", time.elapsed()); let encoded1 = cl100k_base_singleton() .lock() @@ -601,9 +599,36 @@ mod tests { assert_eq!(encoded1, encoded4); } + #[test] + fn test_correctness_o200k() { + // This is quite a challenging test case... + let test_string = std::str::from_utf8(&[ + 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, + 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, + 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, + 102, 102, 101, 110, 100, + ]) + .unwrap(); + let time = Instant::now(); + let bpe = &BPE_O200K; + println!("{:?}", time.elapsed()); + let encoded1 = o200k_base_singleton() + .lock() + .encode_ordinary(test_string) + .into_iter() + .map(|t| t as u32) + .collect_vec(); + let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); + assert_eq!(encoded1, encoded2); + let encoded3 = bpe.encode_via_table(test_string.as_bytes()); + assert_eq!(encoded1, encoded3); + let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); + assert_eq!(encoded1, encoded4); + } + #[test] fn test_bpe_equivalence() { - let bpe = BytePairEncoding::cl100k(); + let bpe = &BPE_CL100K; for tokens in [10, 1000, 10000] { for _ in 0..5 { let test_input = create_test_bytes(bpe, tokens); @@ -614,68 +639,3 @@ mod tests { } } } - -mod data { - use serde::{Deserialize, Serialize}; - - use crate::byte_pair_encoding::BytePairEncoding; - - #[derive(Serialize, Deserialize)] - pub(crate) struct TokenDict { - tokens: Vec>, - hash_factor: u64, - } - - impl TokenDict { - pub(crate) fn into_bpe(self) -> BytePairEncoding { - BytePairEncoding::from_dictionary(self.tokens, Some(self.hash_factor)) - } - } - - #[test] - fn update_token_dicts() { - serialize_tokens( - "cl100k", - &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), - 100256, - 17846336922010275747, - ); - serialize_tokens( - "o200k", - &tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"), - 199998, - 17846336922010275747, - ); - } - - #[cfg(test)] - #[track_caller] - fn serialize_tokens( - name: &str, - bpe: &tiktoken_rs::CoreBPE, - num_tokens: usize, - hash_factor: u64, - ) { - use std::fs::File; - use std::path::PathBuf; - - use itertools::Itertools; - use serde::Serialize; - - let path = PathBuf::from(file!()); - let dir = path.parent().unwrap(); - let data_file = dir.join(format!("data/bpe_{name}.dict")); - let current_dir = std::env::current_dir().unwrap(); - let abs_path = current_dir.parent().unwrap().parent().unwrap(); - let file = File::create(abs_path.join(data_file)).unwrap(); - let mut serializer = rmp_serde::Serializer::new(file); - let tokens = (0..num_tokens) - .map(|i| bpe._decode_native(&[i])) - .collect_vec(); - let dict = TokenDict { - tokens, - hash_factor, - }; - dict.serialize(&mut serializer).unwrap(); - } -} diff --git a/crates/bpe/src/data/bpe_cl100k.dict b/crates/bpe/src/data/bpe_cl100k.dict deleted file mode 100644 index ab7a16e..0000000 Binary files a/crates/bpe/src/data/bpe_cl100k.dict and /dev/null differ diff --git a/crates/bpe/src/data/bpe_o200k.dict b/crates/bpe/src/data/bpe_o200k.dict deleted file mode 100644 index f7ab9fc..0000000 Binary files a/crates/bpe/src/data/bpe_o200k.dict and /dev/null differ diff --git a/crates/bpe/src/interval_encoding.rs b/crates/bpe/src/interval_encoding.rs index 5c2f248..05bf79f 100644 --- a/crates/bpe/src/interval_encoding.rs +++ b/crates/bpe/src/interval_encoding.rs @@ -86,13 +86,13 @@ impl<'a> IntervalEncoding<'a> { mod tests { use rand::{thread_rng, Rng}; - use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; + use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; use super::IntervalEncoding; #[test] fn test_interval_count() { - let bpe = BytePairEncoding::cl100k(); + let bpe = &BPE_CL100K; let text = create_test_bytes(bpe, 10000); let intervals = IntervalEncoding::new(bpe, &text); for _ in 0..1000 { diff --git a/crates/bpe/src/prependable_encoder.rs b/crates/bpe/src/prependable_encoder.rs index f229d32..ce13e40 100644 --- a/crates/bpe/src/prependable_encoder.rs +++ b/crates/bpe/src/prependable_encoder.rs @@ -90,13 +90,13 @@ impl<'a> PrependableEncoder<'a> { #[cfg(test)] mod tests { - use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; + use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; use super::PrependableEncoder; #[test] fn test_prependable_encoder() { - let bpe = BytePairEncoding::cl100k(); + let bpe = &BPE_CL100K; let mut enc = PrependableEncoder::new(bpe); let input_string = create_test_bytes(bpe, 100); for (i, c) in input_string.iter().enumerate().rev() {