Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union

import torch

Expand Down Expand Up @@ -78,17 +78,8 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {}
_PREPARED_BACKENDS: Set["AttentionBackendName"] = set()

if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -231,7 +222,9 @@ def decorator(func):

@classmethod
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
backend = cls._active_backend
_ensure_attention_backend_ready(backend)
return backend, cls._backends[backend]

@classmethod
def list_backends(cls):
Expand All @@ -258,7 +251,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
raise ValueError(f"Backend {backend} is not registered.")

backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
_ensure_attention_backend_ready(backend)

old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
Expand Down Expand Up @@ -452,6 +445,39 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)


def _ensure_flash_attn_3_func_hub_loaded():
cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
if cached is not None:
return cached

from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
func = flash_attn_interface_hub.flash_attn_func
_BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func

return func


_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = {
AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded,
}


def _prepare_attention_backend(backend: AttentionBackendName) -> None:
preparer = _BACKEND_PREPARERS.get(backend)
if preparer is not None:
preparer()


def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None:
if backend in _PREPARED_BACKENDS:
return
_check_attention_backend_requirements(backend)
_prepare_attention_backend(backend)
_PREPARED_BACKENDS.add(backend)


@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
Expand Down Expand Up @@ -1322,7 +1348,10 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
if func is None:
func = _ensure_flash_attn_3_func_hub_loaded()
out = func(
q=query,
k=key,
v=value,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def set_attention_backend(self, backend: str) -> None:
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready

# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
Expand All @@ -606,7 +606,7 @@ def set_attention_backend(self, backend: str) -> None:
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
_ensure_attention_backend_ready(backend)

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
Expand Down
Loading