-
Notifications
You must be signed in to change notification settings - Fork 13
Description
Hello, I’m interested in the sequence parallel part of the code, but I have some questions regarding its behavior.
In the code snippet:
if max_seq_length <= max_tokens_per_device:
# take the seq length field and only retain seq lengths with indices that are valid for this rank
seq_indices = seq_lengths.cumsum(-1)
seq_indices = seq_indices[(seq_indices < end_index) & (seq_indices >= start_index)]
start_index_tensor = torch.tensor([start_index], device=seq_indices.device)
end_index_tensor = torch.tensor([end_index], device=seq_indices.device)
seq_lengths = seq_indices.diff(prepend=start_index_tensor, append=end_index_tensor)
seq_lengths = seq_lengths[seq_lengths > 0]
inputs["seq_lengths"] = seq_lengths
inputs["seq_parallel_group"] = None
I believe there might be a bug. Specifically, when batch_size > 1 and sequence_num > 1, even if a sequence’s length is smaller than max_tokens_per_device, it can still be distributed across multiple devices. Without sequence parallelism (SP) enabled, this could lead to incorrect attention computation.
For example, if seq_lengths = [19, 28, 24, 12, 25, 8, 4, 8] and we use 4 GPUs for parallelism, each GPU can be allocated a sequence length of 32. However, the second sequence (length 28) might get split across gpu-0 and gpu-1, leading to inconsistent results without sp.
Could you please provide more clarification on this? Thank you!