|
19 | 19 |
|
20 | 20 |
|
21 | 21 | def export_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_package: bool, no_transforms: bool) -> None:
|
| 22 | + input_shape = (-1, 8, -1, -1) |
| 23 | + archive_path = pathlib.Path(tmpdir) / "model.pt2" |
| 24 | + metadata_path = pathlib.Path("tests") / "torch" / "ftw-metadata.yaml" |
| 25 | + transforms = torch.nn.Sequential(T.Resize((16, 16)), T.Normalize(mean=[0.0], std=[3000.0])) |
| 26 | + model = torch.nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1, padding=0) |
| 27 | + model_program, transforms_program = export( |
| 28 | + model=model, |
| 29 | + transforms=None if no_transforms else transforms, |
| 30 | + input_shape=input_shape, |
| 31 | + device=device, |
| 32 | + dtype=torch.float32, |
| 33 | + ) |
| 34 | + |
| 35 | + if no_transforms: |
| 36 | + assert transforms_program is None |
| 37 | + |
| 38 | + package( |
| 39 | + output_file=archive_path, |
| 40 | + model_program=model_program, |
| 41 | + transforms_program=transforms_program, |
| 42 | + metadata_path=metadata_path, |
| 43 | + aoti_compile_and_package=aoti_compile_and_package, |
| 44 | + ) |
| 45 | + |
| 46 | + # Validate that pt2 is loadable and model/transform are usable |
| 47 | + pt2 = load_pt2(archive_path) |
| 48 | + |
| 49 | + x = torch.randn(1, 8, 8, 8, device=device, dtype=torch.float32) |
| 50 | + if pt2.aoti_runners != {}: |
| 51 | + model_aoti = pt2.aoti_runners["model"] |
| 52 | + preds = model_aoti(x) |
| 53 | + assert preds.shape == (1, 3, 8, 8) |
| 54 | + |
| 55 | + if no_transforms: |
| 56 | + assert "transforms" not in pt2.aoti_runners |
| 57 | + else: |
| 58 | + assert "transforms" in pt2.aoti_runners |
| 59 | + |
| 60 | + if "transforms" in pt2.aoti_runners: |
| 61 | + transforms_aoti = pt2.aoti_runners["transforms"] |
| 62 | + transformed = transforms_aoti(x) |
| 63 | + assert transformed.shape == (1, 8, 16, 16) |
| 64 | + else: |
| 65 | + model_exported = pt2.exported_programs["model"].module() |
| 66 | + preds = model_exported(x) |
| 67 | + assert preds.shape == (1, 3, 8, 8) |
| 68 | + |
| 69 | + if no_transforms: |
| 70 | + assert "transforms" not in pt2.exported_programs |
| 71 | + else: |
| 72 | + assert "transforms" in pt2.exported_programs |
| 73 | + |
| 74 | + if "transforms" in pt2.exported_programs: |
| 75 | + transforms_exported = pt2.exported_programs["transforms"].module() |
| 76 | + transformed = transforms_exported(x) |
| 77 | + assert transformed.shape == (1, 8, 16, 16) |
| 78 | + |
| 79 | + # Validate metadata is valid yaml |
| 80 | + metadata = pt2.extra_files["mlm-metadata"] |
| 81 | + metadata = yaml.safe_load(metadata) |
| 82 | + MLModelProperties.model_validate(metadata["properties"]) |
| 83 | + |
| 84 | + |
| 85 | +def export_ftw_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_package: bool, no_transforms: bool) -> None: |
22 | 86 | input_shape = (-1, 8, -1, -1)
|
23 | 87 | archive_path = pathlib.Path(tmpdir) / "model.pt2"
|
24 | 88 | metadata_path = pathlib.Path("tests") / "torch" / "ftw-metadata.yaml"
|
@@ -83,16 +147,29 @@ def export_model(tmpdir: Path, device: str | torch.device, aoti_compile_and_pack
|
83 | 147 | MLModelProperties.model_validate(metadata["properties"])
|
84 | 148 |
|
85 | 149 |
|
| 150 | +@pytest.mark.parametrize("no_transforms", [True, False]) |
| 151 | +@pytest.mark.parametrize("aoti_compile_and_package", [False, True]) |
| 152 | +def test_export_cpu(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None: |
| 153 | + export_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms) |
| 154 | + |
| 155 | + |
| 156 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") |
| 157 | +@pytest.mark.parametrize("no_transforms", [True, False]) |
| 158 | +@pytest.mark.parametrize("aoti_compile_and_package", [False, True]) |
| 159 | +def test_export_cuda(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None: |
| 160 | + export_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms) |
| 161 | + |
| 162 | + |
86 | 163 | @pytest.mark.slow
|
87 | 164 | @pytest.mark.parametrize("no_transforms", [True, False])
|
88 | 165 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True])
|
89 | 166 | def test_ftw_export_cpu(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
|
90 |
| - export_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms) |
| 167 | + export_ftw_model(tmpdir, "cpu", aoti_compile_and_package, no_transforms) |
91 | 168 |
|
92 | 169 |
|
93 | 170 | @pytest.mark.slow
|
94 | 171 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
95 | 172 | @pytest.mark.parametrize("no_transforms", [True, False])
|
96 | 173 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True])
|
97 | 174 | def test_ftw_export_cuda(tmpdir: Path, aoti_compile_and_package: bool, no_transforms: bool) -> None:
|
98 |
| - export_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms) |
| 175 | + export_ftw_model(tmpdir, "cuda", aoti_compile_and_package, no_transforms) |
0 commit comments