@@ -12,15 +12,18 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
12
13
13
use crate :: backtrack_encoder:: BacktrackEncoder ;
14
14
use crate :: bitfield:: BitField ;
15
+ use crate :: byte_pair_encoding:: data:: TokenDict ;
15
16
16
17
static BPE_CL100K : LazyLock < BytePairEncoding > = LazyLock :: new ( || {
17
18
let bytes = include_bytes ! ( "data/bpe_cl100k.dict" ) ;
18
- rmp_serde:: from_slice ( bytes) . expect ( "" )
19
+ let dict: TokenDict = rmp_serde:: from_slice ( bytes) . expect ( "" ) ;
20
+ dict. into_bpe ( )
19
21
} ) ;
20
22
21
23
static BPE_O200K : LazyLock < BytePairEncoding > = LazyLock :: new ( || {
22
24
let bytes = include_bytes ! ( "data/bpe_o200k.dict" ) ;
23
- rmp_serde:: from_slice ( bytes) . expect ( "" )
25
+ let dict: TokenDict = rmp_serde:: from_slice ( bytes) . expect ( "" ) ;
26
+ dict. into_bpe ( )
24
27
} ) ;
25
28
26
29
/// Representation of the byte pair dictionary.
@@ -612,15 +615,23 @@ mod tests {
612
615
}
613
616
}
614
617
615
- #[ cfg( test) ]
616
618
mod data {
617
- use std:: fs:: File ;
618
- use std:: path:: PathBuf ;
619
-
620
- use serde:: Serialize ;
619
+ use serde:: { Deserialize , Serialize } ;
621
620
622
621
use crate :: byte_pair_encoding:: BytePairEncoding ;
623
622
623
+ #[ derive( Serialize , Deserialize ) ]
624
+ pub ( crate ) struct TokenDict {
625
+ tokens : Vec < Vec < u8 > > ,
626
+ hash_factor : u64 ,
627
+ }
628
+
629
+ impl TokenDict {
630
+ pub ( crate ) fn into_bpe ( self ) -> BytePairEncoding {
631
+ BytePairEncoding :: from_dictionary ( self . tokens , Some ( self . hash_factor ) )
632
+ }
633
+ }
634
+
624
635
#[ test]
625
636
fn update_token_dicts ( ) {
626
637
serialize_tokens (
@@ -637,22 +648,34 @@ mod data {
637
648
) ;
638
649
}
639
650
651
+ #[ cfg( test) ]
640
652
#[ track_caller]
641
653
fn serialize_tokens (
642
654
name : & str ,
643
- dict : & tiktoken_rs:: CoreBPE ,
655
+ bpe : & tiktoken_rs:: CoreBPE ,
644
656
num_tokens : usize ,
645
657
hash_factor : u64 ,
646
658
) {
659
+ use std:: fs:: File ;
660
+ use std:: path:: PathBuf ;
661
+
662
+ use itertools:: Itertools ;
663
+ use serde:: Serialize ;
664
+
647
665
let path = PathBuf :: from ( file ! ( ) ) ;
648
666
let dir = path. parent ( ) . unwrap ( ) ;
649
667
let data_file = dir. join ( format ! ( "data/bpe_{name}.dict" ) ) ;
650
668
let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
651
669
let abs_path = current_dir. parent ( ) . unwrap ( ) . parent ( ) . unwrap ( ) ;
652
670
let file = File :: create ( abs_path. join ( data_file) ) . unwrap ( ) ;
653
671
let mut serializer = rmp_serde:: Serializer :: new ( file) ;
654
- BytePairEncoding :: from_tiktoken ( dict, num_tokens, Some ( hash_factor) )
655
- . serialize ( & mut serializer)
656
- . unwrap ( ) ;
672
+ let tokens = ( 0 ..num_tokens)
673
+ . map ( |i| bpe. _decode_native ( & [ i] ) )
674
+ . collect_vec ( ) ;
675
+ let dict = TokenDict {
676
+ tokens,
677
+ hash_factor,
678
+ } ;
679
+ dict. serialize ( & mut serializer) . unwrap ( ) ;
657
680
}
658
681
}
0 commit comments