-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Open
Description
Ref to the #40002 . This PR consoliated this function into transformers.modeling_falsh_attention_utils
and removed the index_first_axis
function that directly copied from flash-attention repo. I am wondering if they are functionally equally.
# The shape is recovered to the original one at the end, and ref to the comment,
# it actually equal to return input[indices]
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
# The currently one, I don't think the shape are exactly same here, cuz no shape recoverd
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
Metadata
Metadata
Assignees
Labels
No labels