diff --git a/docs/features/models/transformers.md b/docs/features/models/transformers.md index 77c5d4241..deeed93ba 100644 --- a/docs/features/models/transformers.md +++ b/docs/features/models/transformers.md @@ -12,10 +12,11 @@ title: Transformers ## Model Initialization -To load the model, you can use the `from_transformers` function. It takes 2 arguments: +To load the model, you can use the `from_transformers` function. It takes 3 arguments: - `model`: a `transformers` model (created with `AutoModelForCausalLM` for instance) - `tokenizer_or_processor`: a `transformers` tokenizer (created with `AutoTokenizer` for instance, it must be an instance of either `PreTrainedTokenizer` or `PreTrainedTokenizerFast`) +- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype. For instance: diff --git a/docs/features/models/transformers_multimodal.md b/docs/features/models/transformers_multimodal.md index 3d34a0f78..78e6c51a1 100644 --- a/docs/features/models/transformers_multimodal.md +++ b/docs/features/models/transformers_multimodal.md @@ -12,6 +12,7 @@ To load the model, you can use the `from_transformers` function. It takes 2 argu - `model`: a `transformers` model (created with `AutoModelForImageTextToText` for instance) - `tokenizer_or_processor`: a `transformers` processor (created with `AutoProcessor` for instance, it must be an instance of `ProcessorMixin`) +- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype. For instance: diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 4d6387a40..42f05553c 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -209,6 +209,8 @@ def __init__( self, model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", + *, + device_dtype: Optional["torch.dtype"] = None, ): """ Parameters: @@ -219,6 +221,9 @@ def __init__( tokenizer A `PreTrainedTokenizer`, or any tokenizer that is compatible with the `transformers` API for tokenizers. + device_dtype + The dtype to use for the model. If not provided, the model will use + the default dtype. """ # We need to handle the cases in which jax/flax or tensorflow @@ -237,6 +242,7 @@ def __init__( self.model = model self.hf_tokenizer = tokenizer self.tokenizer = TransformerTokenizer(tokenizer) + self.device_dtype = device_dtype self.type_adapter = TransformersTypeAdapter(tokenizer=tokenizer) if ( @@ -287,7 +293,11 @@ def _prepare_model_inputs( input_ids, attention_mask = self.tokenizer.encode(prompts) inputs = { "input_ids": input_ids.to(self.model.device), - "attention_mask": attention_mask.to(self.model.device), + "attention_mask": ( + attention_mask.to(self.model.device, dtype=self.device_dtype) + if self.device_dtype is not None + else attention_mask.to(self.model.device) + ), } return prompts, inputs @@ -600,7 +610,13 @@ class TransformersMultiModal(Transformers): """ - def __init__(self, model: "PreTrainedModel", processor): + def __init__( + self, + model: "PreTrainedModel", + processor, + *, + device_dtype: Optional["torch.dtype"] = None, + ): """Create a TransformersMultiModal model instance We rely on the `__init__` method of the `Transformers` class to handle @@ -614,6 +630,9 @@ def __init__(self, model: "PreTrainedModel", processor): `transformers` API for models. processor A `ProcessorMixin` instance. + device_dtype + The dtype to use for the model. If not provided, the model will use + the default dtype. """ self.processor = processor @@ -622,7 +641,7 @@ def __init__(self, model: "PreTrainedModel", processor): tokenizer: "PreTrainedTokenizer" = self.processor.tokenizer - super().__init__(model, tokenizer) + super().__init__(model, tokenizer, device_dtype=device_dtype) self.type_adapter = TransformersMultiModalTypeAdapter( tokenizer=tokenizer @@ -655,7 +674,11 @@ def _prepare_model_inputs( inputs = self.processor( **merged_prompts, padding=True, return_tensors="pt" - ).to(self.model.device) + ) + if self.device_dtype is not None: + inputs = inputs.to(self.model.device, dtype=self.device_dtype) + else: + inputs = inputs.to(self.model.device) return merged_prompts["text"], inputs @@ -663,6 +686,8 @@ def _prepare_model_inputs( def from_transformers( model: "PreTrainedModel", tokenizer_or_processor: Union["PreTrainedTokenizer", "ProcessorMixin"], + *, + device_dtype: Optional["torch.dtype"] = None, ) -> Union[Transformers, TransformersMultiModal]: """Create an Outlines `Transformers` or `TransformersMultiModal` model instance from a `PreTrainedModel` instance and a `PreTrainedTokenizer` or @@ -679,6 +704,9 @@ def from_transformers( tokenizer_or_processor A `transformers.PreTrainedTokenizer` or `transformers.ProcessorMixin` instance. + device_dtype + The dtype to use for the model. If not provided, the model will use + the default dtype. Returns ------- @@ -693,10 +721,10 @@ def from_transformers( tokenizer_or_processor, (PreTrainedTokenizer, PreTrainedTokenizerFast) ): tokenizer = tokenizer_or_processor - return Transformers(model, tokenizer) + return Transformers(model, tokenizer, device_dtype=device_dtype) elif isinstance(tokenizer_or_processor, ProcessorMixin): processor = tokenizer_or_processor - return TransformersMultiModal(model, processor) + return TransformersMultiModal(model, processor, device_dtype=device_dtype) else: raise ValueError( "We could determine whether the model passed to `from_transformers`" diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 10f572794..91b7ff7ab 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -3,6 +3,7 @@ from pydantic import BaseModel import pytest +import torch import transformers import outlines @@ -47,15 +48,17 @@ def test_transformers_instantiate_mamba(): assert isinstance(model, Transformers) -def test_transformers_instantiate_tokenizer_kwargs(): +def test_transformers_instantiate_tokenizer_kwargs_dtype(): model = outlines.from_transformers( transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL), transformers.AutoTokenizer.from_pretrained( TEST_MODEL, additional_special_tokens=["", ""] ), + device_dtype=torch.bfloat16, ) assert "" in model.tokenizer.special_tokens assert "" in model.tokenizer.special_tokens + assert model.device_dtype == torch.bfloat16 @pytest.fixture @@ -88,6 +91,10 @@ def test_transformers_call(model, model_bart): result = model("Respond with one word. Not more.") assert isinstance(result, str) + model.device_dtype = torch.bfloat16 + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + result = model_bart("Respond with one word. Not more.") assert isinstance(result, str) diff --git a/tests/models/test_transformers_multimodal.py b/tests/models/test_transformers_multimodal.py index ce968e1f6..4bdc7282e 100644 --- a/tests/models/test_transformers_multimodal.py +++ b/tests/models/test_transformers_multimodal.py @@ -2,6 +2,7 @@ import io import re +import torch from enum import Enum import pytest @@ -47,15 +48,17 @@ def model(): return model -def test_transformers_multimodal_instantiate_simple(): +def test_transformers_multimodal_instantiate(): model = outlines.from_transformers( LlavaForConditionalGeneration.from_pretrained(TEST_MODEL), AutoProcessor.from_pretrained(TEST_MODEL), + device_dtype=torch.bfloat16, ) assert isinstance(model, TransformersMultiModal) assert isinstance(model.tokenizer, TransformerTokenizer) assert isinstance(model.type_adapter, TransformersMultiModalTypeAdapter) assert model.tensor_library_name == "torch" + assert model.device_dtype == torch.bfloat16 def test_transformers_multimodal_simple(model, image): @@ -74,6 +77,13 @@ def test_transformers_multimodal_call(model, image): ) assert isinstance(result, str) + model.device_dtype = torch.bfloat16 + result = model( + ["Describe this image in one sentence:", Image(image)], + max_new_tokens=2, + ) + assert isinstance(result, str) + def test_transformers_multimodal_wrong_number_image(model, image): with pytest.raises(ValueError):