diff --git a/models/xti_attention_processor.py b/models/xti_attention_processor.py index d938902..fb3dbd2 100644 --- a/models/xti_attention_processor.py +++ b/models/xti_attention_processor.py @@ -1,12 +1,16 @@ from typing import Dict, Optional import torch -from diffusers.models.cross_attention import CrossAttention +try: + from diffusers.models.cross_attention import CrossAttention as Attention +except ImportError: + print(f"current version of diffusers does not have diffusers.models.cross_attention, using diffusers.models.attention_processor.Attention instead.") + from diffusers.models.attention_processor import Attention class XTIAttenProc: - def __call__(self, attn: CrossAttention, + def __call__(self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None): @@ -31,7 +35,7 @@ def __call__(self, attn: CrossAttention, if _ehs is None: _ehs = hidden_states - elif attn.cross_attention_norm: + elif ((hasattr(attn, cross_attention_norm) and attn.cross_attention_norm) or attn.norm_cross is not None): _ehs = attn.norm_cross(_ehs) _ehs_bypass = attn.norm_cross(_ehs_bypass)