-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
params:
batch_size: 64, num_heads: 16, seq_len: 2048, dimension: 64, type:torch.float16, use mask, A800
example:
def func(y, grad):
y.backward(grad, retain_graph=True)
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.uint8, device="cuda", requires_grad=False))
t = benchmark.Timer(
stmt="flash_attention_custom_mask(q, k, v, mask, scale)",
globals={"flash_attention_custom_mask": flash_attention_custom_mask, "q": q, "k": k,"v":v,"mask":mask, "scale":0.5},
num_threads=torch.get_num_threads(),
)
fwd_time = t.timeit(10).mean * 1000
torch.cuda.synchronize()
torch.cuda.empty_cache()
y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()
for x in [q,k,v]:
x.grad = None
grad = torch.rand_like(y)
t = benchmark.Timer(
stmt="f(y,grad)",
globals={"f": func, "y": y, "grad": grad},
num_threads=torch.get_num_threads(),
)
bwd_time = t.timeit(10).mean * 1000
error:
Traceback (most recent call last):
File "week5_mask_test.py", line 372, in <module>
bwd_time = t.timeit(10).mean * 1000
File "/usr/local/lib/python3.8/dist-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit
self._timeit(number=max(int(number // 100), 2))
File "/usr/local/lib/python3.8/dist-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit
return max(self._timer.timeit(number), 1e-9)
File "/usr/lib/python3.8/timeit.py", line 177, in timeit
timing = self.inner(it, self.timer)
File "<timeit-src>", line 6, in inner
File "week5_mask_test.py", line 64, in func
y.backward(grad, retain_graph=True)
File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/hadoop-perception/flashattention2-custom-mask-main/fa2_custom_mask/fa2_custom_mask.py", line 87, in backward
_attn_bwd[grid](
File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
some attempts:
The issue seems to be related to incorrect tiling block size. Therefore, I modified the config parameters in the fa2_fwd file.
configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64]\ # delete 128
for BN in [32, 64]\
for s in ([1] if is_hip() else [3, 4, 7])\
for w in [4, 8]\
]
and this issue still occurs.
looking forward to your response.
Metadata
Metadata
Assignees
Labels
No labels