diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..0eb817c0a853 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,7 +17,7 @@ import inspect import math from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union import torch @@ -78,17 +78,8 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -if DIFFUSERS_ENABLE_HUB_KERNELS: - if not is_kernels_available(): - raise ImportError( - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." - ) - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func -else: - flash_attn_3_func_hub = None +_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {} +_PREPARED_BACKENDS: Set["AttentionBackendName"] = set() if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -231,7 +222,9 @@ def decorator(func): @classmethod def get_active_backend(cls): - return cls._active_backend, cls._backends[cls._active_backend] + backend = cls._active_backend + _ensure_attention_backend_ready(backend) + return backend, cls._backends[backend] @classmethod def list_backends(cls): @@ -258,7 +251,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke raise ValueError(f"Backend {backend} is not registered.") backend = AttentionBackendName(backend) - _check_attention_backend_requirements(backend) + _ensure_attention_backend_ready(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -452,6 +445,39 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) +def _ensure_flash_attn_3_func_hub_loaded(): + cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) + if cached is not None: + return cached + + from ..utils.kernels_utils import _get_fa3_from_hub + + flash_attn_interface_hub = _get_fa3_from_hub() + func = flash_attn_interface_hub.flash_attn_func + _BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func + + return func + + +_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = { + AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded, +} + + +def _prepare_attention_backend(backend: AttentionBackendName) -> None: + preparer = _BACKEND_PREPARERS.get(backend) + if preparer is not None: + preparer() + + +def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None: + if backend in _PREPARED_BACKENDS: + return + _check_attention_backend_requirements(backend) + _prepare_attention_backend(backend) + _PREPARED_BACKENDS.add(backend) + + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, @@ -1322,7 +1348,10 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = flash_attn_3_func_hub( + func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) + if func is None: + func = _ensure_flash_attn_3_func_hub_loaded() + out = func( q=query, k=key, v=value, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 91daca1ad809..daa5b913060f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -595,7 +595,7 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements + from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -607,7 +607,7 @@ def set_attention_backend(self, backend: str) -> None: if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend) - _check_attention_backend_requirements(backend) + _ensure_attention_backend_ready(backend) attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules():