1+ mod alibi;
12#[ cfg( feature = "cuda" ) ]
23mod compute_cap;
34#[ cfg( feature = "cuda" ) ]
@@ -9,7 +10,9 @@ mod models;
910use crate :: compute_cap:: { incompatible_compute_cap, COMPILE_COMPUTE_CAP , RUNTIME_COMPUTE_CAP } ;
1011#[ cfg( feature = "cuda" ) ]
1112use crate :: models:: FlashBertModel ;
12- use crate :: models:: { BertModel , EmbeddingModel , PositionEmbeddingType , QuantBertModel } ;
13+ use crate :: models:: {
14+ BertModel , EmbeddingModel , JinaBertModel , PositionEmbeddingType , QuantBertModel ,
15+ } ;
1316use candle:: { DType , Device } ;
1417use candle_nn:: VarBuilder ;
1518use models:: Config ;
@@ -47,8 +50,6 @@ impl CandleBackend {
4750
4851 let model: Box < dyn EmbeddingModel + Send > = match device {
4952 Device :: Cpu => {
50- tracing:: info!( "Starting Bert model on CPU" ) ;
51-
5253 if & dtype == "float32" || & dtype == "float16" {
5354 let dtype = if & dtype == "float32" {
5455 DType :: F32
@@ -70,14 +71,21 @@ impl CandleBackend {
7071 }
7172 . s ( ) ?;
7273
73- Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
74+ if config. position_embedding_type == PositionEmbeddingType :: Alibi {
75+ tracing:: info!( "Starting JinaBert model on CPU" ) ;
76+ Box :: new ( JinaBertModel :: load ( vb, & config, pool) . s ( ) ?)
77+ } else {
78+ tracing:: info!( "Starting Bert model on CPU" ) ;
79+ Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
80+ }
7481 } else if & dtype == "q6k" {
7582 let vb = candle_transformers:: quantized_var_builder:: VarBuilder :: from_gguf (
7683 model_path. join ( "ggml-model-q6k.bin" ) ,
7784 )
7885 . map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
7986 tracing:: info!( "vb" ) ;
8087
88+ tracing:: info!( "Starting QuantBert model on CPU" ) ;
8189 Box :: new ( QuantBertModel :: load ( vb, & config, pool) . s ( ) ?)
8290 } else {
8391 return Err ( BackendError :: Start ( format ! (
@@ -130,6 +138,9 @@ impl CandleBackend {
130138 {
131139 tracing:: info!( "Starting FlashBert model on Cuda" ) ;
132140 Box :: new ( FlashBertModel :: load ( vb, & config, pool) . s ( ) ?)
141+ } else if config. position_embedding_type == PositionEmbeddingType :: Alibi {
142+ tracing:: info!( "Starting JinaBert model on Cuda" ) ;
143+ Box :: new ( JinaBertModel :: load ( vb, & config, pool) . s ( ) ?)
133144 } else {
134145 tracing:: info!( "Starting Bert model on Cuda" ) ;
135146 Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
0 commit comments