Skip to content

Commit b96adb1

Browse files
committed
Refactor Cosmos 2.5 model and tests for improved precision and functionality
- Added torchvision transforms import to the Cosmos 2.5 model for enhanced functionality. - Updated max_size calculation in positional embedding classes to enforce strict behavior. - Changed precision from bfloat16 to float32 in test cases to improve numerical stability. - Tightened numerical difference assertions in tests to ensure higher accuracy in model outputs. These changes enhance the model's robustness and ensure better alignment with reference implementations.
1 parent bbfae59 commit b96adb1

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

fastvideo/models/dits/cosmos2_5.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import torch
88
import torch.nn as nn
9+
from torchvision import transforms
910

1011
from fastvideo.attention import DistributedAttention, LocalAttention
1112
from fastvideo.configs.models.dits.cosmos2_5 import Cosmos25VideoConfig
@@ -18,6 +19,7 @@
1819
from fastvideo.platforms import AttentionBackendEnum
1920

2021

22+
2123
class Cosmos25PatchEmbed(nn.Module):
2224
"""
2325
COSMOS 2.5 patch embedding - converts video (B, C, T, H, W) to patches (B, T', H', W', D).
@@ -579,7 +581,7 @@ def __init__(
579581
) -> None:
580582
super().__init__()
581583

582-
self.max_size = [size // patch for size, patch in zip(max_size, patch_size, strict=False)]
584+
self.max_size = [size // patch for size, patch in zip(max_size, patch_size, strict=True)]
583585
self.patch_size = patch_size
584586
self.base_fps = base_fps
585587
self.enable_fps_modulation = enable_fps_modulation
@@ -668,7 +670,7 @@ def __init__(
668670
) -> None:
669671
super().__init__()
670672

671-
self.max_size = [size // patch for size, patch in zip(max_size, patch_size, strict=False)]
673+
self.max_size = [size // patch for size, patch in zip(max_size, patch_size, strict=True)]
672674
self.patch_size = patch_size
673675
self.eps = eps
674676

@@ -907,7 +909,6 @@ def forward(
907909

908910
# 2. Concatenate padding mask if needed
909911
if self.concat_padding_mask and padding_mask is not None:
910-
from torchvision import transforms
911912
padding_mask = transforms.functional.resize(
912913
padding_mask,
913914
list(hidden_states.shape[-2:]),

fastvideo/tests/transformers/test_cosmos2_5.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ def load_reference_cosmos25_model(checkpoint_path: str, device, dtype):
177177
def test_cosmos25_transformer():
178178
"""Test COSMOS 2.5 transformer against reference implementation."""
179179
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180-
precision = torch.bfloat16
181-
precision_str = "bfloat16"
180+
precision = torch.float32
182181

183182
# Create COSMOS 2.5 specific config
184183
from fastvideo.configs.models.dits.cosmos2_5 import Cosmos25ArchConfig
@@ -441,8 +440,8 @@ def test_cosmos25_transformer():
441440

442441

443442
# Allow for some numerical differences due to implementation details
444-
assert max_diff < 1e-1, f"Maximum difference too large: {max_diff.item()}"
445-
assert mean_diff < 1e-2, f"Mean difference too large: {mean_diff.item()}"
443+
assert max_diff < 1e-4, f"Maximum difference too large: {max_diff.item()}"
444+
assert mean_diff < 1e-5, f"Mean difference too large: {mean_diff.item()}"
446445

447446
logger.info("✓ COSMOS 2.5 FastVideo implementation matches reference!")
448447
else:
@@ -454,8 +453,7 @@ def test_cosmos25_transformer():
454453
def test_cosmos25_transformer_video():
455454
"""Test COSMOS 2.5 transformer with video input (multiple frames)."""
456455
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
457-
precision = torch.bfloat16
458-
precision_str = "bfloat16"
456+
precision = torch.float32
459457

460458
# Create COSMOS 2.5 specific config
461459
from fastvideo.configs.models.dits.cosmos2_5 import Cosmos25ArchConfig

0 commit comments

Comments
 (0)