We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4ae5c11 commit 911baa5Copy full SHA for 911baa5
native_sparse_attention_pytorch/triton_native_sparse_attention.py
@@ -1477,8 +1477,8 @@ def native_sparse_attend(
1477
fk: Float['b kh n d'],
1478
fv: Float['b kh n d'],
1479
block_size: int,
1480
- selected_block_indices: Int['b qh sel'] | Int['b kh sel'],
1481
- fmask: Bool['b qh sel'] | Bool['b kh sel'],
+ selected_block_indices: Int['b qh n sel'] | Int['b kh n sel'],
+ fmask: Bool['b qh n sel'] | Bool['b kh n sel'],
1482
return_lse = False
1483
):
1484
seq_len = fq.shape[-2]
0 commit comments