11
11
from torch .export .dynamic_shapes import Dim
12
12
from torch .export .pt2_archive ._package import package_pt2
13
13
14
- import stac_model .torch .export as export
15
-
16
14
logging .basicConfig (level = logging .DEBUG )
17
15
logger = logging .getLogger (__name__ )
18
16
19
17
20
18
def package_model_and_transforms (
21
19
output_file : str ,
22
- model_program : export .export .ExportedProgram ,
23
- transforms_program : export .export .ExportedProgram | None = None ,
20
+ model_program : torch .export .ExportedProgram ,
21
+ transforms_program : torch .export .ExportedProgram | None = None ,
24
22
metadata_path : str | None = None ,
25
23
aoti_compile_and_package : bool = False ,
26
24
) -> None :
@@ -60,7 +58,7 @@ def package_model_and_transforms(
60
58
tempfile .TemporaryDirectory () as transforms_tmpdir ,
61
59
):
62
60
# Package and extract transforms from pt2 archive
63
- model_path = export ._inductor .aoti_compile_and_package (
61
+ model_path = torch ._inductor .aoti_compile_and_package (
64
62
model_program , package_path = os .path .join (archive_tmpdir , "model.pt2" )
65
63
)
66
64
@@ -77,7 +75,7 @@ def package_model_and_transforms(
77
75
78
76
if transforms_program is not None :
79
77
# Package and extract transforms from pt2 archive
80
- transforms_path = export ._inductor .aoti_compile_and_package (
78
+ transforms_path = torch ._inductor .aoti_compile_and_package (
81
79
transforms_program ,
82
80
package_path = os .path .join (archive_tmpdir , "transforms.pt2" ),
83
81
)
@@ -100,13 +98,13 @@ def package_model_and_transforms(
100
98
)
101
99
102
100
103
- @export .no_grad ()
101
+ @torch .no_grad ()
104
102
def export_model_and_transforms (
105
103
model : torch .nn .Module ,
106
104
transforms : torch .nn .Module ,
107
105
input_shape : Sequence [int ],
108
- device : export .device ,
109
- ) -> tuple [export .export .ExportedProgram , export .export .ExportedProgram ]:
106
+ device : torch .device ,
107
+ ) -> tuple [torch .export .ExportedProgram , torch .export .ExportedProgram ]:
110
108
"""Exports a model and its transforms to programs.
111
109
112
110
Args:
@@ -135,16 +133,16 @@ def export_model_and_transforms(
135
133
transforms_arg = next (iter (inspect .signature (transforms .forward ).parameters ))
136
134
137
135
# Export model and transforms
138
- model_program = export .export .export (
136
+ model_program = torch .export .export (
139
137
mod = model , args = (example_inputs ,), dynamic_shapes = {model_arg : dims }
140
138
)
141
- transforms_program = export .export .export (
139
+ transforms_program = torch .export .export (
142
140
mod = transforms , args = (example_inputs ,), dynamic_shapes = {transforms_arg : dims }
143
141
)
144
142
return model_program , transforms_program
145
143
146
144
147
- def _create_example_input_from_shape (input_shape : Sequence [int ]) -> export .Tensor :
145
+ def _create_example_input_from_shape (input_shape : Sequence [int ]) -> torch .Tensor :
148
146
"""Creates an example input tensor based on the provided input shape.
149
147
150
148
Args:
@@ -183,4 +181,4 @@ def _create_example_input_from_shape(input_shape: Sequence[int]) -> export.Tenso
183
181
else :
184
182
shape .append (input_shape [3 ])
185
183
186
- return export .randn (* shape , requires_grad = False )
184
+ return torch .randn (* shape , requires_grad = False )
0 commit comments