1+ import subprocess
12import tempfile
23
34import pytest
78from mmdeploy .utils .constants import Backend
89
910onnx_file = tempfile .NamedTemporaryFile (suffix = '.onnx' ).name
10- test_img = torch .rand ([ 1 , 3 , 64 , 64 ] )
11+ test_img = torch .rand (1 , 3 , 8 , 8 )
1112
1213
1314@pytest .mark .skip (reason = 'This a not test class but a utility class.' )
@@ -17,25 +18,15 @@ def __init__(self):
1718 super ().__init__ ()
1819
1920 def forward (self , x ):
20- return x * 0.5
21+ return x + test_img
2122
2223
23- model = TestModel ().eval (). cuda ()
24+ model = TestModel ().eval ()
2425
2526
2627@pytest .fixture (autouse = True , scope = 'module' )
2728def generate_onnx_file ():
2829 with torch .no_grad ():
29- dynamic_axes = {
30- 'input' : {
31- 0 : 'batch' ,
32- 2 : 'width' ,
33- 3 : 'height'
34- },
35- 'output' : {
36- 0 : 'batch'
37- }
38- }
3930 torch .onnx .export (
4031 model ,
4132 test_img ,
@@ -46,7 +37,7 @@ def generate_onnx_file():
4637 do_constant_folding = True ,
4738 verbose = False ,
4839 opset_version = 11 ,
49- dynamic_axes = dynamic_axes )
40+ dynamic_axes = None )
5041
5142
5243def check_backend_avaiable (backend ):
@@ -57,6 +48,22 @@ def check_backend_avaiable(backend):
5748 'TensorRT is not installed or custom ops are not compiled.' )
5849 if not torch .cuda .is_available ():
5950 pytest .skip ('CUDA is not available.' )
51+ elif backend == Backend .ONNXRUNTIME :
52+ from mmdeploy .apis .onnxruntime import is_available as ort_available
53+ if not ort_available ():
54+ pytest .skip (
55+ 'ONNXRuntime is not installed or custom ops are not compiled.' )
56+ elif backend == Backend .PPL :
57+ from mmdeploy .apis .ppl import is_available as ppl_avaiable
58+ if not ppl_avaiable ():
59+ pytest .skip ('PPL is not available.' )
60+ elif backend == Backend .NCNN :
61+ from mmdeploy .apis .ncnn import is_available as ncnn_available
62+ if not ncnn_available ():
63+ pytest .skip (
64+ 'NCNN is not installed or custom ops are not compiled.' )
65+ else :
66+ raise NotImplementedError (f'Unknown backend type: { backend .value } ' )
6067
6168
6269def onnx2backend (backend , onnx_file ):
@@ -66,20 +73,46 @@ def onnx2backend(backend, onnx_file):
6673 engine = create_trt_engine (
6774 onnx_file , {
6875 'input' : {
69- 'min_shape' : [1 , 3 , 64 , 64 ],
70- 'opt_shape' : [1 , 3 , 64 , 64 ],
71- 'max_shape' : [1 , 3 , 64 , 64 ]
76+ 'min_shape' : [1 , 3 , 8 , 8 ],
77+ 'opt_shape' : [1 , 3 , 8 , 8 ],
78+ 'max_shape' : [1 , 3 , 8 , 8 ]
7279 }
7380 })
7481 save_trt_engine (engine , backend_file )
7582 return backend_file
76-
77-
78- def create_wrapper (backend , engine_file ):
83+ elif backend == Backend .ONNXRUNTIME :
84+ return onnx_file
85+ elif backend == Backend .PPL :
86+ return onnx_file
87+ elif backend == Backend .NCNN :
88+ from mmdeploy .apis .ncnn import get_onnx2ncnn_path
89+ onnx2ncnn_path = get_onnx2ncnn_path ()
90+ param_file = tempfile .NamedTemporaryFile (suffix = '.param' ).name
91+ bin_file = tempfile .NamedTemporaryFile (suffix = '.bin' ).name
92+ subprocess .call ([onnx2ncnn_path , onnx_file , param_file , bin_file ])
93+ return param_file , bin_file
94+
95+
96+ def create_wrapper (backend , model_files ):
7997 if backend == Backend .TENSORRT :
8098 from mmdeploy .apis .tensorrt import TRTWrapper
81- trt_model = TRTWrapper (engine_file )
99+ trt_model = TRTWrapper (model_files )
82100 return trt_model
101+ elif backend == Backend .ONNXRUNTIME :
102+ from mmdeploy .apis .onnxruntime import ORTWrapper
103+ ort_model = ORTWrapper (model_files , 0 )
104+ return ort_model
105+ elif backend == Backend .PPL :
106+ from mmdeploy .apis .ppl import PPLWrapper
107+ ppl_model = PPLWrapper (model_files , 0 )
108+ return ppl_model
109+ elif backend == Backend .NCNN :
110+ from mmdeploy .apis .ncnn import NCNNWrapper
111+ param_file , bin_file = model_files
112+ ncnn_model = NCNNWrapper (param_file , bin_file , output_names = ['output' ])
113+ return ncnn_model
114+ else :
115+ raise NotImplementedError (f'Unknown backend type: { backend .value } ' )
83116
84117
85118def run_wrapper (backend , wrapper , input ):
@@ -88,9 +121,27 @@ def run_wrapper(backend, wrapper, input):
88121 results = wrapper ({'input' : input })['output' ]
89122 results = results .detach ().cpu ()
90123 return results
124+ elif backend == Backend .ONNXRUNTIME :
125+ input = input .cuda ()
126+ results = wrapper ({'input' : input })[0 ]
127+ return list (results )
128+ elif backend == Backend .PPL :
129+ input = input .cuda ()
130+ results = wrapper ({'input' : input })[0 ]
131+ return list (results )
132+ elif backend == Backend .NCNN :
133+ input = input .float ()
134+ results = wrapper ({'input' : input })['output' ]
135+ results = results .detach ().cpu ().numpy ()
136+ results_list = list (results )
137+ return results_list
138+ else :
139+ raise NotImplementedError (f'Unknown backend type: { backend .value } ' )
91140
92141
93- ALL_BACKEND = [Backend .TENSORRT ]
142+ ALL_BACKEND = [
143+ Backend .TENSORRT , Backend .ONNXRUNTIME , Backend .PPL , Backend .NCNN
144+ ]
94145
95146
96147@pytest .mark .parametrize ('backend' , ALL_BACKEND )
0 commit comments