4444from torch .distributions import constraints
4545from torch .utils .checkpoint import checkpoint
4646
47- from transformers .utils import is_torchao_available
48-
49-
50- if is_torchao_available ():
51- from torchao .quantization import Int4WeightOnlyConfig
52-
5347from .configuration_utils import PretrainedConfig
5448from .distributed import DistributedConfig
5549from .dynamic_module_utils import custom_object_save
6155from .integrations .flash_attention import flash_attention_forward
6256from .integrations .flash_paged import paged_attention_forward
6357from .integrations .flex_attention import flex_attention_forward
58+ from .integrations .hub_kernels import is_kernel , load_and_register_kernel
6459from .integrations .sdpa_attention import sdpa_attention_forward
6560from .integrations .sdpa_paged import sdpa_attention_paged_forward
6661from .integrations .tensor_parallel import (
7368 verify_tp_plan ,
7469)
7570from .loss .loss_utils import LOSS_MAPPING
76- from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
7771from .modeling_flash_attention_utils import lazy_import_flash_attention
78- from .pytorch_utils import ( # noqa: F401
79- Conv1D ,
80- apply_chunking_to_forward ,
81- find_pruneable_heads_and_indices ,
82- id_tensor_storage ,
83- prune_conv1d_layer ,
84- prune_layer ,
85- prune_linear_layer ,
86- )
72+ from .pytorch_utils import id_tensor_storage
8773from .quantizers import HfQuantizer
8874from .quantizers .auto import get_hf_quantizer
8975from .quantizers .quantizers_utils import get_module_from_name
124110 is_torch_npu_available ,
125111 is_torch_xla_available ,
126112 is_torch_xpu_available ,
113+ is_torchao_available ,
127114 logging ,
128115)
129116from .utils .generic import _CAN_RECORD_REGISTRY , GeneralInterface , OutputRecorder
138125from .utils .quantization_config import BitsAndBytesConfig , QuantizationMethod
139126
140127
141- XLA_USE_BF16 = os .environ .get ("XLA_USE_BF16" , "0" ).upper ()
142- XLA_DOWNCAST_BF16 = os .environ .get ("XLA_DOWNCAST_BF16" , "0" ).upper ()
143-
128+ if is_torchao_available ():
129+ from torchao .quantization import Int4WeightOnlyConfig
144130
145131if is_accelerate_available ():
146132 from accelerate import dispatch_model , infer_auto_device_map
164150 from safetensors .torch import load_file as safe_load_file
165151 from safetensors .torch import save_file as safe_save_file
166152
153+ if is_peft_available ():
154+ from .utils import find_adapter_config_file
167155
168- if is_kernels_available ():
169- from kernels import get_kernel
170-
171-
172- logger = logging .get_logger (__name__ )
173-
174-
175- _init_weights = True
176- _is_quantized = False
177- _is_ds_init_called = False
178156_torch_distributed_available = torch .distributed .is_available ()
179-
180157_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal ("2.5" )
181158if _is_dtensor_available :
182159 from torch .distributed .tensor import DTensor
183160
184-
185- def is_local_dist_rank_0 ():
186- return (
187- torch .distributed .is_available ()
188- and torch .distributed .is_initialized ()
189- and int (os .environ .get ("LOCAL_RANK" , "-1" )) == 0
190- )
191-
192-
193161if is_sagemaker_mp_enabled ():
194162 import smdistributed .modelparallel .torch as smp
195163 from smdistributed .modelparallel import __version__ as SMP_VERSION
@@ -198,11 +166,24 @@ def is_local_dist_rank_0():
198166else :
199167 IS_SAGEMAKER_MP_POST_1_10 = False
200168
201- if is_peft_available ():
202- from .utils import find_adapter_config_file
203169
170+ logger = logging .get_logger (__name__ )
204171
172+ XLA_USE_BF16 = os .environ .get ("XLA_USE_BF16" , "0" ).upper ()
173+ XLA_DOWNCAST_BF16 = os .environ .get ("XLA_DOWNCAST_BF16" , "0" ).upper ()
205174SpecificPreTrainedModelType = TypeVar ("SpecificPreTrainedModelType" , bound = "PreTrainedModel" )
175+ _init_weights = True
176+ _is_quantized = False
177+ _is_ds_init_called = False
178+
179+
180+ def is_local_dist_rank_0 ():
181+ return (
182+ torch .distributed .is_available ()
183+ and torch .distributed .is_initialized ()
184+ and int (os .environ .get ("LOCAL_RANK" , "-1" )) == 0
185+ )
186+
206187
207188TORCH_INIT_FUNCTIONS = {
208189 "uniform_" : nn .init .uniform_ ,
@@ -2801,44 +2782,10 @@ def _check_and_adjust_attn_implementation(
28012782 and is_kernels_available ()
28022783 ):
28032784 applicable_attn_implementation = "kernels-community/flash-attn"
2804- if applicable_attn_implementation is not None and re .match (
2805- r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$" , applicable_attn_implementation
2806- ):
2807- if not is_kernels_available ():
2808- raise ValueError ("kernels is not installed. Please install it with `pip install kernels`." )
2809- attention_wrapper = None
2810- # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
2811- actual_attn_name = applicable_attn_implementation
2812- if "|" in applicable_attn_implementation :
2813- attention_wrapper , actual_attn_name = applicable_attn_implementation .split ("|" )
2814- # `transformers` has wrapper for sdpa, paged, flash, flex etc.
2815- attention_wrapper = ALL_ATTENTION_FUNCTIONS .get (attention_wrapper )
2816- # Extract repo_id and kernel_name from the string
2817- if ":" in actual_attn_name :
2818- repo_id , kernel_name = actual_attn_name .split (":" )
2819- kernel_name = kernel_name .strip ()
2820- else :
2821- repo_id = actual_attn_name
2822- kernel_name = None
2823- repo_id = repo_id .strip ()
2824- # extract the rev after the @ if it exists
2825- repo_id , _ , rev = repo_id .partition ("@" )
2826- repo_id = repo_id .strip ()
2827- rev = rev .strip () if rev else None
2785+ if is_kernel (applicable_attn_implementation ):
28282786 try :
2829- kernel = get_kernel (repo_id , revision = rev )
2830- if hasattr (kernel , "flash_attn_varlen_func" ):
2831- if attention_wrapper is None :
2832- attention_wrapper = flash_attention_forward
2833- kernel_function = partial (attention_wrapper , implementation = kernel )
2834- lazy_import_flash_attention (kernel )
2835- elif kernel_name is not None :
2836- kernel_function = getattr (kernel , kernel_name )
2837- ALL_ATTENTION_FUNCTIONS .register (applicable_attn_implementation , kernel_function )
2838- ALL_MASK_ATTENTION_FUNCTIONS .register (
2839- applicable_attn_implementation , ALL_MASK_ATTENTION_FUNCTIONS ["flash_attention_2" ]
2840- )
2841- # log that we used kernel fallback
2787+ load_and_register_kernel (applicable_attn_implementation )
2788+ # log that we used kernel fallback if successful
28422789 if attn_implementation == "flash_attention_2" :
28432790 logger .warning_once (
28442791 "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
@@ -2848,8 +2795,8 @@ def _check_and_adjust_attn_implementation(
28482795 if attn_implementation == "flash_attention_2" :
28492796 self ._flash_attn_2_can_dispatch () # will fail as fa2 is not available but raise the proper exception
28502797 logger .warning_once (
2851- f"Could not find a kernel repository ' { repo_id } ' compatible with your device in the hub: { e } . Using "
2852- " default attention implementation instead (sdpa if available, eager otherwise)."
2798+ f"Could not find a kernel matching ` { applicable_attn_implementation } ` compatible with your device in the "
2799+ f"hub: \n { e } . \n Using default attention implementation instead (sdpa if available, eager otherwise)."
28532800 )
28542801 try :
28552802 self ._sdpa_can_dispatch (is_init_check )
0 commit comments