https://github.com/lucidrains/CoCa-pytorch/blob/edee92c74e311ccfa4a0024412fd991c98aff5fd/coca_pytorch/coca_pytorch.py#L532 fyi the dist batch size isn't correct torch.arange(batch, device=device) -> torch.arange(text_latents.shape[0], device=device)