Skip to content

Commit 2c7961f

Browse files
committed
add single layer export/package test
1 parent 55558f3 commit 2c7961f

File tree

1 file changed

+79
-2
lines changed

1 file changed

+79
-2
lines changed

tests/torch/test_export.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,70 @@
1919

2020

2121
def export_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_package: bool, no_transforms: bool) -> None:
22+
input_shape = (-1, 8, -1, -1)
23+
archive_path = pathlib.Path(tmpdir) / "model.pt2"
24+
metadata_path = pathlib.Path("tests") / "torch" / "ftw-metadata.yaml"
25+
transforms = torch.nn.Sequential(T.Resize((16, 16)), T.Normalize(mean=[0.0], std=[3000.0]))
26+
model = torch.nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1, padding=0)
27+
model_program, transforms_program = export(
28+
model=model,
29+
transforms=None if no_transforms else transforms,
30+
input_shape=input_shape,
31+
device=device,
32+
dtype=torch.float32,
33+
)
34+
35+
if no_transforms:
36+
assert transforms_program is None
37+
38+
package(
39+
output_file=archive_path,
40+
model_program=model_program,
41+
transforms_program=transforms_program,
42+
metadata_path=metadata_path,
43+
aoti_compile_and_package=aoti_compile_and_package,
44+
)
45+
46+
# Validate that pt2 is loadable and model/transform are usable
47+
pt2 = load_pt2(archive_path)
48+
49+
x = torch.randn(1, 8, 8, 8, device=device, dtype=torch.float32)
50+
if pt2.aoti_runners != {}:
51+
model_aoti = pt2.aoti_runners["model"]
52+
preds = model_aoti(x)
53+
assert preds.shape == (1, 3, 8, 8)
54+
55+
if no_transforms:
56+
assert "transforms" not in pt2.aoti_runners
57+
else:
58+
assert "transforms" in pt2.aoti_runners
59+
60+
if "transforms" in pt2.aoti_runners:
61+
transforms_aoti = pt2.aoti_runners["transforms"]
62+
transformed = transforms_aoti(x)
63+
assert transformed.shape == (1, 8, 16, 16)
64+
else:
65+
model_exported = pt2.exported_programs["model"].module()
66+
preds = model_exported(x)
67+
assert preds.shape == (1, 3, 8, 8)
68+
69+
if no_transforms:
70+
assert "transforms" not in pt2.exported_programs
71+
else:
72+
assert "transforms" in pt2.exported_programs
73+
74+
if "transforms" in pt2.exported_programs:
75+
transforms_exported = pt2.exported_programs["transforms"].module()
76+
transformed = transforms_exported(x)
77+
assert transformed.shape == (1, 8, 16, 16)
78+
79+
# Validate metadata is valid yaml
80+
metadata = pt2.extra_files["mlm-metadata"]
81+
metadata = yaml.safe_load(metadata)
82+
MLModelProperties.model_validate(metadata["properties"])
83+
84+
85+
def export_ftw_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_package: bool, no_transforms: bool) -> None:
2286
input_shape = (-1, 8, -1, -1)
2387
archive_path = pathlib.Path(tmpdir) / "model.pt2"
2488
metadata_path = pathlib.Path("tests") / "torch" / "ftw-metadata.yaml"
@@ -83,16 +147,29 @@ def export_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_pack
83147
MLModelProperties.model_validate(metadata["properties"])
84148

85149

150+
@pytest.mark.parametrize("no_transforms", [True, False])
151+
@pytest.mark.parametrize("aoti_compile_and_package", [False, True])
152+
def test_export_cpu(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
153+
export_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms)
154+
155+
156+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
157+
@pytest.mark.parametrize("no_transforms", [True, False])
158+
@pytest.mark.parametrize("aoti_compile_and_package", [False, True])
159+
def test_export_cuda(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
160+
export_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms)
161+
162+
86163
@pytest.mark.slow
87164
@pytest.mark.parametrize("no_transforms", [True, False])
88165
@pytest.mark.parametrize("aoti_compile_and_package", [False, True])
89166
def test_ftw_export_cpu(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
90-
export_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms)
167+
export_ftw_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms)
91168

92169

93170
@pytest.mark.slow
94171
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
95172
@pytest.mark.parametrize("no_transforms", [True, False])
96173
@pytest.mark.parametrize("aoti_compile_and_package", [False, True])
97174
def test_ftw_export_cuda(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
98-
export_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms)
175+
export_ftw_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms)

0 commit comments

Comments
 (0)