Skip to content

Commit 3fc2658

Browse files
authored
Add support for multiple subgraphs in some of the Python MLIR utils (#776)
1 parent bb00f64 commit 3fc2658

File tree

1 file changed

+26
-31
lines changed
  • larq_compute_engine/mlir/python

1 file changed

+26
-31
lines changed

larq_compute_engine/mlir/python/util.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,12 @@ def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
6363
]
6464

6565

66-
def _remove_tensors_from_model(model, remove_tensors_idxs):
66+
def _remove_tensors_from_model(
67+
model: tflite_schema.ModelT, subgraph: tflite_schema.SubGraphT, remove_tensors_idxs
68+
):
6769
"""Remove tensors from model."""
6870
if not remove_tensors_idxs:
6971
return
70-
if len(model.subgraphs) > 1:
71-
raise ValueError(
72-
"Model must only have one subgraph. Instead, it has "
73-
"{} subgraphs.".format(len(model.subgraphs))
74-
)
75-
subgraph = model.subgraphs[0]
7672
tensors = subgraph.tensors
7773
operators = subgraph.operators
7874

@@ -89,6 +85,7 @@ def _remove_tensors_from_model(model, remove_tensors_idxs):
8985
left_shift_by += 1
9086
else:
9187
d_old_to_new_tensors[idx] = idx - left_shift_by
88+
9289
# Update tensor indices referenced throughout the model
9390
def update_tensors(tensor_idxs):
9491
for i, ti in enumerate(tensor_idxs):
@@ -110,14 +107,9 @@ def update_tensors(tensor_idxs):
110107

111108
def _find_int8_quantized_inputs_outputs(model):
112109
"""Validate that model input is quantized and output is dequantized."""
113-
if len(model.subgraphs) > 1:
114-
raise ValueError(
115-
"Model must only have one subgraph. Instead, it has "
116-
"{} subgraphs.".format(len(model.subgraphs))
117-
)
118-
subgraph = model.subgraphs[0]
119-
tensors = subgraph.tensors
120-
operators = subgraph.operators
110+
main_subgraph = model.subgraphs[0] # the main subgraph has ID 0
111+
tensors = main_subgraph.tensors
112+
operators = main_subgraph.operators
121113

122114
# Ensure model has atleast one quantize and dequantize operator
123115
quant_opcode_idx, dequant_opcode_idx = None, None
@@ -138,15 +130,18 @@ def _find_int8_quantized_inputs_outputs(model):
138130
input_quant_ops, output_dequant_ops = [], []
139131
for op in operators:
140132
# Find input quantize operator
141-
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
133+
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in main_subgraph.inputs:
142134
pos, float_tensor, int_tensor = (
143135
"input",
144136
tensors[op.inputs[0]],
145137
tensors[op.outputs[0]],
146138
)
147139
input_quant_ops.append(op)
148140
# Find output dequantize operator
149-
elif op.opcodeIndex == dequant_opcode_idx and op.outputs[0] in subgraph.outputs:
141+
elif (
142+
op.opcodeIndex == dequant_opcode_idx
143+
and op.outputs[0] in main_subgraph.outputs
144+
):
150145
pos, float_tensor, int_tensor = (
151146
"output",
152147
tensors[op.outputs[0]],
@@ -221,23 +216,28 @@ def modify_integer_quantized_model_io_type(
221216
operators.remove(op)
222217

223218
# Remove tensors marked for deletion.
224-
_remove_tensors_from_model(model, remove_tensors_idxs)
219+
_remove_tensors_from_model(model, subgraph, remove_tensors_idxs)
225220

226221
# Convert the model to a bytearray
227222
return _convert_model_from_object_to_bytearray(model)
228223

229224

230-
def strip_lcedequantize_ops(model):
225+
def strip_lcedequantize_ops(model: bytes) -> bytes:
231226
"""Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
232227
# Convert the model to an object
233228
model = _convert_model_from_bytearray_to_object(model)
234229

235-
if len(model.subgraphs) > 1:
236-
raise ValueError(
237-
"Model must only have one subgraph. Instead, it has "
238-
"{} subgraphs.".format(len(model.subgraphs))
239-
)
230+
# Process each subgraph separately
231+
for subgraph in model.subgraphs:
232+
_strip_lcedequantize_ops_subgraph(model, subgraph)
233+
234+
# Convert the model to a bytearray
235+
return _convert_model_from_object_to_bytearray(model)
240236

237+
238+
def _strip_lcedequantize_ops_subgraph(
239+
model: tflite_schema.ModelT, subgraph: tflite_schema.SubGraphT
240+
) -> None:
241241
# Ensure model has at least one LceDequantize and/or Dequantize operator
242242
lce_dequant_opcode_idx, dequant_opcode_idx = None, None
243243
for idx, opcode in enumerate(model.operatorCodes):
@@ -254,7 +254,6 @@ def strip_lcedequantize_ops(model):
254254

255255
# Ensure model outputs are dequantized and remove Dequantize ops first if any
256256
if dequant_opcode_idx is not None:
257-
subgraph = model.subgraphs[0]
258257
tensors = subgraph.tensors
259258
operators = subgraph.operators
260259
remove_tensors_idxs = set()
@@ -306,9 +305,8 @@ def strip_lcedequantize_ops(model):
306305
operators.remove(op)
307306

308307
# Remove tensors marked for deletion.
309-
_remove_tensors_from_model(model, remove_tensors_idxs)
308+
_remove_tensors_from_model(model, subgraph, remove_tensors_idxs)
310309

311-
subgraph = model.subgraphs[0]
312310
tensors = subgraph.tensors
313311
operators = subgraph.operators
314312
remove_tensors_idxs = set()
@@ -364,7 +362,4 @@ def strip_lcedequantize_ops(model):
364362
operators.remove(op)
365363

366364
# Remove tensors marked for deletion.
367-
_remove_tensors_from_model(model, remove_tensors_idxs)
368-
369-
# Convert the model to a bytearray
370-
return _convert_model_from_object_to_bytearray(model)
365+
_remove_tensors_from_model(model, subgraph, remove_tensors_idxs)

0 commit comments

Comments
 (0)