Skip to content

Commit ac35306

Browse files
committed
update changelog
1 parent e02cb74 commit ac35306

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Add torch export and packaging utilities for combining a model, transforms, and MLM schema compliant metadata into a single `.pt2` archive.
1213
- Add [ML-Model Legacy](./docs/legacy/ml-model.md) document providing migration guidance
1314
from the deprecated [ML-Model](https://github.com/stac-extensions/ml-model) extension
1415
(relates to [stac-extensions/ml-model#16](https://github.com/stac-extensions/ml-model/pull/16)).

stac_model/torch/export.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@
1111
from torch.export.dynamic_shapes import Dim
1212
from torch.export.pt2_archive._package import package_pt2
1313

14-
import stac_model.torch.export as export
15-
1614
logging.basicConfig(level=logging.DEBUG)
1715
logger = logging.getLogger(__name__)
1816

1917

2018
def package_model_and_transforms(
2119
output_file: str,
22-
model_program: export.export.ExportedProgram,
23-
transforms_program: export.export.ExportedProgram | None = None,
20+
model_program: torch.export.ExportedProgram,
21+
transforms_program: torch.export.ExportedProgram | None = None,
2422
metadata_path: str | None = None,
2523
aoti_compile_and_package: bool = False,
2624
) -> None:
@@ -60,7 +58,7 @@ def package_model_and_transforms(
6058
tempfile.TemporaryDirectory() as transforms_tmpdir,
6159
):
6260
# Package and extract transforms from pt2 archive
63-
model_path = export._inductor.aoti_compile_and_package(
61+
model_path = torch._inductor.aoti_compile_and_package(
6462
model_program, package_path=os.path.join(archive_tmpdir, "model.pt2")
6563
)
6664

@@ -77,7 +75,7 @@ def package_model_and_transforms(
7775

7876
if transforms_program is not None:
7977
# Package and extract transforms from pt2 archive
80-
transforms_path = export._inductor.aoti_compile_and_package(
78+
transforms_path = torch._inductor.aoti_compile_and_package(
8179
transforms_program,
8280
package_path=os.path.join(archive_tmpdir, "transforms.pt2"),
8381
)
@@ -100,13 +98,13 @@ def package_model_and_transforms(
10098
)
10199

102100

103-
@export.no_grad()
101+
@torch.no_grad()
104102
def export_model_and_transforms(
105103
model: torch.nn.Module,
106104
transforms: torch.nn.Module,
107105
input_shape: Sequence[int],
108-
device: export.device,
109-
) -> tuple[export.export.ExportedProgram, export.export.ExportedProgram]:
106+
device: torch.device,
107+
) -> tuple[torch.export.ExportedProgram, torch.export.ExportedProgram]:
110108
"""Exports a model and its transforms to programs.
111109
112110
Args:
@@ -135,16 +133,16 @@ def export_model_and_transforms(
135133
transforms_arg = next(iter(inspect.signature(transforms.forward).parameters))
136134

137135
# Export model and transforms
138-
model_program = export.export.export(
136+
model_program = torch.export.export(
139137
mod=model, args=(example_inputs,), dynamic_shapes={model_arg: dims}
140138
)
141-
transforms_program = export.export.export(
139+
transforms_program = torch.export.export(
142140
mod=transforms, args=(example_inputs,), dynamic_shapes={transforms_arg: dims}
143141
)
144142
return model_program, transforms_program
145143

146144

147-
def _create_example_input_from_shape(input_shape: Sequence[int]) -> export.Tensor:
145+
def _create_example_input_from_shape(input_shape: Sequence[int]) -> torch.Tensor:
148146
"""Creates an example input tensor based on the provided input shape.
149147
150148
Args:
@@ -183,4 +181,4 @@ def _create_example_input_from_shape(input_shape: Sequence[int]) -> export.Tenso
183181
else:
184182
shape.append(input_shape[3])
185183

186-
return export.randn(*shape, requires_grad=False)
184+
return torch.randn(*shape, requires_grad=False)

0 commit comments

Comments
 (0)