-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
@triton.jit
def _attn_bwd(Q, K, V, mask, sm_scale,
DO,
DQ, DK, DV,
M, D,
stride_z, stride_h, stride_tok, stride_d,
mask_stride_z, mask_stride_h, mask_stride_tok, mask_stride_tokk,
H, N_CTX,
BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLK_SLICE_FACTOR: tl.constexpr,
HEAD_DIM: tl.constexpr,
USE_MASK: tl.constexpr):
The length of queries and keys/values should be equal now. Are you plan to support unequal input length?
Metadata
Metadata
Assignees
Labels
No labels