Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/features/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ title: Transformers

## Model Initialization

To load the model, you can use the `from_transformers` function. It takes 2 arguments:
To load the model, you can use the `from_transformers` function. It takes 3 arguments:

- `model`: a `transformers` model (created with `AutoModelForCausalLM` for instance)
- `tokenizer_or_processor`: a `transformers` tokenizer (created with `AutoTokenizer` for instance, it must be an instance of either `PreTrainedTokenizer` or `PreTrainedTokenizerFast`)
- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype.

For instance:

Expand Down
1 change: 1 addition & 0 deletions docs/features/models/transformers_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ To load the model, you can use the `from_transformers` function. It takes 2 argu

- `model`: a `transformers` model (created with `AutoModelForImageTextToText` for instance)
- `tokenizer_or_processor`: a `transformers` processor (created with `AutoProcessor` for instance, it must be an instance of `ProcessorMixin`)
- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype.

For instance:

Expand Down
40 changes: 34 additions & 6 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
*,
device_dtype: Optional["torch.dtype"] = None,
):
"""
Parameters:
Expand All @@ -219,6 +221,9 @@ def __init__(
tokenizer
A `PreTrainedTokenizer`, or any tokenizer that is compatible with
the `transformers` API for tokenizers.
device_dtype
The dtype to use for the model. If not provided, the model will use
the default dtype.

"""
# We need to handle the cases in which jax/flax or tensorflow
Expand All @@ -237,6 +242,7 @@ def __init__(
self.model = model
self.hf_tokenizer = tokenizer
self.tokenizer = TransformerTokenizer(tokenizer)
self.device_dtype = device_dtype
self.type_adapter = TransformersTypeAdapter(tokenizer=tokenizer)

if (
Expand Down Expand Up @@ -287,7 +293,11 @@ def _prepare_model_inputs(
input_ids, attention_mask = self.tokenizer.encode(prompts)
inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
"attention_mask": (
attention_mask.to(self.model.device, dtype=self.device_dtype)
if self.device_dtype is not None
else attention_mask.to(self.model.device)
),
}

return prompts, inputs
Expand Down Expand Up @@ -600,7 +610,13 @@ class TransformersMultiModal(Transformers):

"""

def __init__(self, model: "PreTrainedModel", processor):
def __init__(
self,
model: "PreTrainedModel",
processor,
*,
device_dtype: Optional["torch.dtype"] = None,
):
"""Create a TransformersMultiModal model instance

We rely on the `__init__` method of the `Transformers` class to handle
Expand All @@ -614,6 +630,9 @@ def __init__(self, model: "PreTrainedModel", processor):
`transformers` API for models.
processor
A `ProcessorMixin` instance.
device_dtype
The dtype to use for the model. If not provided, the model will use
the default dtype.

"""
self.processor = processor
Expand All @@ -622,7 +641,7 @@ def __init__(self, model: "PreTrainedModel", processor):

tokenizer: "PreTrainedTokenizer" = self.processor.tokenizer

super().__init__(model, tokenizer)
super().__init__(model, tokenizer, device_dtype=device_dtype)

self.type_adapter = TransformersMultiModalTypeAdapter(
tokenizer=tokenizer
Expand Down Expand Up @@ -655,14 +674,20 @@ def _prepare_model_inputs(

inputs = self.processor(
**merged_prompts, padding=True, return_tensors="pt"
).to(self.model.device)
)
if self.device_dtype is not None:
inputs = inputs.to(self.model.device, dtype=self.device_dtype)
else:
inputs = inputs.to(self.model.device)

return merged_prompts["text"], inputs


def from_transformers(
model: "PreTrainedModel",
tokenizer_or_processor: Union["PreTrainedTokenizer", "ProcessorMixin"],
*,
device_dtype: Optional["torch.dtype"] = None,
) -> Union[Transformers, TransformersMultiModal]:
"""Create an Outlines `Transformers` or `TransformersMultiModal` model
instance from a `PreTrainedModel` instance and a `PreTrainedTokenizer` or
Expand All @@ -679,6 +704,9 @@ def from_transformers(
tokenizer_or_processor
A `transformers.PreTrainedTokenizer` or
`transformers.ProcessorMixin` instance.
device_dtype
The dtype to use for the model. If not provided, the model will use
the default dtype.

Returns
-------
Expand All @@ -693,10 +721,10 @@ def from_transformers(
tokenizer_or_processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)
):
tokenizer = tokenizer_or_processor
return Transformers(model, tokenizer)
return Transformers(model, tokenizer, device_dtype=device_dtype)
elif isinstance(tokenizer_or_processor, ProcessorMixin):
processor = tokenizer_or_processor
return TransformersMultiModal(model, processor)
return TransformersMultiModal(model, processor, device_dtype=device_dtype)
else:
raise ValueError(
"We could determine whether the model passed to `from_transformers`"
Expand Down
9 changes: 8 additions & 1 deletion tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel
import pytest
import torch
import transformers

import outlines
Expand Down Expand Up @@ -47,15 +48,17 @@ def test_transformers_instantiate_mamba():
assert isinstance(model, Transformers)


def test_transformers_instantiate_tokenizer_kwargs():
def test_transformers_instantiate_tokenizer_kwargs_dtype():
model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL),
transformers.AutoTokenizer.from_pretrained(
TEST_MODEL, additional_special_tokens=["<t1>", "<t2>"]
),
device_dtype=torch.bfloat16,
)
assert "<t1>" in model.tokenizer.special_tokens
assert "<t2>" in model.tokenizer.special_tokens
assert model.device_dtype == torch.bfloat16


@pytest.fixture
Expand Down Expand Up @@ -88,6 +91,10 @@ def test_transformers_call(model, model_bart):
result = model("Respond with one word. Not more.")
assert isinstance(result, str)

model.device_dtype = torch.bfloat16
result = model("Respond with one word. Not more.")
assert isinstance(result, str)

result = model_bart("Respond with one word. Not more.")
assert isinstance(result, str)

Expand Down
12 changes: 11 additions & 1 deletion tests/models/test_transformers_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import re
import torch
from enum import Enum

import pytest
Expand Down Expand Up @@ -47,15 +48,17 @@ def model():
return model


def test_transformers_multimodal_instantiate_simple():
def test_transformers_multimodal_instantiate():
model = outlines.from_transformers(
LlavaForConditionalGeneration.from_pretrained(TEST_MODEL),
AutoProcessor.from_pretrained(TEST_MODEL),
device_dtype=torch.bfloat16,
)
assert isinstance(model, TransformersMultiModal)
assert isinstance(model.tokenizer, TransformerTokenizer)
assert isinstance(model.type_adapter, TransformersMultiModalTypeAdapter)
assert model.tensor_library_name == "torch"
assert model.device_dtype == torch.bfloat16


def test_transformers_multimodal_simple(model, image):
Expand All @@ -74,6 +77,13 @@ def test_transformers_multimodal_call(model, image):
)
assert isinstance(result, str)

model.device_dtype = torch.bfloat16
result = model(
["<image>Describe this image in one sentence:", Image(image)],
max_new_tokens=2,
)
assert isinstance(result, str)


def test_transformers_multimodal_wrong_number_image(model, image):
with pytest.raises(ValueError):
Expand Down
Loading