From ef6a4833a3254088c98dce9ff85ab03ef4f64002 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Oct 2025 15:14:27 +0530 Subject: [PATCH 1/7] refactor how attention kernels from hub are used. --- src/diffusers/models/attention_dispatch.py | 57 ++++++++++++++++------ src/diffusers/models/modeling_utils.py | 4 +- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..d741e56d37d1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,7 +17,7 @@ import inspect import math from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union import torch @@ -78,17 +78,8 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -if DIFFUSERS_ENABLE_HUB_KERNELS: - if not is_kernels_available(): - raise ImportError( - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." - ) - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func -else: - flash_attn_3_func_hub = None +flash_attn_3_func_hub = None +_PREPARED_BACKENDS: Set["AttentionBackendName"] = set() if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -231,7 +222,9 @@ def decorator(func): @classmethod def get_active_backend(cls): - return cls._active_backend, cls._backends[cls._active_backend] + backend = cls._active_backend + _ensure_attention_backend_ready(backend) + return backend, cls._backends[backend] @classmethod def list_backends(cls): @@ -258,7 +251,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke raise ValueError(f"Backend {backend} is not registered.") backend = AttentionBackendName(backend) - _check_attention_backend_requirements(backend) + _ensure_attention_backend_ready(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -452,6 +445,39 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) +def _ensure_flash_attn_3_func_hub_loaded(): + global flash_attn_3_func_hub + + if flash_attn_3_func_hub is not None: + return flash_attn_3_func_hub + + from ..utils.kernels_utils import _get_fa3_from_hub + + flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + + return flash_attn_3_func_hub + + +_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = { + AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded, +} + + +def _prepare_attention_backend(backend: AttentionBackendName) -> None: + preparer = _BACKEND_PREPARERS.get(backend) + if preparer is not None: + preparer() + + +def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None: + if backend in _PREPARED_BACKENDS: + return + _check_attention_backend_requirements(backend) + _prepare_attention_backend(backend) + _PREPARED_BACKENDS.add(backend) + + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, @@ -1322,7 +1348,8 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = flash_attn_3_func_hub( + func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded() + out = func( q=query, k=key, v=value, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac511..786d2e901b01 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -594,7 +594,7 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements + from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -606,7 +606,7 @@ def set_attention_backend(self, backend: str) -> None: if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend) - _check_attention_backend_requirements(backend) + _ensure_attention_backend_ready(backend) attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): From 7fd26bc037313980443c97f7bb898866c8f2ead7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Oct 2025 15:23:18 +0530 Subject: [PATCH 2/7] up --- src/diffusers/models/attention_dispatch.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d741e56d37d1..0eb817c0a853 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -78,7 +78,7 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -flash_attn_3_func_hub = None +_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {} _PREPARED_BACKENDS: Set["AttentionBackendName"] = set() if _CAN_USE_SAGE_ATTN: @@ -446,17 +446,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None def _ensure_flash_attn_3_func_hub_loaded(): - global flash_attn_3_func_hub - - if flash_attn_3_func_hub is not None: - return flash_attn_3_func_hub + cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) + if cached is not None: + return cached from ..utils.kernels_utils import _get_fa3_from_hub flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + func = flash_attn_interface_hub.flash_attn_func + _BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func - return flash_attn_3_func_hub + return func _BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = { @@ -1348,7 +1348,9 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded() + func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) + if func is None: + func = _ensure_flash_attn_3_func_hub_loaded() out = func( q=query, k=key, From f48ec46b0a9033ec929ae3c7c5b3cc7239defdb7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 24 Oct 2025 20:20:48 -1000 Subject: [PATCH 3/7] refactor according to Dhruv's ideas. Co-authored-by: Dhruv Nair --- src/diffusers/models/attention_dispatch.py | 99 +++++++++++----------- src/diffusers/models/modeling_utils.py | 10 ++- src/diffusers/utils/constants.py | 1 - tests/others/test_attention_backends.py | 1 - 4 files changed, 57 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0eb817c0a853..845cecb03e8a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -16,8 +16,9 @@ import functools import inspect import math +from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -40,7 +41,7 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS if TYPE_CHECKING: @@ -78,9 +79,6 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {} -_PREPARED_BACKENDS: Set["AttentionBackendName"] = set() - if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, @@ -222,9 +220,7 @@ def decorator(func): @classmethod def get_active_backend(cls): - backend = cls._active_backend - _ensure_attention_backend_ready(backend) - return backend, cls._backends[backend] + return cls._active_backend, cls._backends[cls._active_backend] @classmethod def list_backends(cls): @@ -242,6 +238,25 @@ def _is_context_parallel_enabled( return supports_context_parallel and is_degree_greater_than_1 +@dataclass +class _HubKernelConfig: + """Configuration for downloading and using a hub-based attention kernel.""" + + repo_id: str + function_attr: str + revision: Optional[str] = None + kernel_fn: Optional[Callable] = None + + +# Registry for hub-based attention kernels +_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( + repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" + ) +} + + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ @@ -251,7 +266,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke raise ValueError(f"Backend {backend} is not registered.") backend = AttentionBackendName(backend) - _ensure_attention_backend_ready(backend) + _check_attention_backend_requirements(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -398,13 +413,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None # TODO: add support Hub variant of FA3 varlen later elif backend in [AttentionBackendName._FLASH_3_HUB]: - if not DIFFUSERS_ENABLE_HUB_KERNELS: - raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." - ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend in [ @@ -445,39 +456,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) -def _ensure_flash_attn_3_func_hub_loaded(): - cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) - if cached is not None: - return cached - - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - func = flash_attn_interface_hub.flash_attn_func - _BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func - - return func - - -_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = { - AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded, -} - - -def _prepare_attention_backend(backend: AttentionBackendName) -> None: - preparer = _BACKEND_PREPARERS.get(backend) - if preparer is not None: - preparer() - - -def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None: - if backend in _PREPARED_BACKENDS: - return - _check_attention_backend_requirements(backend) - _prepare_attention_backend(backend) - _PREPARED_BACKENDS.add(backend) - - @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, @@ -581,6 +559,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx +# ===== Helpers for downloading kernels ===== +def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: + if backend not in _HUB_KERNELS_REGISTRY: + return + config = _HUB_KERNELS_REGISTRY[backend] + + if config._kernel_fn is not None: + return + + try: + from kernels import get_kernel + + kernel_module = get_kernel(config.repo_id, revision=config.revision) + kernel_func = getattr(kernel_module, config.function_attr) + + # Cache the downloaded kernel function in the config object + config._kernel_fn = kernel_func + + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") + raise + + # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding @@ -1348,9 +1349,7 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) - if func is None: - func = _ensure_flash_attn_3_func_hub_loaded() + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn out = func( q=query, k=key, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index daa5b913060f..3880418fb03e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None: available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + backend = AttentionBackendName(backend) - _ensure_attention_backend_ready(backend) + _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 42a53e181034..051a0c034e52 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,7 +46,6 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 42cdcd56f74a..8f4667792a02 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -7,7 +7,6 @@ ```bash export RUN_ATTENTION_BACKEND_TESTS=yes -export DIFFUSERS_ENABLE_HUB_KERNELS=yes pytest tests/others/test_attention_backends.py ``` From eed79ac63a32b432e25e95955f04752bd62d3b95 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 24 Oct 2025 20:24:03 -1000 Subject: [PATCH 4/7] empty Co-authored-by: Dhruv Nair From 7036bc3d1feb9df852aa2790b2ef9fffb601ce59 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 24 Oct 2025 20:24:26 -1000 Subject: [PATCH 5/7] empty Co-authored-by: Dhruv Nair From 52eace166a0ac134c1e579c6150a135510741c26 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 24 Oct 2025 20:24:41 -1000 Subject: [PATCH 6/7] empty Co-authored-by: dn6 From df9fb6b373e7ec3b0df361d6827b502a721b3a4c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 24 Oct 2025 21:14:27 -1000 Subject: [PATCH 7/7] up --- src/diffusers/models/attention_dispatch.py | 6 +++--- src/diffusers/utils/kernels_utils.py | 23 ---------------------- 2 files changed, 3 insertions(+), 26 deletions(-) delete mode 100644 src/diffusers/utils/kernels_utils.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 845cecb03e8a..9d8896f91b3c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -565,7 +565,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: return config = _HUB_KERNELS_REGISTRY[backend] - if config._kernel_fn is not None: + if config.kernel_fn is not None: return try: @@ -575,7 +575,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: kernel_func = getattr(kernel_module, config.function_attr) # Cache the downloaded kernel function in the config object - config._kernel_fn = kernel_func + config.kernel_fn = kernel_func except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") @@ -1349,7 +1349,7 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn out = func( q=query, k=key, diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py deleted file mode 100644 index 26d6e3972fb7..000000000000 --- a/src/diffusers/utils/kernels_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from ..utils import get_logger -from .import_utils import is_kernels_available - - -logger = get_logger(__name__) - - -_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" - - -def _get_fa3_from_hub(): - if not is_kernels_available(): - return None - else: - from kernels import get_kernel - - try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub - except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") - raise