Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 6 additions & 45 deletions fastvideo/v1/tests/encoders/test_clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import maybe_download_model
from fastvideo.v1.configs.models.encoders import CLIPTextConfig
from torch.testing import assert_close

logger = init_logger(__name__)

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "hunyuanvideo-community/HunyuanVideo"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
"data", BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
TEXT_ENCODER_PATH = os.path.join(MODEL_PATH, "text_encoder_2")


Expand Down Expand Up @@ -79,18 +78,7 @@ def test_clip_encoder():
logger.info("Model1 has %d parameters", len(params1))
logger.info("Model2 has %d parameters", len(params2))

# Compare a few key parameters

# weight_diffs = []
# for (name1, param1), (name2, param2) in zip(
# sorted(params1.items()), sorted(params2.items())
# ):
# # if len(weight_diffs) < 5: # Just check a few parameters
# max_diff = torch.max(torch.abs(param1 - param2)).item()
# mean_diff = torch.mean(torch.abs(param1 - param2)).item()
# weight_diffs.append((name1, name2, max_diff, mean_diff))
# logger.info(f"Parameter: {name1} vs {name2}")
# logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}")
assert_close(model1.embed_tokens.weight, model2.embed_tokens.weight, atol=1e-4, rtol=1e-4)

# Load tokenizer
tokenizer, _ = load_tokenizer(tokenizer_type="clipL",
Expand Down Expand Up @@ -134,40 +122,13 @@ def test_clip_encoder():

assert last_hidden_state1.shape == last_hidden_state2.shape, \
f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}"

max_diff_hidden = torch.max(
torch.abs(last_hidden_state1 - last_hidden_state2))
mean_diff_hidden = torch.mean(
torch.abs(last_hidden_state1 - last_hidden_state2))

logger.info("Maximum difference in last hidden states: %f",
max_diff_hidden.item())
logger.info("Mean difference in last hidden states: %f",
mean_diff_hidden.item())

# Compare pooler outputs
pooler_output1 = outputs1.pooler_output
pooler_output2 = outputs2.pooler_output

assert pooler_output1.shape == pooler_output2.shape, \
f"Pooler output shapes don't match: {pooler_output1.shape} vs {pooler_output2.shape}"

max_diff_pooler = torch.max(
torch.abs(pooler_output1 - pooler_output2))
mean_diff_pooler = torch.mean(
torch.abs(pooler_output1 - pooler_output2))

logger.info("Maximum difference in pooler outputs: %f",
max_diff_pooler.item())
logger.info("Mean difference in pooler outputs: %f",
mean_diff_pooler.item())

# Check if outputs are similar (allowing for small numerical differences)
assert mean_diff_hidden < 1e-2, \
f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}"
assert mean_diff_pooler < 1e-2, \
f"Pooler outputs differ significantly: mean diff = {mean_diff_pooler.item()}"
assert max_diff_hidden < 1e-1, \
f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}"
assert max_diff_pooler < 1e-2, \
f"Pooler outputs differ significantly: max diff = {max_diff_pooler.item()}"
assert_close(pooler_output1, pooler_output2, atol=1e-3, rtol=1e-3)
assert_close(last_hidden_state1, last_hidden_state2, atol=1e-3, rtol=1e-3)

43 changes: 7 additions & 36 deletions fastvideo/v1/tests/encoders/test_llama_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
from fastvideo.v1.models.loader.component_loader import TextEncoderLoader
from fastvideo.v1.utils import maybe_download_model
from fastvideo.v1.configs.models.encoders import LlamaConfig
from torch.testing import assert_close

logger = init_logger(__name__)

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "hunyuanvideo-community/HunyuanVideo"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
'data', BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
TEXT_ENCODER_PATH = os.path.join(MODEL_PATH, "text_encoder")
TOKENIZER_PATH = os.path.join(MODEL_PATH, "tokenizer")

Expand Down Expand Up @@ -75,37 +74,24 @@ def test_llama_encoder():
logger.info("Model1 has %d parameters", len(params1))
logger.info("Model2 has %d parameters", len(params2))

# Compare a few key parameters
weight_diffs = []

# check if embed_tokens are the same
print(model1.embed_tokens.weight.shape, model2.embed_tokens.weight.shape)
assert torch.allclose(model1.embed_tokens.weight,
model2.embed_tokens.weight)
assert_close(model1.embed_tokens.weight, model2.embed_tokens.weight, atol=1e-4, rtol=1e-4)

weights = [
"layers.{}.input_layernorm.weight",
"layers.{}.post_attention_layernorm.weight"
]
# for (name1, param1), (name2, param2) in zip(
# sorted(params1.items()), sorted(params2.items())
# ):
for layer_idx in range(hf_config.num_hidden_layers):
for w in weights:
name1 = w.format(layer_idx)
name2 = w.format(layer_idx)
p1 = params1[name1]
p2 = params2[name2]
# print(type(p2))
if "gate_up" in name2:
# print("skipping gate_up")
continue
try:
# logger.info(f"Parameter: {name1} vs {name2}")
max_diff = torch.max(torch.abs(p1 - p2)).item()
mean_diff = torch.mean(torch.abs(p1 - p2)).item()
weight_diffs.append((name1, name2, max_diff, mean_diff))
# logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}")
except Exception as e:
logger.info("Error comparing %s and %s: %s", name1, name2, e)
assert_close(p1, p2, atol=1e-4, rtol=1e-4)

tokenizer, _ = load_tokenizer(tokenizer_type="llm",
tokenizer_path=TOKENIZER_PATH,
Expand Down Expand Up @@ -150,19 +136,4 @@ def test_llama_encoder():

assert last_hidden_state1.shape == last_hidden_state2.shape, \
f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}"

max_diff_hidden = torch.max(
torch.abs(last_hidden_state1 - last_hidden_state2))
mean_diff_hidden = torch.mean(
torch.abs(last_hidden_state1 - last_hidden_state2))

logger.info("Maximum difference in last hidden states: %f",
max_diff_hidden.item())
logger.info("Mean difference in last hidden states: %f",
mean_diff_hidden.item())

# Check if outputs are similar (allowing for small numerical differences)
assert mean_diff_hidden < 1e-2, \
f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}"
assert max_diff_hidden < 1e-1, \
f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}"
assert_close(last_hidden_state1, last_hidden_state2, atol=1e-3, rtol=1e-3)
34 changes: 5 additions & 29 deletions fastvideo/v1/tests/encoders/test_t5_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
from fastvideo.v1.utils import maybe_download_model, PRECISION_TO_TYPE
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.configs.models.encoders import T5Config
from torch.testing import assert_close

logger = init_logger(__name__)

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
'data', BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
TEXT_ENCODER_PATH = os.path.join(MODEL_PATH, "text_encoder")
TOKENIZER_PATH = os.path.join(MODEL_PATH, "tokenizer")

Expand Down Expand Up @@ -57,7 +56,6 @@ def test_t5_encoder():
logger.info("Model1 has %s parameters", len(params1))
logger.info("Model2 has %s parameters", len(params2))

weight_diffs = []
# check if embed_tokens are the same
weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \
"encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\
Expand All @@ -70,15 +68,7 @@ def test_t5_encoder():
p1 = params1[name1]
p2 = params2[name2]
assert p1.dtype == p2.dtype
try:
logger.info("Parameter: %s vs %s", name1, name2)
max_diff = torch.max(torch.abs(p1 - p2)).item()
mean_diff = torch.mean(torch.abs(p1 - p2)).item()
weight_diffs.append((name1, name2, max_diff, mean_diff))
logger.info(" Max diff: %s, Mean diff: %s", max_diff,
mean_diff)
except Exception as e:
logger.info("Error comparing %s and %s: %s", name1, name2, e)
assert_close(p1, p2, atol=1e-4, rtol=1e-4)

# Test with some sample prompts
prompts = [
Expand Down Expand Up @@ -122,19 +112,5 @@ def test_t5_encoder():

assert last_hidden_state1.shape == last_hidden_state2.shape, \
f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}"

max_diff_hidden = torch.max(
torch.abs(last_hidden_state1 - last_hidden_state2))
mean_diff_hidden = torch.mean(
torch.abs(last_hidden_state1 - last_hidden_state2))

logger.info("Maximum difference in last hidden states: %s",
max_diff_hidden.item())
logger.info("Mean difference in last hidden states: %s",
mean_diff_hidden.item())

# Check if outputs are similar (allowing for small numerical differences)
assert mean_diff_hidden < 1e-4, \
f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}"
assert max_diff_hidden < 1e-4, \
f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}"

assert_close(last_hidden_state1, last_hidden_state2, atol=1e-3, rtol=1e-3)
4 changes: 1 addition & 3 deletions fastvideo/v1/tests/transformers/test_hunyuanvideo_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "hunyuanvideo-community/HunyuanVideo"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
"data", BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer")
CONFIG_PATH = os.path.join(TRANSFORMER_PATH, "config.json")

Expand Down
13 changes: 3 additions & 10 deletions fastvideo/v1/tests/transformers/test_wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from diffusers import WanTransformer3DModel
from torch.testing import assert_close

from fastvideo.v1.forward_context import set_forward_context
from fastvideo.v1.fastvideo_args import FastVideoArgs
Expand All @@ -21,9 +22,7 @@
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
'data', BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer")


Expand Down Expand Up @@ -120,10 +119,4 @@ def test_wan_transformer():
assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}"

# Check if outputs are similar (allowing for small numerical differences)
max_diff = torch.max(torch.abs(output1 - output2))
mean_diff = torch.mean(torch.abs(output1 - output2))
logger.info("Max Diff: %s", max_diff.item())
logger.info("Mean Diff: %s", mean_diff.item())
assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}"
# mean diff
assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}"
assert_close(output1, output2, atol=1e-3, rtol=1e-3)
4 changes: 1 addition & 3 deletions fastvideo/v1/tests/vaes/test_hunyuan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "hunyuanvideo-community/HunyuanVideo"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
"data", BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
VAE_PATH = os.path.join(MODEL_PATH, "vae")
CONFIG_PATH = os.path.join(VAE_PATH, "config.json")

Expand Down
22 changes: 4 additions & 18 deletions fastvideo/v1/tests/vaes/test_wan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch
from diffusers import AutoencoderKLWan

from torch.testing import assert_close
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.models.loader.component_loader import VAELoader
Expand All @@ -18,9 +18,7 @@
os.environ["MASTER_PORT"] = "29503"

BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
local_dir=os.path.join(
'data', BASE_MODEL_PATH))
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH)
VAE_PATH = os.path.join(MODEL_PATH, "vae")


Expand Down Expand Up @@ -69,13 +67,7 @@ def test_wan_vae():
# Check if latents have the same shape
assert latent1.mean.shape == latent2.mean.shape, f"Latent shapes don't match: {latent1.mean.shape} vs {latent2.mean.shape}"
# Check if latents are similar
max_diff_encode = torch.max(torch.abs(latent1.mean - latent2.mean))
mean_diff_encode = torch.mean(torch.abs(latent1.mean - latent2.mean))
logger.info("Maximum difference between encoded latents: %s",
max_diff_encode.item())
logger.info("Mean difference between encoded latents: %s",
mean_diff_encode.item())
assert max_diff_encode < 1e-5, f"Encoded latents differ significantly: max diff = {mean_diff_encode.item()}"
assert_close(latent1.mean, latent2.mean, atol=1e-4, rtol=1e-4)
# Test decoding
logger.info("Testing decoding...")
latent1_tensor = latent1.mode()
Expand All @@ -97,10 +89,4 @@ def test_wan_vae():
assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}"

# Check if outputs are similar
max_diff_decode = torch.max(torch.abs(output1 - output2))
mean_diff_decode = torch.mean(torch.abs(output1 - output2))
logger.info("Maximum difference between decoded outputs: %s",
max_diff_decode.item())
logger.info("Mean difference between decoded outputs: %s",
mean_diff_decode.item())
assert max_diff_decode < 1e-5, f"Decoded outputs differ significantly: max diff = {mean_diff_decode.item()}"
assert_close(output1, output2, atol=1e-3, rtol=1e-3)
Loading