Skip to content

Commit 12f26fa

Browse files
committed
fix transformers deprecated warning + make tests run
1 parent bff15f2 commit 12f26fa

File tree

4 files changed

+750
-126
lines changed

4 files changed

+750
-126
lines changed

bergson/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tup
110110
cfg.model,
111111
device_map=device_map,
112112
quantization_config=quantization_config,
113-
torch_dtype=dtype,
113+
dtype=dtype,
114114
)
115115
target_modules = None
116116

@@ -120,7 +120,7 @@ def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tup
120120
peft_config.base_model_name_or_path, # type: ignore
121121
device_map=device_map,
122122
quantization_config=quantization_config,
123-
torch_dtype=dtype,
123+
dtype=dtype,
124124
)
125125

126126
model = PeftModel.from_pretrained(

0 commit comments

Comments
 (0)