diff --git a/adalflow/adalflow/components/model_client/transformers_client.py b/adalflow/adalflow/components/model_client/transformers_client.py index f681f23f0..7ef85ab05 100644 --- a/adalflow/adalflow/components/model_client/transformers_client.py +++ b/adalflow/adalflow/components/model_client/transformers_client.py @@ -6,7 +6,6 @@ import re import warnings - from adalflow.core.model_client import ModelClient from adalflow.core.types import GeneratorOutput, ModelType, Embedding, EmbedderOutput from adalflow.core.functional import get_top_k_indices_scores @@ -14,116 +13,46 @@ # optional import from adalflow.utils.lazy_import import safe_import, OptionalPackages - -transformers = safe_import( - OptionalPackages.TRANSFORMERS.value[0], OptionalPackages.TRANSFORMERS.value[1] -) -torch = safe_import(OptionalPackages.TORCH.value[0], OptionalPackages.TORCH.value[1]) - -import torch - import torch.nn.functional as F from torch import Tensor +import torch from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, AutoTokenizer, AutoModel, + AutoModelForCausalLM, AutoModelForSequenceClassification, + pipeline, +) +from os import getenv as get_env_variable + +transformers = safe_import( + OptionalPackages.TRANSFORMERS.value[0], OptionalPackages.TRANSFORMERS.value[1] ) +torch = safe_import(OptionalPackages.TORCH.value[0], OptionalPackages.TORCH.value[1]) log = logging.getLogger(__name__) -def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: +def average_pool(last_hidden_states: Tensor, attention_mask: list) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] -# TODO: provide a standard api for embedding and chat models used in local model SDKs -class TransformerEmbedder: - """Local model SDK for transformers. - - - There are two ways to run transformers: - (1) model and then run model inference - (2) Pipeline and then run pipeline inference - - This file demonstrates how to - (1) create a torch model inference component: TransformerEmbedder which equalize to OpenAI(), the SyncAPIClient - (2) Convert this model inference component to LightRAG API client: TransformersClient - - The is now just an exmplary component that initialize a certain model from transformers and run inference on it. - It is not tested on all transformer models yet. It might be necessary to write one for each model. - - References: - - transformers: https://huggingface.co/docs/transformers/en/index - - thenlper/gte-base model:https://huggingface.co/thenlper/gte-base - """ - - models: Dict[str, type] = {} - - def __init__(self, model_name: Optional[str] = "thenlper/gte-base"): - super().__init__() - - if model_name is not None: - self.init_model(model_name=model_name) - - @lru_cache(None) - def init_model(self, model_name: str): - try: - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModel.from_pretrained(model_name) - # register the model - self.models[model_name] = self.model - log.info(f"Done loading model {model_name}") - - except Exception as e: - log.error(f"Error loading model {model_name}: {e}") - raise e - - def infer_gte_base_embedding( - self, - input=Union[str, List[str]], - tolist: bool = True, - ): - model = self.models.get("thenlper/gte-base", None) - if model is None: - # initialize the model - self.init_model("thenlper/gte-base") - - if isinstance(input, str): - input = [input] - # Tokenize the input texts - batch_dict = self.tokenizer( - input, max_length=512, padding=True, truncation=True, return_tensors="pt" - ) - outputs = model(**batch_dict) - embeddings = average_pool( - outputs.last_hidden_state, batch_dict["attention_mask"] - ) - # (Optionally) normalize embeddings - embeddings = F.normalize(embeddings, p=2, dim=1) - if tolist: - embeddings = embeddings.tolist() - return embeddings - - def __call__(self, **kwargs): - if "model" not in kwargs: - raise ValueError("model is required") - - if "mock" in kwargs and kwargs["mock"]: - import numpy as np - - embeddings = np.array([np.random.rand(768).tolist()]) - return embeddings - # load files and models, cache it for the next inference - model_name = kwargs["model"] - # inference the model - if model_name == "thenlper/gte-base": - return self.infer_gte_base_embedding(kwargs["input"]) - else: - raise ValueError(f"model {model_name} is not supported") +def mean_pooling(model_output: dict, attention_mask) -> Tensor: + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) def get_device(): @@ -144,296 +73,394 @@ def get_device(): def clean_device_cache(): import torch - if torch.has_mps: + if torch.backends.mps.is_built(): torch.mps.empty_cache() torch.mps.set_per_process_memory_fraction(1.0) -class TransformerReranker: - __doc__ = r"""Local model SDK for a reranker model using transformers. +class TransformerEmbeddingModelClient(ModelClient): + __doc__ = r"""LightRAG API client for embedding models using HuggingFace's transformers library. - References: - - model: https://huggingface.co/BAAI/bge-reranker-base - - paper: https://arxiv.org/abs/2309.07597 + Use: ``ls ~/.cache/huggingface/hub `` to see the cached models. - note: - If you are using Macbook M1 series chips, you need to ensure ``torch.device("mps")`` is set. + Some modeles are gated, you will need to their page to get the access token. + Find how to apply tokens here: https://huggingface.co/docs/hub/security-tokens + Once you have a token and have access, put the token in the environment variable HF_TOKEN. """ - models: Dict[str, type] = {} - def __init__(self, model_name: Optional[str] = "BAAI/bge-reranker-base"): - self.model_name = model_name or "BAAI/bge-reranker-base" - if model_name is not None: - self.init_model(model_name=model_name) + # + # Model initialisation + # + def __init__( + self, + model_name: Optional[str] = None, + tokenizer_kwargs: Optional[dict] = None, + auto_model_kwargs: Optional[dict] = None, + auto_tokenizer_kwargs: Optional[dict] = None, + auto_model: Optional[type] = AutoModel, + auto_tokenizer: Optional[type] = AutoTokenizer, + local_files_only: Optional[bool] = False, + custom_model: Optional[PreTrainedModel] = None, + custom_tokenizer: Optional[ + Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + ] = None, + ): + + super().__init__() + self.model_name = model_name + self.tokenizer_kwargs = tokenizer_kwargs or dict() + self.auto_model_kwargs = auto_model_kwargs or dict() + self.auto_tokenizer_kwargs = auto_tokenizer_kwargs or dict() + if "return_tensors" not in self.tokenizer_kwargs: + self.tokenizer_kwargs["return_tensors"] = "pt" + self.auto_model = auto_model + self.auto_tokenizer = auto_tokenizer + self.local_files_only = local_files_only + self.custom_model = custom_model + self.custom_tokenizer = custom_tokenizer + + # Check if there is conflicting arguments + self.use_auto_model = auto_model is not None + self.use_auto_tokenizer = auto_tokenizer is not None + self.use_cusom_model = custom_model is not None + self.use_cusom_tokenizer = custom_tokenizer is not None + self.model_name_exit = model_name is not None + + ## arguments related to model + if self.use_auto_model and self.use_cusom_model: + raise ValueError("Cannot specify 'auto_model' and 'custom_model'.") + elif (not self.use_auto_model) and (not self.use_cusom_model): + raise ValueError("Need to specify either 'auto_model' or 'custom_model'.") + elif self.use_auto_model and (not self.model_name_exit): + raise ValueError( + "When 'auto_model' is specified 'model_name' must be specified too." + ) + + ## arguments related to tokenizer + if self.use_auto_tokenizer and self.use_cusom_tokenizer: + raise Exception("Cannot specify 'auto_tokenizer' and 'custom_tokenizer'.") + elif (not self.use_auto_tokenizer) and (not self.use_cusom_tokenizer): + raise Exception( + "Need to specify either'auto_tokenizer' and 'custom_tokenizer'." + ) + elif self.use_auto_tokenizer and (not self.model_name_exit): + raise ValueError( + "When 'auto_tokenizer' is specified 'model_name' must be specified too." + ) + + self.init_sync_client() + + def init_sync_client(self): + self.init_model( + model_name=self.model_name, + auto_model=self.auto_model, + auto_tokenizer=self.auto_tokenizer, + custom_model=self.custom_model, + custom_tokenizer=self.custom_tokenizer, + ) + + @lru_cache(None) + def init_model( + self, + model_name: Optional[str] = None, + auto_model: Optional[type] = AutoModel, + auto_tokenizer: Optional[type] = AutoTokenizer, + custom_model: Optional[PreTrainedModel] = None, + custom_tokenizer: Optional[ + PreTrainedTokenizer | PreTrainedTokenizerFast + ] = None, + ): - def init_model(self, model_name: str): try: - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForSequenceClassification.from_pretrained(model_name) - # Check device availability and set the device - device = get_device() + if self.use_auto_model: + self.model = auto_model.from_pretrained( + model_name, + local_files_only=self.local_files_only, + **self.auto_model_kwargs, + ) + else: + self.model = custom_model + + if self.use_auto_tokenizer: + self.tokenizer = auto_tokenizer.from_pretrained( + model_name, + local_files_only=self.local_files_only, + **self.auto_tokenizer_kwargs, + ) + else: + self.tokenizer = custom_tokenizer - # Move model to the selected device - self.device = device - self.model.to(device) - self.model.eval() - # register the model - self.models[model_name] = self.model # TODO: better model registration log.info(f"Done loading model {model_name}") except Exception as e: log.error(f"Error loading model {model_name}: {e}") raise e - def infer_bge_reranker_base( + # + # Inference code + # + def infer_embedding( self, - # input=List[Tuple[str, str]], # list of pairs of the query and the candidate - query: str, - documents: List[str], - ) -> List[float]: - model = self.models.get(self.model_name, None) - if model is None: - # initialize the model - self.init_model(self.model_name) + input=Union[str, List[str], List[List[str]]], + tolist: bool = True, + ) -> Union[List, Tensor]: + model = self.model - # convert the query and documents to pair input - input = [(query, doc) for doc in documents] + self.handle_input(input) + batch_dict = self.tokenize_inputs(input, kwargs=self.tokenizer_kwargs) + outputs = self.compute_model_outputs(batch_dict, model) + embeddings = self.compute_embeddings(outputs, batch_dict) + + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + if tolist: + embeddings = embeddings.tolist() + return embeddings + + def handle_input( + self, input: Union[str, List[str], List[List[str]]] + ) -> Union[List[str], List[List[str]]]: + if isinstance(input, str): + input = [input] + return input + def tokenize_inputs( + self, + input: Union[str, List[str], List[List[str]]], + kwargs: Optional[dict] = None, + ) -> dict: + kwargs = kwargs or dict() + batch_dict = self.tokenizer(input, **kwargs) + return batch_dict + + def compute_model_outputs(self, batch_dict: dict, model: PreTrainedModel) -> dict: with torch.no_grad(): + outputs = model(**batch_dict) + return outputs - inputs = self.tokenizer( - input, - padding=True, - truncation=True, - return_tensors="pt", - max_length=512, - ) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - scores = ( - model(**inputs, return_dict=True) - .logits.view( - -1, - ) - .float() - ) - # apply sigmoid to get the scores - scores = F.sigmoid(scores) + def compute_embeddings(self, outputs: dict, batch_dict: dict): + embeddings = mean_pooling(outputs, batch_dict["attention_mask"]) + return embeddings - scores = scores.tolist() - return scores + # + # Preprocessing, postprocessing and call for inference code + # + def call( + self, + api_kwargs: Dict = None, + model_type: Optional[ModelType] = ModelType.UNDEFINED, + ) -> Union[List, Tensor]: - def __call__(self, **kwargs): - r"""Ensure "model" and "input" are in the kwargs.""" - if "model" not in kwargs: - raise ValueError("model is required") + api_kwargs = api_kwargs or dict() + if "model" not in api_kwargs: + raise ValueError("model must be specified in api_kwargs") + # I don't think it is useful anymore + # if ( + # model_type == ModelType.EMBEDDER + # # and "model" in api_kwargs + # ): + if "mock" in api_kwargs and api_kwargs["mock"]: + import numpy as np - # if "mock" in kwargs and kwargs["mock"]: - # import numpy as np + embeddings = np.array([np.random.rand(768).tolist()]) + return embeddings - # scores = np.array([np.random.rand(1).tolist()]) - # return scores - # load files and models, cache it for the next inference - model_name = kwargs["model"] # inference the model - if model_name == self.model_name: - assert "query" in kwargs, "query is required" - assert "documents" in kwargs, "documents is required" - scores = self.infer_bge_reranker_base(kwargs["query"], kwargs["documents"]) - return scores - else: - raise ValueError(f"model {model_name} is not supported") + return self.infer_embedding(api_kwargs["input"]) + def parse_embedding_response(self, response: Union[List, Tensor]) -> EmbedderOutput: + embeddings: List[Embedding] = [] + for idx, emb in enumerate(response): + embeddings.append(Embedding(index=idx, embedding=emb)) + response = EmbedderOutput(data=embeddings) + return response -class TransformerLLM: - __doc__ = r"""Local model SDK for transformers LLM. - - NOTE: - This inference component is only specific to the HuggingFaceH4/zephyr-7b-beta model. + def convert_inputs_to_api_kwargs( + self, + input: Any, # for retriever, it is a single query, + model_kwargs: dict = {}, + model_type: Optional[ModelType] = ModelType.UNDEFINED, + ) -> dict: + final_model_kwargs = model_kwargs.copy() + # if model_type == ModelType.EMBEDDER: + final_model_kwargs["input"] = input + return final_model_kwargs - The example raw output: - # <|system|> - # You are a friendly chatbot who always responds in the style of a pirate. - # <|user|> - # How many helicopters can a human eat in one sitting? - # <|assistant|> - # Ah, me hearty matey! But yer question be a puzzler! A human cannot eat a helicopter in one sitting, as helicopters are not edible. They be made of metal, plastic, and other materials, not food! +class TransformerLLMModelClient(ModelClient): + __doc__ = r"""LightRAG API client for text generation models using HuggingFace's transformers library. - References: - - model: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta - - https://huggingface.co/google/gemma-2b - - https://huggingface.co/google/gemma-2-2b + Use: ``ls ~/.cache/huggingface/hub `` to see the cached models. + Some modeles are gated, you will need to their page to get the access token. + Find how to apply tokens here: https://huggingface.co/docs/hub/security-tokens + Once you have a token and have access, put the token in the environment variable HF_TOKEN. """ - models: Dict[str, type] = {} # to register the model - tokenizer: Dict[str, type] = {} - - model_to_init_func = { - "HuggingFaceH4/zephyr-7b-beta": "use_pipeline", - "google/gemma-2-2b": "use_pipeline", - } + # + # Model initialisation + # def __init__( self, model_name: Optional[str] = None, + tokenizer_decode_kwargs: Optional[dict] = None, + tokenizer_kwargs: Optional[dict] = None, + auto_model_kwargs: Optional[dict] = None, + auto_tokenizer_kwargs: Optional[dict] = None, + init_from: Optional[str] = "autoclass", + apply_chat_template: bool = False, + chat_template: Optional[str] = None, + chat_template_kwargs: Optional[dict] = None, + use_token: bool = False, + torch_dtype: Optional[Any] = torch.bfloat16, + local_files_only: Optional[bool] = False, ): super().__init__() self.model_name = model_name # current model to use - - if model_name is not None and model_name not in self.models: + self.tokenizer_decode_kwargs = tokenizer_decode_kwargs or dict() + self.tokenizer_kwargs = tokenizer_kwargs or dict() + self.auto_model_kwargs = auto_model_kwargs or dict() + self.auto_tokenizer_kwargs = auto_tokenizer_kwargs or dict() + if "return_tensors" not in self.tokenizer_kwargs: + self.tokenizer_kwargs["return_tensors"] = "pt" + self.use_token = use_token + self.torch_dtype = torch_dtype + self.init_from = init_from + self.apply_chat_template = apply_chat_template + self.chat_template = chat_template + self.chat_template_kwargs = chat_template_kwargs or dict( + tokenize=False, add_generation_prompt=True + ) + self.local_files_only = local_files_only + self.model = None + if model_name is not None: self.init_model(model_name=model_name) def _check_token(self, token: str): - import os - - if os.getenv(token) is None: + if get_env_variable(token) is None: warnings.warn( f"{token} is not set. You may not be able to access the model." ) - def _init_from_pipeline(self, model_name: str): - from transformers import pipeline + def _get_token_if_relevant(self) -> Union[str, bool]: + if self.use_token: + self._check_token("HF_TOKEN") + token = get_env_variable("HF_TOKEN") + else: + token = False + return token + + def _init_from_pipeline(self): clean_device_cache() - self._check_token("HF_TOKEN") - try: - import os - - pipe = pipeline( - "text-generation", - model=model_name, - torch_dtype=torch.bfloat16, - device=get_device(), - token=os.getenv("HF_TOKEN"), - ) - self.models[model_name] = pipe - except Exception as e: - log.error(f"Error loading model {model_name}: {e}") - raise e + token = self._get_token_if_relevant() # return a token string or False + self.model = pipeline( + "text-generation", + model=self.model_name, + torch_dtype=self.torch_dtype, + device=get_device(), + token=token, + ) - def _init_from_automodelcasual_lm(self, model_name: str): - try: - from transformers import AutoTokenizer, AutoModelForCausalLM - except ImportError: - raise ImportError( - "transformers is not installed. Please install it with `pip install transformers`" - ) + def _init_from_automodelcasual_lm(self): - try: - import os + token = self._get_token_if_relevant() # return a token str or False - if os.getenv("HF_TOKEN") is None: - warnings.warn( - "HF_TOKEN is not set. You may not be able to access the model." - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_name, token=os.getenv("HF_TOKEN") - ) - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - device_map="auto", - token=os.getenv("HF_TOKEN"), - ) - self.models[model_name] = model - self.tokenizer[model_name] = tokenizer - except Exception as e: - log.error(f"Error loading model {model_name}: {e}") - raise e + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + token=token, + local_files_only=self.local_files_only, + **self.auto_tokenizer_kwargs, + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=self.torch_dtype, + device_map="auto", + token=token, + local_files_only=self.local_files_only, + **self.auto_model_kwargs, + ) + # Set pad token if it's not already set + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token # common fallback + self.model.config.pad_token_id = ( + self.tokenizer.eos_token_id + ) # ensure consistency in the model config @lru_cache(None) def init_model(self, model_name: str): - log.debug(f"Loading model {model_name}") - - model_setup = self.model_to_init_func.get(model_name, None) - if model_setup: - if model_setup == "use_pipeline": - self._init_from_pipeline(model_name) - else: - self._init_from_automodelcasual_lm(model_name) - else: - raise ValueError(f"Model {model_name} is not supported") - - def _parse_chat_completion_from_pipeline(self, completion: Any) -> str: - - text = completion[0]["generated_text"] - - pattern = r"(?<=\|assistant\|>).*" - - match = re.search(pattern, text) - - if match: - text = match.group().strip().lstrip("\\n") - return text - else: - return "" - def _parse_chat_completion_from_automodelcasual_lm(self, completion: Any) -> str: - print(f"completion: {completion}") - return completion[0] - - def parse_chat_completion(self, completion: Any) -> str: - model_name = self.model_name - model_setup = self.model_to_init_func.get(model_name, None) - if model_setup: - if model_setup == "use_pipeline": - return self._parse_chat_completion_from_pipeline(completion) + log.debug(f"Loading model {model_name}") + try: + if self.init_from == "autoclass": + self._init_from_automodelcasual_lm() + elif self.init_from == "pipeline": + self._init_from_pipeline() else: - return self._parse_chat_completion_from_automodelcasual_lm(completion) - else: - raise ValueError(f"Model {model_name} is not supported") + raise ValueError( + "argument 'init_from' must be one of 'autoclass' or 'pipeline'." + ) + except Exception as e: + log.error(f"Error loading model {model_name}: {e}") + raise e + # + # Inference code + # def _infer_from_pipeline( self, *, model: str, messages: Sequence[Dict[str, str]], max_tokens: Optional[int] = None, + apply_chat_template: bool = False, + chat_template: Optional[str] = None, + chat_template_kwargs: Optional[dict] = None, **kwargs, ): - if not model: - raise ValueError("Model is not provided.") - if model not in self.models: + if not self.model: self.init_model(model_name=model) - model_to_use = self.models[model] - log.info( f"Start to infer model {model}, messages: {messages}, kwargs: {kwargs}" ) - - if model == "HuggingFaceH4/zephyr-7b-beta": - - prompt = model_to_use.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + # TO DO: add default values in doc + final_kwargs = { + "max_new_tokens": max_tokens or 256, + "do_sample": True, + "temperature": kwargs.get("temperature", 0.7), + "top_k": kwargs.get("top_k", 50), + "top_p": kwargs.get("top_p", 0.95), + } + if apply_chat_template: + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + token=self._get_token_if_relevant(), + local_files_only=self.local_files_only, ) - - final_kwargs = { - "max_new_tokens": max_tokens or 256, - "do_sample": True, - "temperature": kwargs.get("temperature", 0.7), - "top_k": kwargs.get("top_k", 50), - "top_p": kwargs.get("top_p", 0.95), - } - outputs = model_to_use(prompt, **final_kwargs) - elif model == "google/gemma-2-2b": - final_kwargs = { - "max_new_tokens": max_tokens or 256, - "do_sample": True, - "temperature": kwargs.get("temperature", 0.7), - "top_k": kwargs.get("top_k", 50), - "top_p": kwargs.get("top_p", 0.95), - } - text = messages[0]["content"] - outputs = model_to_use( - text, - **final_kwargs, + # Set pad token if it's not already set + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token # common fallback + self.model.config.pad_token_id = ( + self.tokenizer.eos_token_id + ) # ensure consistency in the model config + + model_input = self._handle_input( + messages, + apply_chat_template=True, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, ) + else: + model_input = self._handle_input(messages) + outputs = self.model( + model_input, + **final_kwargs, + ) log.info(f"Outputs: {outputs}") return outputs @@ -442,26 +469,61 @@ def _infer_from_automodelcasual_lm( *, model: str, messages: Sequence[Dict[str, str]], + max_tokens: Optional[int] = None, max_length: Optional[int] = 8192, # model-agnostic + apply_chat_template: bool = False, + chat_template: Optional[str] = None, + chat_template_kwargs: Optional[dict] = None, **kwargs, ): - if not model: - raise ValueError("Model is not provided.") - if model not in self.models: + if not self.model: self.init_model(model_name=model) - model_to_use = self.models[model] - tokenizer_to_use = self.tokenizer[model] - input_ids = tokenizer_to_use(messages[0]["content"], return_tensors="pt").to( + if apply_chat_template: + model_input = self._handle_input( + messages, + apply_chat_template=True, + chat_template_kwargs=chat_template_kwargs, + chat_template=chat_template, + ) + else: + model_input = self._handle_input(messages) + input_ids = self.tokenizer(model_input, **self.tokenizer_kwargs).to( get_device() ) - print(input_ids) - outputs_tokens = model_to_use.generate(**input_ids, max_length=max_length) + outputs_tokens = self.model.generate( + **input_ids, max_length=max_length, max_new_tokens=max_tokens, **kwargs + ) outputs = [] - for i, output in enumerate(outputs_tokens): - outputs.append(tokenizer_to_use.decode(output)) + for output in outputs_tokens: + outputs.append(self.tokenizer.decode(output)) return outputs + def _handle_input( + self, + messages: Sequence[Dict[str, str]], + apply_chat_template: bool = False, + chat_template_kwargs: dict = None, + chat_template: Optional[str] = None, + ) -> str: + + if apply_chat_template: + if chat_template is not None: + self.tokenizer.chat_template = chat_template + prompt = self.tokenizer.apply_chat_template( + messages, **chat_template_kwargs + ) + if ("tokenize" in chat_template_kwargs) and ( + chat_template_kwargs["tokenize"] == True + ): + prompt = self.tokenizer.decode(prompt, **self.tokenizer_decode_kwargs) + return prompt + else: + return prompt + else: + text = messages[-1]["content"] + return text + def infer_llm( self, *, @@ -470,41 +532,104 @@ def infer_llm( max_tokens: Optional[int] = None, **kwargs, ): - # TODO: generalize the code for more models - model_setup = self.model_to_init_func.get(model, None) - if model_setup: - if model_setup == "use_pipeline": - return self._infer_from_pipeline( - model=model, messages=messages, max_tokens=max_tokens, **kwargs - ) - else: - return self._infer_from_automodelcasual_lm( - model=model, messages=messages, max_tokens=max_tokens, **kwargs - ) - else: - raise ValueError(f"Model {model} is not supported") - def __call__(self, **kwargs): - r"""Ensure "model" and "input" are in the kwargs.""" - log.debug(f"kwargs: {kwargs}") - if "model" not in kwargs: - raise ValueError("model is required") + if self.init_from == "pipeline": + return self._infer_from_pipeline( + model=model, + messages=messages, + max_tokens=max_tokens, + apply_chat_template=self.apply_chat_template, + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + **kwargs, + ) + else: + return self._infer_from_automodelcasual_lm( + model=model, + messages=messages, + max_tokens=max_tokens, + apply_chat_template=self.apply_chat_template, + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + **kwargs, + ) - if "messages" not in kwargs: - raise ValueError("messages is required") + # + # Preprocessing, postprocessing and call for inference code + # + def call( + self, + api_kwargs: Dict = None, + model_type: Optional[ModelType] = ModelType.UNDEFINED, + ): + api_kwargs = api_kwargs or dict() + if "model" not in api_kwargs: + raise ValueError("model must be specified in api_kwargs") - model_name = kwargs["model"] - if model_name != self.model_name: - # need to initialize the model and update the model_name + model_name = api_kwargs["model"] + if (model_name != self.model_name) and (self.model_name is not None): + # need to update the model_name + log.warning( + f"The model passed in 'model_kwargs' is different that the one that has been previously initialised: Updating model from {self.model_name} to {model_name}." + ) + self.model_name = model_name + self.init_model(model_name=model_name) + elif (model_name != self.model_name) and (self.model_name is None): + # need to initialize the model for the first time self.model_name = model_name self.init_model(model_name=model_name) - output = self.infer_llm(**kwargs) + output = self.infer_llm(**api_kwargs) return output + def _parse_chat_completion_from_pipeline(self, completion: Any) -> str: + + text = completion[0]["generated_text"] + pattern = r"(?<=\|assistant\|>).*" + match = re.search(pattern, text) + + if match: + text = match.group().strip().lstrip("\\n") + return text + else: + return "" + + def _parse_chat_completion_from_automodelcasual_lm( + self, completion: Any + ) -> GeneratorOutput: + print(f"completion: {completion}") + return completion[0] + + def parse_chat_completion(self, completion: Any) -> str: + try: + if self.init_from == "pipeline": + output = self._parse_chat_completion_from_pipeline(completion) + else: + output = self._parse_chat_completion_from_automodelcasual_lm(completion) + return GeneratorOutput(data=output, raw_response=str(completion)) + except Exception as e: + log.error(f"Error parsing chat completion: {e}") + return GeneratorOutput(data=None, raw_response=str(completion), error=e) + + def convert_inputs_to_api_kwargs( + self, + input: Any, # for retriever, it is a single query, + model_kwargs: dict = None, + model_type: Optional[ModelType] = ModelType.UNDEFINED, + ) -> dict: + model_kwargs = model_kwargs or dict() + final_model_kwargs = model_kwargs.copy() + assert "model" in final_model_kwargs, "model must be specified" + # messages = [{"role": "system", "content": input}] + messages = [ + {"role": "user", "content": input} + ] # Not sure, but it seems to make more sense + final_model_kwargs["messages"] = messages + return final_model_kwargs + -class TransformersClient(ModelClient): - __doc__ = r"""LightRAG API client for transformers. +class TransformerRerankerModelClient(ModelClient): + __doc__ = r"""LightRAG API client for reranker (cross-encoder) models using HuggingFace's transformers library. Use: ``ls ~/.cache/huggingface/hub `` to see the cached models. @@ -513,273 +638,131 @@ class TransformersClient(ModelClient): Once you have a token and have access, put the token in the environment variable HF_TOKEN. """ - support_models = { - "thenlper/gte-base": { - "type": ModelType.EMBEDDER, - }, - "BAAI/bge-reranker-base": { - "type": ModelType.RERANKER, - }, - "HuggingFaceH4/zephyr-7b-beta": {"type": ModelType.LLM}, - "google/gemma-2-2b": {"type": ModelType.LLM}, - } - - def __init__(self, model_name: Optional[str] = None) -> None: - super().__init__() - self._model_name = model_name - if self._model_name: - assert ( - self._model_name in self.support_models - ), f"model {self._model_name} is not supported" - if self._model_name == "thenlper/gte-base": - self.sync_client = self.init_sync_client() - elif self._model_name == "BAAI/bge-reranker-base": - self.reranker_client = self.init_reranker_client() - elif self._model_name == "HuggingFaceH4/zephyr-7b-beta": - self.llm_client = self.init_llm_client() - self.async_client = None - - def init_sync_client(self): - return TransformerEmbedder() - - def init_reranker_client(self): - return TransformerReranker() - - def init_llm_client(self): - return TransformerLLM() - - def set_llm_client(self, llm_client: object): - r"""Allow user to pass a custom llm client. Here is an example of a custom llm client: - - Ensure you have parse_chat_completion and __call__ methods which will be applied to api_kwargs specified in transform_client.call(). - - .. code-block:: python - - class CustomizeLLM: - - def __init__(self) -> None: - pass - - def parse_chat_completion(self, completion: Any) -> str: - return completion - - def __call__(self, messages: Sequence[Dict[str, str]], model: str, **kwargs): - from transformers import AutoTokenizer, AutoModelForCausalLM - - tokenizer = AutoTokenizer.from_pretrained( - "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True - ) - model = AutoModelForCausalLM.from_pretrained( - "deepseek-ai/deepseek-coder-1.3b-instruct", - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ).to(get_device()) - messages = [ - {"role": "user", "content": "write a quick sort algorithm in python."} - ] - inputs = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt" - ).to(model.device) - # tokenizer.eos_token_id is the id of <|EOT|> token - outputs = model.generate( - inputs, - max_new_tokens=512, - do_sample=False, - top_k=50, - top_p=0.95, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - ) - print( - tokenizer.decode(outputs[0][len(inputs[0]) :], skip_special_tokens=True) - ) - decoded_outputs = [] - for output in outputs: - decoded_outputs.append( - tokenizer.decode(output[len(inputs[0]) :], skip_special_tokens=True) - ) - return decoded_outputs - - llm_client = CustomizeLLM() - transformer_client.set_llm_client(llm_client) - # use in the generator - generator = Generator( - model_client=transformer_client, - model_kwargs=model_kwargs, - prompt_kwargs=prompt_kwargs, - ...) - - """ - self.llm_client = llm_client - - def parse_embedding_response(self, response: Any) -> EmbedderOutput: - embeddings: List[Embedding] = [] - for idx, emb in enumerate(response): - embeddings.append(Embedding(index=idx, embedding=emb)) - response = EmbedderOutput(data=embeddings) - return response + # + # Model initialisation + # + def __init__( + self, + model_name: Optional[str] = None, + tokenizer_kwargs: Optional[dict] = None, + auto_model_kwargs: Optional[dict] = None, + auto_tokenizer_kwargs: Optional[dict] = None, + auto_model: Optional[type] = AutoModelForSequenceClassification, + auto_tokenizer: Optional[type] = AutoTokenizer, + local_files_only: Optional[bool] = False, + ): + self.auto_model = auto_model + self.auto_model_kwargs = auto_model_kwargs or dict() + self.auto_tokenizer_kwargs = auto_tokenizer_kwargs or dict() + self.auto_tokenizer = auto_tokenizer + self.model_name = model_name + self.tokenizer_kwargs = tokenizer_kwargs or dict() + if "return_tensors" not in self.tokenizer_kwargs: + self.tokenizer_kwargs["return_tensors"] = "pt" + self.local_files_only = local_files_only + if model_name is not None: + self.init_model(model_name=model_name) - def parse_chat_completion(self, completion: Any) -> GeneratorOutput: + def init_model(self, model_name: str): try: - output = self.llm_client.parse_chat_completion(completion) + self.tokenizer = self.auto_tokenizer.from_pretrained( + self.model_name, + local_files_only=self.local_files_only, + **self.auto_tokenizer_kwargs, + ) + self.model = self.auto_model.from_pretrained( + self.model_name, + local_files_only=self.local_files_only, + **self.auto_model_kwargs, + ) + # Check device availability and set the device + device = get_device() + + # Move model to the selected device + self.device = device + self.model.to(device) + self.model.eval() + # register the model + log.info(f"Done loading model {model_name}") - return GeneratorOutput(data=output, raw_response=str(completion)) except Exception as e: - log.error(f"Error parsing chat completion: {e}") - return GeneratorOutput(data=None, raw_response=str(completion), error=e) + log.error(f"Error loading model {model_name}: {e}") + raise e + + # + # Inference code + # + + def infer_reranker( + self, + model: str, + query: str, + documents: List[str], + ) -> List[float]: + if not self.model: + self.init_model(model_name=model) + # convert the query and documents to pair input + input = [(query, doc) for doc in documents] + + with torch.no_grad(): + + inputs = self.tokenizer(input, **self.tokenizer_kwargs) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + scores = ( + self.model(**inputs, return_dict=True) + .logits.view( + -1, + ) + .float() + ) + # apply sigmoid to get the scores + scores = F.sigmoid(scores) + + scores = scores.tolist() + return scores - def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + # + # Preprocessing, postprocessing and call for inference code + # + def call(self, api_kwargs: Dict = None): + api_kwargs = api_kwargs or dict() if "model" not in api_kwargs: raise ValueError("model must be specified in api_kwargs") - if api_kwargs["model"] not in self.support_models: - raise ValueError(f"model {api_kwargs['model']} is not supported") - - if ( - model_type == ModelType.EMBEDDER - and "model" in api_kwargs - and api_kwargs["model"] == "thenlper/gte-base" - ): - if self.sync_client is None: - self.sync_client = self.init_sync_client() - return self.sync_client(**api_kwargs) - elif ( # reranker - model_type == ModelType.RERANKER - and "model" in api_kwargs - and api_kwargs["model"] == "BAAI/bge-reranker-base" - ): - if not hasattr(self, "reranker_client") or self.reranker_client is None: - self.reranker_client = self.init_reranker_client() - scores = self.reranker_client(**api_kwargs) - top_k_indices, top_k_scores = get_top_k_indices_scores( - scores, api_kwargs["top_k"] + + model_name = api_kwargs["model"] + if (model_name != self.model_name) and (self.model_name is not None): + # need to update the model_name + log.warning( + f"The model passed in 'model_kwargs' is different that the one that has been previously initialised: Updating model from {self.model_name} to {model_name}." ) - return top_k_indices, top_k_scores - elif model_type == ModelType.LLM and "model" in api_kwargs: # LLM - if not hasattr(self, "llm_client") or self.llm_client is None: - self.llm_client = self.init_llm_client() - response = self.llm_client(**api_kwargs) - return response - else: - raise ValueError(f"model_type {model_type} is not supported") + self.model_name = model_name + self.init_model(model_name=model_name) + elif (model_name != self.model_name) and (self.model_name is None): + # need to initialize the model for the first time + self.model_name = model_name + self.init_model(model_name=model_name) + + assert "query" in api_kwargs, "query is required" + assert "documents" in api_kwargs, "documents is required" + assert "top_k" in api_kwargs, "top_k is required" + + top_k = api_kwargs.pop("top_k") + scores = self.infer_reranker(**api_kwargs) + top_k_indices, top_k_scores = get_top_k_indices_scores(scores, top_k) + log.warning(f"output: ({top_k_indices}, {top_k_scores})") + return top_k_indices, top_k_scores def convert_inputs_to_api_kwargs( self, input: Any, # for retriever, it is a single query, - model_kwargs: dict = {}, + model_kwargs: dict = None, model_type: ModelType = ModelType.UNDEFINED, ) -> dict: + model_kwargs = model_kwargs or dict() final_model_kwargs = model_kwargs.copy() - if model_type == ModelType.EMBEDDER: - final_model_kwargs["input"] = input - return final_model_kwargs - elif model_type == ModelType.RERANKER: - assert "model" in final_model_kwargs, "model must be specified" - assert "documents" in final_model_kwargs, "documents must be specified" - assert "top_k" in final_model_kwargs, "top_k must be specified" - final_model_kwargs["query"] = input - return final_model_kwargs - elif model_type == ModelType.LLM: - assert "model" in final_model_kwargs, "model must be specified" - messages = [{"role": "system", "content": input}] - final_model_kwargs["messages"] = messages - return final_model_kwargs - else: - raise ValueError(f"model_type {model_type} is not supported") - - -if __name__ == "__main__": - from adalflow.core import Generator - - import adalflow as adal - - adal.setup_env() - - rag_template = r""" -You are a helpful assistant. - -Your task is to answer the query that may or may not come with context information. -When context is provided, you should stick to the context and less on your prior knowledge to answer the query. - - - - {{input_str}} - - {% if context_str %} - - {{context_str}} - - {% endif %} - -""" - - template = """{{input_str}}""" - - model_kwargs = { - "model": "google/gemma-2-2b", - "temperature": 1, - "stream": False, - } - prompt_kwargs = { - "input_str": "Where is Brian?", - # "context_str": "Brian is in the kitchen.", - } - prompt_kwargs = { - "input_str": "What is the capital of France?", - } - - class CustomizeLLM: - - def __init__(self) -> None: - pass - - def parse_chat_completion(self, completion: Any) -> str: - return completion[0] - - def __call__(self, messages: Sequence[Dict[str, str]], model: str, **kwargs): - r"""take api key""" - from transformers import AutoTokenizer, AutoModelForCausalLM - - tokenizer = AutoTokenizer.from_pretrained( - "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True - ) - model = AutoModelForCausalLM.from_pretrained( - "deepseek-ai/deepseek-coder-1.3b-instruct", - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ).to(get_device()) - messages = [ - {"role": "user", "content": "write a quick sort algorithm in python."} - ] - inputs = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt" - ).to(model.device) - # tokenizer.eos_token_id is the id of <|EOT|> token - outputs = model.generate( - inputs, - max_new_tokens=512, - do_sample=False, - top_k=50, - top_p=0.95, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - ) - - decoded_outputs = [] - for output in outputs: - decoded_outputs.append( - tokenizer.decode(output[len(inputs[0]) :], skip_special_tokens=True) - ) - return decoded_outputs - - transformer_client = TransformersClient() - transformer_client.set_llm_client(CustomizeLLM()) - generator = Generator( - model_client=transformer_client, - model_kwargs=model_kwargs, - # prompt_kwargs=prompt_kwargs, - template=template, - # output_processors=JsonParser(), - ) - output = generator(prompt_kwargs=prompt_kwargs) - print(output) + assert "model" in final_model_kwargs, "model must be specified" + assert "documents" in final_model_kwargs, "documents must be specified" + assert "top_k" in final_model_kwargs, "top_k must be specified" + final_model_kwargs["query"] = input + return final_model_kwargs diff --git a/adalflow/adalflow/core/embedder.py b/adalflow/adalflow/core/embedder.py index 89aac0c5f..588347baf 100644 --- a/adalflow/adalflow/core/embedder.py +++ b/adalflow/adalflow/core/embedder.py @@ -231,4 +231,4 @@ def call( input=batch_input, model_kwargs=model_kwargs ) embeddings.append(batch_output) - return embeddings + return embeddings \ No newline at end of file diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index 027651328..7868b44ec 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -6,7 +6,6 @@ import json from typing import Any, Dict, Optional, Union, Callable, Tuple, List -from copy import deepcopy import logging @@ -110,11 +109,6 @@ def __init__( ) template = template or DEFAULT_LIGHTRAG_SYSTEM_PROMPT - try: - prompt_kwargs = deepcopy(prompt_kwargs) - except Exception as e: - log.warning(f"Error copying the prompt_kwargs: {e}") - prompt_kwargs = prompt_kwargs # Cache model_str = ( @@ -125,8 +119,6 @@ def __init__( ) self.cache_path = os.path.join(_cache_path, f"cache_{model_str}.db") - print(f"cache_path: {self.cache_path}") - CachedEngine.__init__(self, cache_path=self.cache_path) Component.__init__(self) GradComponent.__init__(self) @@ -167,6 +159,10 @@ def __init__( } self._teacher: Optional["Generator"] = None + def get_cache_path(self) -> str: + r"""Get the cache path for the generator.""" + return self.cache_path + @staticmethod def _get_default_mapping( output: "GeneratorOutput" = None, @@ -269,11 +265,9 @@ def _compose_model_kwargs(self, **model_kwargs) -> Dict: return combined_model_kwargs def print_prompt(self, **kwargs) -> str: - # prompt_kwargs_str = _convert_prompt_kwargs_to_str(kwargs) return self.prompt.print_prompt(**kwargs) def get_prompt(self, **kwargs) -> str: - # prompt_kwargs_str = _convert_prompt_kwargs_to_str(kwargs) return self.prompt.call(**kwargs) def _extra_repr(self) -> str: @@ -420,8 +414,12 @@ def forward( if self.mock_output: output = GeneratorOutput(data=self.mock_output_data) else: - if self.teacher_mode: + if self.teacher_mode and not isinstance(self, BackwardEngine): if not self._teacher: + print( + f"prompt_kwargs: {prompt_kwargs}, model_kwargs: {model_kwargs}" + ) + print(f"names: {self.name}") raise ValueError("Teacher generator is not set.") log.info(f"Using teacher: {self._teacher}") input_args = { @@ -706,7 +704,6 @@ def _run_callbacks( model_kwargs=model_kwargs, ) if output.error: - print(f"call back on failure: {output}") self.trigger_callbacks( "on_failure", output=output, @@ -830,9 +827,23 @@ def __call__(self, *args, **kwargs) -> Union[GeneratorOutputType, Any]: return self.call(*args, **kwargs) def _extra_repr(self) -> str: + # Create the string for model_kwargs s = f"model_kwargs={self.model_kwargs}, " + + # Create the string for trainable prompt_kwargs + prompt_kwargs_repr = [ + k + for k, v in self.prompt_kwargs.items() + if isinstance(v, Parameter) and v.requires_opt + ] + + s += f"trainable_prompt_kwargs={prompt_kwargs_repr}" return s + def to_dict(self) -> Dict[str, Any]: + r"""Convert the generator to a dictionary.""" + # exclude default functions + @staticmethod def failure_message_to_backward_engine( gradient_response: GeneratorOutput, @@ -854,6 +865,8 @@ def __init__(self, **kwargs): kwargs = {} kwargs["template"] = FEEDBACK_ENGINE_TEMPLATE super().__init__(**kwargs) + self.name = "BackwardEngine" + self.teacher_mode = False @staticmethod def failure_message_to_optimizer( @@ -954,7 +967,6 @@ def create_teacher_generator( call_logger = GeneratorCallLogger(save_dir="traces") def on_complete(output, input, prompt_kwargs, model_kwargs, logger_call: Callable): - print(f"on_complet output: {output}") logger_call( output=output, input=input, @@ -963,13 +975,9 @@ def on_complete(output, input, prompt_kwargs, model_kwargs, logger_call: Callabl ) for model in [llama3_model, gpt_3_model, gemini_model, claude_model]: - print(f"""model: {model["model_kwargs"]["model"]}""") generator = Generator(**model) - print("_kwargs: ", generator._kwargs) - teacher = create_teacher_generator(generator, **claude_model) - print(f"teacher: {teacher}") call_logger.register_generator("generator", "generator_call") # setup the callback @@ -983,8 +991,7 @@ def on_complete(output, input, prompt_kwargs, model_kwargs, logger_call: Callabl "input_str": "Hello, world!", } ) - print(f"output: {output}") break # test the backward engine - # TODO: test ollama and transformer client to update the change + # TODO: test ollama and transformer client to update the change \ No newline at end of file diff --git a/adalflow/tests/test_transformer_client.py b/adalflow/tests/test_transformer_client.py index d8562454b..63ba019d4 100644 --- a/adalflow/tests/test_transformer_client.py +++ b/adalflow/tests/test_transformer_client.py @@ -1,13 +1,171 @@ import unittest import torch - +from adalflow.components.model_client.transformers_client import TransformerEmbeddingModelClient, TransformerLLMModelClient, TransformerRerankerModelClient +from adalflow.core.types import ModelType +from adalflow.core import Embedder, Generator # Set the number of threads for PyTorch, avoid segementation fault torch.set_num_threads(1) torch.set_num_interop_threads(1) -class TestTransformerModelClient(unittest.TestCase): +class TestTransformerEmbeddingModelClient(unittest.TestCase): + def setUp(self) -> None: + self.query = "what is panda?" + self.documents = [ + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.", + "The red panda (Ailurus fulgens), also called the lesser panda, the red bear-cat, and the red cat-bear, is a mammal native to the eastern Himalayas and southwestern China.", + ] + + def test_execution(self): + test_input = "Hello word" + embedding_model = "thenlper/gte-base" + model_kwargs = {"model": embedding_model} + tokenizer_kwargs = { + "max_length": 512, + "padding": True, + "truncation": True, + "return_tensors": 'pt' + } + model_client = TransformerEmbeddingModelClient( + model_name=embedding_model, + tokenizer_kwargs=tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=test_input, model_kwargs=model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_integration_with_embedder(self): + + test_input = "Hello word" + embedding_model = "thenlper/gte-base" + model_kwargs = {"model": embedding_model} + tokenizer_kwargs = { + "max_length": 512, + "padding": True, + "truncation": True, + "return_tensors": 'pt' + } + model_client = TransformerEmbeddingModelClient( + model_name=embedding_model, + tokenizer_kwargs=tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + embedder = Embedder(model_client=model_client, + model_kwargs=model_kwargs + ) + output = embedder(test_input) + print(output) + +class TestTransformerLLMModelClient(unittest.TestCase): + + def setUp(self) -> None: + + self.model_kwargs = { + "model": "roneneldan/TinyStories-1M", + "temperature": 0.1, + "do_sample": True + } + self.tokenizer_kwargs = { + "max_length": True, + "truncation": True, + } + self.prompt_kwargs = { + "input_str": "Where is Brian?", # test input + } + self.chat_template_kwargs = { + "tokenize": False, + "add_generation_prompt": False + } + self.chat_template = """ + {%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }} + {%- elif message['role'] == 'system' %} + {{- '<>\\n' + message['content'].strip() + '\\n<>\\n\\n' }} + {%- elif message['role'] == 'assistant' %} + {{- '[ASST] ' + message['content'] + ' [/ASST]' + eos_token }} + {%- endif %} + {%- endfor %} + """ # Reference: https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-create-a-chat-template + + def test_exectution(self): + model_client = TransformerLLMModelClient( + tokenizer_kwargs=self.tokenizer_kwargs, + local_files_only=False, + init_from="autoclass", + apply_chat_template=True, + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input="Where is brian?", model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_response(self): + + """Test the TransformerLLM model with roneneldan/TinyStories-1M for generating a response.""" + model_client = TransformerLLMModelClient( + ) + + # Define a sample input + input_text = "Hello, what's the weather today?" + + # Test generating a response, providing the 'model' keyword + # response = transformer_llm_model_component(input=input_text, model=transformer_llm_model) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input_text, self.model_kwargs) + response = model_client.call(api_kwargs) + + # Check if the response is valid + self.assertIsInstance(response, list, "The response should be a list.") + self.assertTrue(all([isinstance(elmt, str) for elmt in response]), "all elements in the response list should be strings.") + self.assertTrue(len(response) > 0, "The response should not be empty.") + + # Optionally, print the response for visual verification during testing + print(f"Generated response: {response}") + + def test_integration_with_generator_autoclass(self): + model_client = TransformerLLMModelClient( + tokenizer_kwargs=self.tokenizer_kwargs, + local_files_only=False, + init_from="autoclass", + apply_chat_template=True, + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs + ) + generator = Generator( + model_client=model_client, + model_kwargs=self.model_kwargs, + # prompt_kwargs=prompt_kwargs, + # output_processors=JsonParser(), + ) + output = generator(prompt_kwargs=self.prompt_kwargs) + print(output) + + def test_integration_with_generator_pipeline(self): + model_client = TransformerLLMModelClient( + tokenizer_kwargs=self.tokenizer_kwargs, + local_files_only=False, + init_from="pipeline", + apply_chat_template=True, + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs + ) + generator = Generator( + model_client=model_client, + model_kwargs=self.model_kwargs, + # prompt_kwargs=prompt_kwargs, + # output_processors=JsonParser(), + ) + output = generator(prompt_kwargs=self.prompt_kwargs) + print(output) + +class TestTransformerRerankerModelClient(unittest.TestCase): def setUp(self) -> None: self.query = "what is panda?" @@ -16,107 +174,58 @@ def setUp(self) -> None: "The red panda (Ailurus fulgens), also called the lesser panda, the red bear-cat, and the red cat-bear, is a mammal native to the eastern Himalayas and southwestern China.", ] - # def test_transformer_embedder(self): - # transformer_embedder_model = "thenlper/gte-base" - # transformer_embedder_model_component = TransformerEmbedder( - # model_name=transformer_embedder_model - # ) - # print( - # f"Testing transformer embedder with model {transformer_embedder_model_component}" - # ) - # print("Testing transformer embedder") - # output = transformer_embedder_model_component( - # model=transformer_embedder_model, input="Hello world" - # ) - # print(output) - - # def test_transformer_client(self): - # transformer_client = TransformersClient() - # print("Testing transformer client") - # # run the model - # kwargs = { - # "model": "thenlper/gte-base", - # # "mock": False, - # } - # api_kwargs = transformer_client.convert_inputs_to_api_kwargs( - # input="Hello world", - # model_kwargs=kwargs, - # model_type=ModelType.EMBEDDER, - # ) - # # print(api_kwargs) - # output = transformer_client.call( - # api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER - # ) - - # # print(transformer_client) - # # print(output) - - # def test_transformer_reranker(self): - # transformer_reranker_model = "BAAI/bge-reranker-base" - # transformer_reranker_model_component = TransformerReranker() - # # print( - # # f"Testing transformer reranker with model {transformer_reranker_model_component}" - # # ) - - # model_kwargs = { - # "model": transformer_reranker_model, - # "documents": self.documents, - # "query": self.query, - # "top_k": 2, - # } - - # output = transformer_reranker_model_component( - # **model_kwargs, - # ) - # # assert output is a list of float with length 2 - # self.assertEqual(len(output), 2) - # self.assertEqual(type(output[0]), float) - - # def test_transformer_reranker_client(self): - # transformer_reranker_client = TransformersClient( - # model_name="BAAI/bge-reranker-base" - # ) - # print("Testing transformer reranker client") - # # run the model - # kwargs = { - # "model": "BAAI/bge-reranker-base", - # "documents": self.documents, - # "top_k": 2, - # } - # api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs( - # input=self.query, - # model_kwargs=kwargs, - # model_type=ModelType.RERANKER, - # ) - # print(api_kwargs) - # self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base") - # output = transformer_reranker_client.call( - # api_kwargs=api_kwargs, model_type=ModelType.RERANKER - # ) - # self.assertEqual(type(output), tuple) - - # def test_transformer_llm_response(self): - # from adalflow.components.model_client.transformers_client import TransformerLLM - - # """Test the TransformerLLM model with zephyr-7b-beta for generating a response.""" - # transformer_llm_model = "HuggingFaceH4/zephyr-7b-beta" - # transformer_llm_model_component = TransformerLLM( - # model_name=transformer_llm_model - # ) - - # # Define a sample input - # input_text = "Hello, what's the weather today?" - - # # Test generating a response, providing the 'model' keyword - # # response = transformer_llm_model_component(input=input_text, model=transformer_llm_model) - # response = transformer_llm_model_component(input_text=input_text) - - # # Check if the response is valid - # self.assertIsInstance(response, str, "The response should be a string.") - # self.assertTrue(len(response) > 0, "The response should not be empty.") - - # # Optionally, print the response for visual verification during testing - # print(f"Generated response: {response}") + def test_execution(self): + transformer_reranker_model = "BAAI/bge-reranker-base" + transformer_reranker_model_client = TransformerRerankerModelClient( + tokenizer_kwargs={"padding": True} + ) + print( + f"Testing TransformerRerankerModelClient with model {transformer_reranker_model}" + ) + + model_kwargs = { + "model": transformer_reranker_model, + "documents": self.documents, + "top_k": 2, + } + + api_kwargs = transformer_reranker_model_client.convert_inputs_to_api_kwargs(self.query, model_kwargs=model_kwargs) + output = transformer_reranker_model_client.call(api_kwargs) + # assert output is a list of list with length 2 + self.assertEqual(len(output), 2) + self.assertEqual(type(output[0]), list) + self.assertEqual(type(output[1]), list) + # assert output[0] is a list of int of length top_k + tok_k = model_kwargs["top_k"] + self.assertTrue(all([isinstance(elmt, int) for elmt in output[0]])) + self.assertEqual(len(output[0]), tok_k) + # assert output[1] is a list of float of length top_k + tok_k = model_kwargs["top_k"] + self.assertTrue(all([isinstance(elmt, float) for elmt in output[1]])) + self.assertEqual(len(output[1]), tok_k) + + def test_transformer_reranker_client(self): + transformer_reranker_client = TransformerRerankerModelClient( + tokenizer_kwargs={"padding": True} + ) + print("Testing transformer reranker client") + # run the model + kwargs = { + "model": "BAAI/bge-reranker-base", + "documents": self.documents, + "top_k": 2, + } + api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs( + input=self.query, + model_kwargs=kwargs, + + ) + print(api_kwargs) + self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base") + output = transformer_reranker_client.call( + api_kwargs=api_kwargs + ) + self.assertEqual(type(output), tuple) if __name__ == "__main__": diff --git a/adalflow/tests/test_transformers_models.py b/adalflow/tests/test_transformers_models.py new file mode 100644 index 000000000..fc553f6e2 --- /dev/null +++ b/adalflow/tests/test_transformers_models.py @@ -0,0 +1,186 @@ +"""This tests that the new transformer_client compatibility with several models hosted on HuggingFace.""" +import unittest +import torch +from adalflow.components.model_client.transformers_client import TransformerEmbeddingModelClient, TransformerLLMModelClient, TransformerRerankerModelClient +from transformers import AutoModelForSequenceClassification + +class TestEmbeddingModels(unittest.TestCase): + def setUp(self) -> None: + self.test_input = "Hello world" + self.auto_tokenizer_kwargs = { + "max_length": 512, + "padding": True, + "truncation": True, + "return_tensors": 'pt' + } + def test_thenhelper_gte_base(self): + embedding_model = "thenlper/gte-base" + model_kwargs = {"model": embedding_model} + + model_client = TransformerEmbeddingModelClient( + model_name=embedding_model, + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.test_input, model_kwargs=model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_jina_embeddings_V2_small_en(self): + embedding_model = "jinaai/jina-embeddings-v2-small-en" + model_kwargs = {"model": embedding_model} + model_client = TransformerEmbeddingModelClient( + model_name=embedding_model, + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.test_input, model_kwargs=model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_t5_small_standard_bahasa_cased(self): + embedding_model = "mesolitica/t5-small-standard-bahasa-cased" + model_kwargs = {"model": embedding_model} + + # Subclass TransformerEmbeddingModelClient to adapt the class to Encoder-Decoder architecture + class T5SmallStandardBahasaCased(TransformerEmbeddingModelClient): + + def compute_model_outputs(self, batch_dict: dict, model) -> dict: + print(batch_dict) + with torch.no_grad(): + outputs = model.encoder(**batch_dict) + return outputs + + + + model_client = T5SmallStandardBahasaCased( + model_name=embedding_model, + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.test_input, model_kwargs=model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_sentence_transformers_all_miniLM_L6_V2(self): + embedding_model = "sentence-transformers/all-MiniLM-L6-v2" + model_kwargs = {"model": embedding_model} + + model_client = TransformerEmbeddingModelClient( + model_name=embedding_model, + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs + ) + print( + f"Testing model client with model {embedding_model}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.test_input, model_kwargs=model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + +class TestLLMModels(unittest.TestCase): + """This class 'has accelerate' as a dependencie for both tests. + You might need to run the following command in the terminal. + `pip install accelerate` + """ + def setUp(self) -> None: + self.input_text = "Where is Brian?" + self.auto_tokenizer_kwargs = {} + self.model_kwargs = { + "temperature": 0.1, + "do_sample": True + } + self.tokenizer_decode_kwargs = { + "max_length": True, + "truncation": True, + } + self.prompt_kwargs = { + "input_str": "Where is Brian?", # test input + } + + def test_roneneld_tiny_stories_1M(self): + self.model_kwargs["model"] = "roneneldan/TinyStories-1M" + model_client = TransformerLLMModelClient( + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs, + local_files_only=False, + init_from="autoclass", + ) + print( + f"Testing model client with model {"roneneldan/TinyStories-1M"}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.input_text, model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + + def test_nickypro_tinyllama_15m(self): + self.model_kwargs["model"] = "nickypro/tinyllama-15M" + model_client = TransformerLLMModelClient( + auto_tokenizer_kwargs=self.auto_tokenizer_kwargs, + local_files_only=False, + init_from="autoclass", + ) + print( + f"Testing model client with model {"nickypro/tinyllama-15M"}" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(input=self.input_text, model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs=api_kwargs) + print(output) + +class TestRerankerModel(unittest.TestCase): + """This class has sentencepieces as a dependencie. + You might need to run the following command in the terminal. + `pip install transformers[sentencepiece`]` + """ + def setUp(self) -> None: + self.query = "Where is Brian." + self.documents = [ + "Brian is in the Kitchen.", + "Brian loves Adalflow.", + "Adalflow is a python library, not some food inside the kitchen.", + ] + self.model_kwargs = { + "documents": self.documents, + "top_k": 2, + } + + def test_jinja_reranker_V1_tiny_en(self): + self.model_kwargs["model"] = "jinaai/jina-reranker-v1-tiny-en" + model_client = TransformerRerankerModelClient( + tokenizer_kwargs={"padding": True}, + auto_model_kwargs={"num_labels": 1} + ) + print( + f"Testing model client with model jinaai/jina-reranker-v1-tiny-en" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(self.query, model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs) + + def test_baai_bge_reranker_base(self): + self.model_kwargs["model"] = "BAAI/bge-reranker-base" + model_client = TransformerRerankerModelClient( + tokenizer_kwargs={"padding": True}, + ) + print( + f"Testing model client with model BAAI/bge-reranker-base" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(self.query, model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs) + + def test_cross_encoder_ms_marco_minilm_L_2_V2(self): + self.model_kwargs["model"] = "cross-encoder/ms-marco-MiniLM-L-2-v2" + model_client = TransformerRerankerModelClient( + tokenizer_kwargs={"padding": True}, + ) + print( + f"Testing model client with model cross-encoder/ms-marco-MiniLM-L-2-v2" + ) + api_kwargs = model_client.convert_inputs_to_api_kwargs(self.query, model_kwargs=self.model_kwargs) + output = model_client.call(api_kwargs) + +if __name__ == "__main__": + unittest.main(verbosity=6) \ No newline at end of file diff --git a/tutorials/model_client.ipynb b/tutorials/model_client.ipynb index 60ea6585b..f3f302d9f 100644 --- a/tutorials/model_client.ipynb +++ b/tutorials/model_client.ipynb @@ -24,9 +24,9 @@ } ], "source": [ - "from lightrag.components.model_client import OpenAIClient\n", - "from lightrag.core.types import ModelType\n", - "from lightrag.utils import setup_env\n", + "from adalflow.components.model_client import OpenAIClient\n", + "from adalflow.core.types import ModelType\n", + "from adalflow.utils import setup_env\n", "\n", "openai_client = OpenAIClient()\n", "\n", @@ -61,6 +61,170 @@ "print(f\"reponse_embedder_output: {reponse_embedder_output}\")\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For local models, we can use the client classes from the `transformers_client` module:\n", + "- TransformerEmbeddingModelClient\n", + "- TransformerLLMModelClient\n", + "- TransformerRerankerModelClient" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'adalflow'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/workspaces/transformer_client/AdalFlow/tutorials/model_client.ipynb Cell 4\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39madalflow\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcomponents\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodel_client\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mtransformers_client\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m TransformerEmbeddingModelClient,\n\u001b[1;32m 3\u001b[0m TransformerLLMModelClient,\n\u001b[1;32m 4\u001b[0m TransformerRerankerModelClient\n\u001b[1;32m 5\u001b[0m )\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'adalflow'" + ] + } + ], + "source": [ + "from adalflow.components.model_client.transformers_client import (\n", + " TransformerEmbeddingModelClient,\n", + " TransformerLLMModelClient,\n", + " TransformerRerankerModelClient\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"Where is Brian?\"\n", + "documents = [\n", + " \"Brian is in the kitchen.\",\n", + " \"I love Adalflow.\",\n", + " \"Brian too.\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'TransformerEmbeddingModelClient' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/workspaces/transformer_client/AdalFlow/tutorials/model_client.ipynb Cell 6\u001b[0m line \u001b[0;36m9\n\u001b[1;32m 2\u001b[0m model_kwargs \u001b[39m=\u001b[39m {\u001b[39m\"\u001b[39m\u001b[39mmodel\u001b[39m\u001b[39m\"\u001b[39m: embedding_model}\n\u001b[1;32m 3\u001b[0m tokenizer_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 4\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmax_length\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m512\u001b[39m,\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mpadding\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtruncation\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mreturn_tensors\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m'\u001b[39m\u001b[39mpt\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 8\u001b[0m }\n\u001b[0;32m----> 9\u001b[0m model_client \u001b[39m=\u001b[39m TransformerEmbeddingModelClient(\n\u001b[1;32m 10\u001b[0m model_name\u001b[39m=\u001b[39membedding_model,\n\u001b[1;32m 11\u001b[0m tokenizer_kwargs\u001b[39m=\u001b[39mtokenizer_kwargs\n\u001b[1;32m 12\u001b[0m )\n\u001b[1;32m 13\u001b[0m \u001b[39mprint\u001b[39m(\n\u001b[1;32m 14\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mTesting model client with model \u001b[39m\u001b[39m{\u001b[39;00membedding_model\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m )\n\u001b[1;32m 16\u001b[0m api_kwargs \u001b[39m=\u001b[39m model_client\u001b[39m.\u001b[39mconvert_inputs_to_api_kwargs(\u001b[39minput\u001b[39m\u001b[39m=\u001b[39mquery, model_kwargs\u001b[39m=\u001b[39mmodel_kwargs)\n", + "\u001b[0;31mNameError\u001b[0m: name 'TransformerEmbeddingModelClient' is not defined" + ] + } + ], + "source": [ + "embedding_model = \"thenlper/gte-base\"\n", + "model_kwargs = {\"model\": embedding_model}\n", + "tokenizer_kwargs = {\n", + " \"max_length\": 512,\n", + " \"padding\": True,\n", + " \"truncation\": True,\n", + " \"return_tensors\": 'pt'\n", + "}\n", + "model_client = TransformerEmbeddingModelClient(\n", + " model_name=embedding_model,\n", + " tokenizer_kwargs=tokenizer_kwargs\n", + ")\n", + "print(\n", + " f\"Testing model client with model {embedding_model}\"\n", + ")\n", + "api_kwargs = model_client.convert_inputs_to_api_kwargs(input=query, model_kwargs=model_kwargs)\n", + "print(f\"api_kwargs: {api_kwargs}\")\n", + "output = model_client.call(api_kwargs=api_kwargs)\n", + "print(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'TransformerLLMModelClient' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/workspaces/transformer_client/AdalFlow/tutorials/model_client.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 6\u001b[0m tokenizer_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmax_length\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 8\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtruncation\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 9\u001b[0m }\n\u001b[1;32m 10\u001b[0m prompt_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 11\u001b[0m \u001b[39m\"\u001b[39m\u001b[39minput_str\u001b[39m\u001b[39m\"\u001b[39m: query, \n\u001b[1;32m 12\u001b[0m }\n\u001b[0;32m---> 13\u001b[0m model_client \u001b[39m=\u001b[39m TransformerLLMModelClient(\n\u001b[1;32m 14\u001b[0m tokenizer_kwargs\u001b[39m=\u001b[39mtokenizer_kwargs,\n\u001b[1;32m 15\u001b[0m local_files_only\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 16\u001b[0m init_from\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mautoclass\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 18\u001b[0m api_kwargs \u001b[39m=\u001b[39m model_client\u001b[39m.\u001b[39mconvert_inputs_to_api_kwargs(\u001b[39minput\u001b[39m\u001b[39m=\u001b[39mquery, model_kwargs\u001b[39m=\u001b[39mmodel_kwargs)\n\u001b[1;32m 19\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mapi_kwargs: \u001b[39m\u001b[39m{\u001b[39;00mapi_kwargs\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'TransformerLLMModelClient' is not defined" + ] + } + ], + "source": [ + "model_kwargs = {\n", + " \"model\": \"roneneldan/TinyStories-1M\",\n", + " \"temperature\": 0.1,\n", + " \"do_sample\": True\n", + "}\n", + "tokenizer_kwargs = {\n", + " \"max_length\": True,\n", + " \"truncation\": True,\n", + "}\n", + "prompt_kwargs = {\n", + " \"input_str\": query, \n", + "}\n", + "model_client = TransformerLLMModelClient(\n", + " tokenizer_kwargs=tokenizer_kwargs,\n", + " local_files_only=False,\n", + " init_from=\"autoclass\",\n", + " )\n", + "api_kwargs = model_client.convert_inputs_to_api_kwargs(input=query, model_kwargs=model_kwargs)\n", + "print(f\"api_kwargs: {api_kwargs}\")\n", + "output = model_client.call(api_kwargs=api_kwargs)\n", + "print(f\"reponse_embedder_output: {reponse_embedder_output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'TransformerRerankerModelClient' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/workspaces/transformer_client/AdalFlow/tutorials/model_client.ipynb Cell 8\u001b[0m line \u001b[0;36m2\n\u001b[1;32m 1\u001b[0m transformer_reranker_model \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mBAAI/bge-reranker-base\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m model_client \u001b[39m=\u001b[39m TransformerRerankerModelClient(\n\u001b[1;32m 3\u001b[0m tokenizer_kwargs\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mpadding\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mTrue\u001b[39;00m}\n\u001b[1;32m 4\u001b[0m )\n\u001b[1;32m 6\u001b[0m model_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmodel\u001b[39m\u001b[39m\"\u001b[39m: transformer_reranker_model,\n\u001b[1;32m 8\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mdocuments\u001b[39m\u001b[39m\"\u001b[39m: documents,\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtop_k\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m2\u001b[39m,\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 12\u001b[0m api_kwargs \u001b[39m=\u001b[39m model_client\u001b[39m.\u001b[39mconvert_inputs_to_api_kwargs(query, model_kwargs\u001b[39m=\u001b[39mmodel_kwargs)\n", + "\u001b[0;31mNameError\u001b[0m: name 'TransformerRerankerModelClient' is not defined" + ] + } + ], + "source": [ + "transformer_reranker_model = \"BAAI/bge-reranker-base\"\n", + "model_client = TransformerRerankerModelClient(\n", + " tokenizer_kwargs={\"padding\": True}\n", + ")\n", + "\n", + "model_kwargs = {\n", + " \"model\": transformer_reranker_model,\n", + " \"documents\": documents,\n", + " \"top_k\": 2,\n", + "}\n", + "\n", + "api_kwargs = model_client.convert_inputs_to_api_kwargs(query, model_kwargs=model_kwargs)\n", + "print(f\"api_kwargs: {api_kwargs}\")\n", + "output = model_client.call(api_kwargs)\n", + "print(f\"reponse_embedder_output: {reponse_embedder_output}\")" + ] + }, { "cell_type": "markdown", "metadata": {},