diff --git a/timm/models/davit.py b/timm/models/davit.py index 22b4a1a05f..347750652e 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -129,6 +129,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): self.groups = num_heads self.head_dim = dim // num_heads self.dynamic_scale = dynamic_scale + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) @@ -136,18 +137,23 @@ def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) if self.dynamic_scale: - q = q * N ** -0.5 + scale = N ** -0.5 else: - q = q * self.head_dim ** -0.5 - attn = q.transpose(-1, -2) @ k - attn = attn.softmax(dim=-1) - x = (attn @ v.transpose(-1, -2)).transpose(-1, -2) + scale = self.head_dim ** -0.5 - x = x.transpose(1, 2).reshape(B, N, C) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, scale=scale) + else: + q = q * scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + x = attn @ v + + x = x.permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) return x