Skip to content

Commit b493ea7

Browse files
vloncarJanFSchulte
andauthored
Support for parsing ONNX Pad node (#1352)
* Support for parsing ONNX Pad node * Use dynamo onnx export * Parse Pad node with ONNX opset >= 11 * Make ONNX models manually instead of using torch's export --------- Co-authored-by: Jan-Frederik Schulte <jschulte@cern.ch>
1 parent f43776d commit b493ea7

File tree

6 files changed

+253
-45
lines changed

6 files changed

+253
-45
lines changed

hls4ml/converters/onnx/reshape.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler
1+
from hls4ml.converters.onnx_to_hls import get_constant_value, get_onnx_attribute, onnx_handler
22

33

44
@onnx_handler('Transpose')
@@ -58,3 +58,75 @@ def parse_resize_layer(node, input_names, input_shapes, graph):
5858
)
5959

6060
return layer
61+
62+
63+
@onnx_handler('Pad')
64+
def parse_pad_layer(node, input_names, input_shapes, graph):
65+
layer = {}
66+
layer['name'] = node.name
67+
layer['class_name'] = 'ZeroPadding'
68+
layer['inputs'] = input_names
69+
layer['outputs'] = list(node.output)
70+
layer['data_format'] = (
71+
'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first'
72+
)
73+
74+
mode = get_onnx_attribute(node, 'mode')
75+
if mode is not None and mode != 'constant':
76+
raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}')
77+
78+
pads = get_constant_value(graph, node.input[1])
79+
if len(input_names) > 2:
80+
const_val = get_constant_value(graph, node.input[2])
81+
if const_val != 0:
82+
raise RuntimeError(f'Only constant value of 0 supported for Pad node {node.name}, got {const_val}')
83+
84+
if len(input_names) > 3:
85+
raise RuntimeError(f'Parsing axes input of Pad node {node.name} is not supported.')
86+
87+
dim = 0
88+
if len(input_shapes[0]) == 3:
89+
dim = 1 # 2D input (batch, channels, width), will use ZeroPadding1D
90+
if layer['data_format'] == 'channels_first':
91+
_, channels, width = input_shapes[0]
92+
pad_left, pad_right = pads[2], pads[5]
93+
else:
94+
_, width, channels = input_shapes[0]
95+
pad_left, pad_right = pads[1], pads[4]
96+
out_width = width + pad_left + pad_right
97+
98+
layer['n_chan'] = channels
99+
layer['in_width'] = width
100+
layer['out_width'] = out_width
101+
102+
layer['pad_left'] = pad_left
103+
layer['pad_right'] = pad_right
104+
elif len(input_shapes[0]) == 4:
105+
dim = 2 # 3D input (batch, channels, height, width), will use ZeroPadding2D
106+
if layer['data_format'] == 'channels_first':
107+
_, channels, height, width = input_shapes[0]
108+
pad_top, pad_bottom = pads[2], pads[6]
109+
pad_left, pad_right = pads[3], pads[7]
110+
else:
111+
_, height, width, channels = input_shapes[0]
112+
pad_top, pad_bottom = pads[1], pads[5]
113+
pad_left, pad_right = pads[2], pads[6]
114+
out_height = height + pad_top + pad_bottom
115+
out_width = width + pad_left + pad_right
116+
117+
layer['n_chan'] = channels
118+
layer['in_height'] = height
119+
layer['in_width'] = width
120+
layer['out_height'] = out_height
121+
layer['out_width'] = out_width
122+
123+
layer['pad_top'] = pad_top
124+
layer['pad_bottom'] = pad_bottom
125+
layer['pad_left'] = pad_left
126+
layer['pad_right'] = pad_right
127+
else:
128+
raise RuntimeError(f'Unsupported input shape: {input_shapes[0]} for Pad node {node.name}')
129+
130+
layer['class_name'] += str(dim) + 'D'
131+
132+
return layer

hls4ml/converters/pytorch/reshape.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes,
207207
layer['out_height'] = out_height
208208
layer['out_width'] = out_width
209209

210+
layer['data_format'] = 'channels_first' # Default data format in PyTorch
211+
210212
return layer, output_shape
211213

212214

@@ -246,4 +248,6 @@ def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes,
246248
layer['in_width'] = width
247249
layer['out_width'] = out_width
248250

251+
layer['data_format'] = 'channels_first' # Default data format in PyTorch
252+
249253
return layer, output_shape

hls4ml/model/optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'parse_qonnx',
3535
[
3636
'reshape_constant',
37+
'padding_constant',
3738
'resize_remove_constants',
3839
'quant_constant_parameters',
3940
'bipolar_quant_constant_parameters',
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from hls4ml.model.layers import Constant, ZeroPadding1D, ZeroPadding2D
2+
from hls4ml.model.optimizer import OptimizerPass
3+
4+
5+
class PaddingConstant(OptimizerPass):
6+
"""
7+
ONNX has the padding come as an input, not a parameter. This removes the Constant node from the input.
8+
The constant value was already used; this is just a cleanup uptimization.
9+
"""
10+
11+
def match(self, node):
12+
is_match = (
13+
isinstance(node, (ZeroPadding1D, ZeroPadding2D))
14+
and len(node.inputs) > 1
15+
and isinstance(node.get_input_node(node.inputs[1]), Constant)
16+
)
17+
18+
return is_match
19+
20+
def transform(self, model, node):
21+
"""
22+
Remove Constant node(s) from the graph. Note, padding is already present in ZeroPadding node.
23+
"""
24+
if len(node.inputs) > 2:
25+
const_val_node = node.get_input_node(node.inputs[2])
26+
if not isinstance(const_val_node, Constant):
27+
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})')
28+
model.remove_node(const_val_node)
29+
node.inputs.pop(2)
30+
31+
pad_node = node.get_input_node(node.inputs[1])
32+
if not isinstance(pad_node, Constant):
33+
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})')
34+
model.remove_node(pad_node)
35+
node.inputs.pop(1)
36+
37+
return True

test/pytest/test_pytorch_constpadmapping.py

Lines changed: 0 additions & 44 deletions
This file was deleted.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import torch.nn as nn
5+
from onnx import TensorProto, helper
6+
7+
from hls4ml.converters import convert_from_onnx_model, convert_from_pytorch_model
8+
from hls4ml.utils.config import config_from_onnx_model, config_from_pytorch_model
9+
10+
test_root_path = Path(__file__).parent
11+
12+
13+
def _make_constantpad_onnx_1d():
14+
input_tensor = helper.make_tensor_value_info('global_in', TensorProto.FLOAT, [1, 2, 4])
15+
output_tensor = helper.make_tensor_value_info('global_out', TensorProto.FLOAT, [1, 2, 9])
16+
pads_tensor = helper.make_tensor_value_info('pads', TensorProto.INT64, [6])
17+
value_tensor = helper.make_tensor_value_info('value', TensorProto.FLOAT, [])
18+
19+
# Pads = [N_before, C_before, W_before, N_after, C_after, W_after]
20+
pads = [0, 0, 2, 0, 0, 3]
21+
22+
pads_initializer = helper.make_tensor(name='pads', data_type=TensorProto.INT64, dims=[6], vals=pads)
23+
value_initializer = helper.make_tensor(name='value', data_type=TensorProto.FLOAT, dims=[], vals=[0.0])
24+
25+
pad_node = helper.make_node(
26+
'Pad', name='const_pad', inputs=['global_in', 'pads', 'value'], outputs=['global_out'], mode='constant'
27+
)
28+
29+
graph = helper.make_graph(
30+
nodes=[pad_node],
31+
name='Pad1DGraph',
32+
inputs=[input_tensor],
33+
outputs=[output_tensor],
34+
initializer=[pads_initializer, value_initializer],
35+
value_info=[pads_tensor, value_tensor],
36+
)
37+
38+
model = helper.make_model(graph)
39+
40+
return model
41+
42+
43+
def test_constantpad_1d():
44+
class Pad1DModel(nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right
48+
49+
def forward(self, x):
50+
return self.pad(x)
51+
52+
model = Pad1DModel()
53+
model.eval()
54+
config_pytorch = config_from_pytorch_model(model, (2, 4), channels_last_conversion='off')
55+
hls_model_pytorch = convert_from_pytorch_model(
56+
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/pytorch'), hls_config=config_pytorch
57+
)
58+
59+
hls_model_pytorch.compile()
60+
61+
pad1d_onnx = _make_constantpad_onnx_1d()
62+
63+
config_onnx = config_from_onnx_model(pad1d_onnx)
64+
hls_model_onnx = convert_from_onnx_model(
65+
pad1d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/onnx'), hls_config=config_onnx
66+
)
67+
68+
hls_model_onnx.compile()
69+
70+
input_data = np.random.randn(10, 2, 4)
71+
pred_pytorch = hls_model_pytorch.predict(input_data)
72+
pred_onnx = hls_model_onnx.predict(input_data)
73+
74+
np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)
75+
76+
77+
def _make_constantpad_onnx_2d():
78+
input_tensor = helper.make_tensor_value_info('global_in', TensorProto.FLOAT, [1, 2, 3, 4])
79+
output_tensor = helper.make_tensor_value_info('global_out', TensorProto.FLOAT, [1, 2, 10, 7])
80+
pads_tensor = helper.make_tensor_value_info('pads', TensorProto.INT64, [8])
81+
value_tensor = helper.make_tensor_value_info('value', TensorProto.FLOAT, [])
82+
83+
# Pads = [N_before, C_before, H_before, W_before, N_after, C_after, H_after, W_after]
84+
pads = [0, 0, 3, 1, 0, 0, 4, 2]
85+
86+
pads_initializer = helper.make_tensor(name='pads', data_type=TensorProto.INT64, dims=[8], vals=pads)
87+
value_initializer = helper.make_tensor(name='value', data_type=TensorProto.FLOAT, dims=[], vals=[0.0])
88+
89+
pad_node = helper.make_node(
90+
'Pad', name='const_pad', inputs=['global_in', 'pads', 'value'], outputs=['global_out'], mode='constant'
91+
)
92+
93+
graph = helper.make_graph(
94+
nodes=[pad_node],
95+
name='Pad2DGraph',
96+
inputs=[input_tensor],
97+
outputs=[output_tensor],
98+
initializer=[pads_initializer, value_initializer],
99+
value_info=[pads_tensor, value_tensor],
100+
)
101+
102+
model = helper.make_model(graph)
103+
104+
return model
105+
106+
107+
def test_constantpad_2d():
108+
class Pad2DModel(nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom
112+
113+
def forward(self, x):
114+
return self.pad(x)
115+
116+
model = Pad2DModel()
117+
model.eval()
118+
config_pytorch = config_from_pytorch_model(model, (2, 3, 4), channels_last_conversion='off')
119+
hls_model_pytorch = convert_from_pytorch_model(
120+
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/pytorch'), hls_config=config_pytorch
121+
)
122+
123+
hls_model_pytorch.compile()
124+
125+
pad2d_onnx = _make_constantpad_onnx_2d()
126+
127+
config_onnx = config_from_onnx_model(pad2d_onnx)
128+
hls_model_onnx = convert_from_onnx_model(
129+
pad2d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/onnx'), hls_config=config_onnx
130+
)
131+
132+
hls_model_onnx.compile()
133+
134+
input_data = np.random.randn(10, 2, 3, 4)
135+
pred_pytorch = hls_model_pytorch.predict(input_data)
136+
pred_onnx = hls_model_onnx.predict(input_data)
137+
138+
np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)

0 commit comments

Comments
 (0)