Skip to content

Commit f61aa89

Browse files
[TorchFX][Conformance] Move all models to export_for_training (#3078)
### Changes All `capture_pre_autograd_graph` calls in the conformance test were replaced by `torch.export.export_for_training`. ### Reason for changes To remove deprecated `capture_pre_autograd_graph` from the conformance test. ### Related tickets #2766 ### Tests post_training_quantization/555/ have finished succesfully
1 parent 2284df5 commit f61aa89

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tests/post_training/pipelines/image_classification_torchvision.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import onnx
1717
import openvino as ov
1818
import torch
19-
from torch._export import capture_pre_autograd_graph
2019
from torchvision import models
2120

2221
from nncf.torch import disable_patching
@@ -25,11 +24,11 @@
2524
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase
2625

2726

28-
def _capture_pre_autograd_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
29-
return capture_pre_autograd_graph(model, args)
27+
def _torch_export_for_training(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
28+
return torch.export.export_for_training(model, args).module()
3029

3130

32-
def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
31+
def _torch_export(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
3332
return torch.export.export(model, args).module()
3433

3534

@@ -44,15 +43,15 @@ class ImageClassificationTorchvision(ImageClassificationBase):
4443
"""Pipeline for Image Classification model from torchvision repository"""
4544

4645
models_vs_model_params = {
47-
models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _capture_pre_autograd_module),
46+
models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _torch_export_for_training),
4847
models.mobilenet_v3_small: VisionModelParams(
49-
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
48+
models.MobileNet_V3_Small_Weights.DEFAULT, _torch_export_for_training
5049
),
5150
models.vit_b_16: VisionModelParams(
52-
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
51+
models.ViT_B_16_Weights.DEFAULT, _torch_export_for_training, export_torch_before_ov_convert=True
5352
),
5453
models.swin_v2_s: VisionModelParams(
55-
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
54+
models.Swin_V2_S_Weights.DEFAULT, _torch_export, export_torch_before_ov_convert=True
5655
),
5756
}
5857

0 commit comments

Comments
 (0)