Skip to content

Commit f0e7781

Browse files
authored
Clean-up kernel loading and dispatch (#40542)
* clean * clean imporrts * fix imports * oups * more imports * more imports * more * move it to integrations * fix * style * fix doc
1 parent f68eb5f commit f0e7781

File tree

11 files changed

+110
-127
lines changed

11 files changed

+110
-127
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
import re
15+
from functools import partial
16+
from typing import Optional, Union
17+
18+
from ..modeling_flash_attention_utils import lazy_import_flash_attention
19+
from .flash_attention import flash_attention_forward
1520

1621

1722
try:
1823
from kernels import (
1924
Device,
2025
LayerRepository,
2126
Mode,
27+
get_kernel,
2228
register_kernel_mapping,
2329
replace_kernel_forward_from_hub,
2430
use_kernel_forward_from_hub,
2531
)
2632

27-
_hub_kernels_available = True
33+
_kernels_available = True
2834

2935
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
3036
"MultiScaleDeformableAttention": {
@@ -82,8 +88,9 @@
8288

8389
register_kernel_mapping(_KERNEL_MAPPING)
8490

85-
8691
except ImportError:
92+
_kernels_available = False
93+
8794
# Stub to make decorators int transformers work when `kernels`
8895
# is not installed.
8996
def use_kernel_forward_from_hub(*args, **kwargs):
@@ -104,16 +111,66 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
104111
def register_kernel_mapping(*args, **kwargs):
105112
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
106113

107-
_hub_kernels_available = False
114+
115+
def is_kernel(attn_implementation: Optional[str]) -> bool:
116+
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
117+
return (
118+
attn_implementation is not None
119+
and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
120+
)
108121

109122

110-
def is_hub_kernels_available():
111-
return _hub_kernels_available
123+
def load_and_register_kernel(attn_implementation: str) -> None:
124+
"""Load and register the kernel associated to `attn_implementation`."""
125+
if not is_kernel(attn_implementation):
126+
return
127+
if not _kernels_available:
128+
raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")
129+
130+
# Need to be imported here as otherwise we have a circular import in `modeling_utils`
131+
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
132+
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS
133+
134+
attention_wrapper = None
135+
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
136+
actual_attn_name = attn_implementation
137+
if "|" in attn_implementation:
138+
attention_wrapper, actual_attn_name = attn_implementation.split("|")
139+
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
140+
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
141+
# Extract repo_id and kernel_name from the string
142+
if ":" in actual_attn_name:
143+
repo_id, kernel_name = actual_attn_name.split(":")
144+
kernel_name = kernel_name.strip()
145+
else:
146+
repo_id = actual_attn_name
147+
kernel_name = None
148+
repo_id = repo_id.strip()
149+
# extract the rev after the @ if it exists
150+
repo_id, _, rev = repo_id.partition("@")
151+
repo_id = repo_id.strip()
152+
rev = rev.strip() if rev else None
153+
154+
# Load the kernel from hub
155+
try:
156+
kernel = get_kernel(repo_id, revision=rev)
157+
except Exception as e:
158+
raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.")
159+
# correctly wrap the kernel
160+
if hasattr(kernel, "flash_attn_varlen_func"):
161+
if attention_wrapper is None:
162+
attention_wrapper = flash_attention_forward
163+
kernel_function = partial(attention_wrapper, implementation=kernel)
164+
lazy_import_flash_attention(kernel)
165+
elif kernel_name is not None:
166+
kernel_function = getattr(kernel, kernel_name)
167+
# Register the kernel as a valid attention
168+
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
169+
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
112170

113171

114172
__all__ = [
115173
"LayerRepository",
116-
"is_hub_kernels_available",
117174
"use_kernel_forward_from_hub",
118175
"register_kernel_mapping",
119176
"replace_kernel_forward_from_hub",

src/transformers/modeling_flash_attention_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def _lazy_define_process_function(flash_function):
126126

127127
def lazy_import_flash_attention(implementation: Optional[str]):
128128
"""
129-
Lazy loading flash attention and returning the respective functions + flags back
129+
Lazily import flash attention and return the respective functions + flags.
130130
131-
NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
132-
can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
131+
NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
132+
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
133133
"""
134134
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
135135
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):

src/transformers/modeling_utils.py

Lines changed: 27 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@
4444
from torch.distributions import constraints
4545
from torch.utils.checkpoint import checkpoint
4646

47-
from transformers.utils import is_torchao_available
48-
49-
50-
if is_torchao_available():
51-
from torchao.quantization import Int4WeightOnlyConfig
52-
5347
from .configuration_utils import PretrainedConfig
5448
from .distributed import DistributedConfig
5549
from .dynamic_module_utils import custom_object_save
@@ -61,6 +55,7 @@
6155
from .integrations.flash_attention import flash_attention_forward
6256
from .integrations.flash_paged import paged_attention_forward
6357
from .integrations.flex_attention import flex_attention_forward
58+
from .integrations.hub_kernels import is_kernel, load_and_register_kernel
6459
from .integrations.sdpa_attention import sdpa_attention_forward
6560
from .integrations.sdpa_paged import sdpa_attention_paged_forward
6661
from .integrations.tensor_parallel import (
@@ -73,17 +68,8 @@
7368
verify_tp_plan,
7469
)
7570
from .loss.loss_utils import LOSS_MAPPING
76-
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
7771
from .modeling_flash_attention_utils import lazy_import_flash_attention
78-
from .pytorch_utils import ( # noqa: F401
79-
Conv1D,
80-
apply_chunking_to_forward,
81-
find_pruneable_heads_and_indices,
82-
id_tensor_storage,
83-
prune_conv1d_layer,
84-
prune_layer,
85-
prune_linear_layer,
86-
)
72+
from .pytorch_utils import id_tensor_storage
8773
from .quantizers import HfQuantizer
8874
from .quantizers.auto import get_hf_quantizer
8975
from .quantizers.quantizers_utils import get_module_from_name
@@ -124,6 +110,7 @@
124110
is_torch_npu_available,
125111
is_torch_xla_available,
126112
is_torch_xpu_available,
113+
is_torchao_available,
127114
logging,
128115
)
129116
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
@@ -138,9 +125,8 @@
138125
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
139126

140127

141-
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
142-
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
143-
128+
if is_torchao_available():
129+
from torchao.quantization import Int4WeightOnlyConfig
144130

145131
if is_accelerate_available():
146132
from accelerate import dispatch_model, infer_auto_device_map
@@ -164,32 +150,14 @@
164150
from safetensors.torch import load_file as safe_load_file
165151
from safetensors.torch import save_file as safe_save_file
166152

153+
if is_peft_available():
154+
from .utils import find_adapter_config_file
167155

168-
if is_kernels_available():
169-
from kernels import get_kernel
170-
171-
172-
logger = logging.get_logger(__name__)
173-
174-
175-
_init_weights = True
176-
_is_quantized = False
177-
_is_ds_init_called = False
178156
_torch_distributed_available = torch.distributed.is_available()
179-
180157
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
181158
if _is_dtensor_available:
182159
from torch.distributed.tensor import DTensor
183160

184-
185-
def is_local_dist_rank_0():
186-
return (
187-
torch.distributed.is_available()
188-
and torch.distributed.is_initialized()
189-
and int(os.environ.get("LOCAL_RANK", "-1")) == 0
190-
)
191-
192-
193161
if is_sagemaker_mp_enabled():
194162
import smdistributed.modelparallel.torch as smp
195163
from smdistributed.modelparallel import __version__ as SMP_VERSION
@@ -198,11 +166,24 @@ def is_local_dist_rank_0():
198166
else:
199167
IS_SAGEMAKER_MP_POST_1_10 = False
200168

201-
if is_peft_available():
202-
from .utils import find_adapter_config_file
203169

170+
logger = logging.get_logger(__name__)
204171

172+
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
173+
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
205174
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
175+
_init_weights = True
176+
_is_quantized = False
177+
_is_ds_init_called = False
178+
179+
180+
def is_local_dist_rank_0():
181+
return (
182+
torch.distributed.is_available()
183+
and torch.distributed.is_initialized()
184+
and int(os.environ.get("LOCAL_RANK", "-1")) == 0
185+
)
186+
206187

207188
TORCH_INIT_FUNCTIONS = {
208189
"uniform_": nn.init.uniform_,
@@ -2801,44 +2782,10 @@ def _check_and_adjust_attn_implementation(
28012782
and is_kernels_available()
28022783
):
28032784
applicable_attn_implementation = "kernels-community/flash-attn"
2804-
if applicable_attn_implementation is not None and re.match(
2805-
r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", applicable_attn_implementation
2806-
):
2807-
if not is_kernels_available():
2808-
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
2809-
attention_wrapper = None
2810-
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
2811-
actual_attn_name = applicable_attn_implementation
2812-
if "|" in applicable_attn_implementation:
2813-
attention_wrapper, actual_attn_name = applicable_attn_implementation.split("|")
2814-
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
2815-
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
2816-
# Extract repo_id and kernel_name from the string
2817-
if ":" in actual_attn_name:
2818-
repo_id, kernel_name = actual_attn_name.split(":")
2819-
kernel_name = kernel_name.strip()
2820-
else:
2821-
repo_id = actual_attn_name
2822-
kernel_name = None
2823-
repo_id = repo_id.strip()
2824-
# extract the rev after the @ if it exists
2825-
repo_id, _, rev = repo_id.partition("@")
2826-
repo_id = repo_id.strip()
2827-
rev = rev.strip() if rev else None
2785+
if is_kernel(applicable_attn_implementation):
28282786
try:
2829-
kernel = get_kernel(repo_id, revision=rev)
2830-
if hasattr(kernel, "flash_attn_varlen_func"):
2831-
if attention_wrapper is None:
2832-
attention_wrapper = flash_attention_forward
2833-
kernel_function = partial(attention_wrapper, implementation=kernel)
2834-
lazy_import_flash_attention(kernel)
2835-
elif kernel_name is not None:
2836-
kernel_function = getattr(kernel, kernel_name)
2837-
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
2838-
ALL_MASK_ATTENTION_FUNCTIONS.register(
2839-
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
2840-
)
2841-
# log that we used kernel fallback
2787+
load_and_register_kernel(applicable_attn_implementation)
2788+
# log that we used kernel fallback if successful
28422789
if attn_implementation == "flash_attention_2":
28432790
logger.warning_once(
28442791
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
@@ -2848,8 +2795,8 @@ def _check_and_adjust_attn_implementation(
28482795
if attn_implementation == "flash_attention_2":
28492796
self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
28502797
logger.warning_once(
2851-
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
2852-
"default attention implementation instead (sdpa if available, eager otherwise)."
2798+
f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
2799+
f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
28532800
)
28542801
try:
28552802
self._sdpa_can_dispatch(is_init_check)

src/transformers/models/blip/modeling_blip_text.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,8 @@
3131
BaseModelOutputWithPoolingAndCrossAttentions,
3232
CausalLMOutputWithCrossAttentions,
3333
)
34-
from ...modeling_utils import (
35-
PreTrainedModel,
36-
apply_chunking_to_forward,
37-
find_pruneable_heads_and_indices,
38-
prune_linear_layer,
39-
)
34+
from ...modeling_utils import PreTrainedModel
35+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
4036
from ...utils import logging
4137
from ...utils.deprecation import deprecate_kwarg
4238
from .configuration_blip import BlipTextConfig

src/transformers/models/bridgetower/modeling_bridgetower.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
ModelOutput,
3535
SequenceClassifierOutput,
3636
)
37-
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
38-
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
37+
from ...modeling_utils import PreTrainedModel
38+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
3939
from ...utils import auto_docstring, logging, torch_int
4040
from ...utils.deprecation import deprecate_kwarg
4141
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig

src/transformers/models/cvt/modeling_cvt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2525

2626
from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
27-
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
27+
from ...modeling_utils import PreTrainedModel
28+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
2829
from ...utils import auto_docstring, logging
2930
from .configuration_cvt import CvtConfig
3031

src/transformers/models/deprecated/mctct/modeling_mctct.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@
2828
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
2929
from ....modeling_layers import GradientCheckpointingLayer
3030
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
31-
from ....modeling_utils import (
32-
PreTrainedModel,
33-
apply_chunking_to_forward,
34-
find_pruneable_heads_and_indices,
35-
prune_linear_layer,
36-
)
31+
from ....modeling_utils import PreTrainedModel
32+
from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
3733
from ....utils import logging
3834
from .configuration_mctct import MCTCTConfig
3935

src/transformers/models/esm/modeling_esm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,9 @@
3131
SequenceClassifierOutput,
3232
TokenClassifierOutput,
3333
)
34-
from ...modeling_utils import (
35-
ALL_ATTENTION_FUNCTIONS,
36-
PreTrainedModel,
37-
find_pruneable_heads_and_indices,
38-
prune_linear_layer,
39-
)
34+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4035
from ...processing_utils import Unpack
36+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
4137
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
4238
from ...utils.generic import OutputRecorder, check_model_inputs
4339
from .configuration_esm import EsmConfig

src/transformers/models/evolla/modeling_evolla.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,9 @@
4141
ModelOutput,
4242
)
4343
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44-
from ...modeling_utils import (
45-
ALL_ATTENTION_FUNCTIONS,
46-
ModuleUtilsMixin,
47-
PreTrainedModel,
48-
find_pruneable_heads_and_indices,
49-
get_parameter_dtype,
50-
prune_linear_layer,
51-
)
44+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype
5245
from ...processing_utils import Unpack
46+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
5347
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
5448
from ...utils.deprecation import deprecate_kwarg
5549
from ...utils.generic import OutputRecorder, check_model_inputs

0 commit comments

Comments
 (0)