@@ -63,16 +63,12 @@ def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
6363 ]
6464
6565
66- def _remove_tensors_from_model (model , remove_tensors_idxs ):
66+ def _remove_tensors_from_model (
67+ model : tflite_schema .ModelT , subgraph : tflite_schema .SubGraphT , remove_tensors_idxs
68+ ):
6769 """Remove tensors from model."""
6870 if not remove_tensors_idxs :
6971 return
70- if len (model .subgraphs ) > 1 :
71- raise ValueError (
72- "Model must only have one subgraph. Instead, it has "
73- "{} subgraphs." .format (len (model .subgraphs ))
74- )
75- subgraph = model .subgraphs [0 ]
7672 tensors = subgraph .tensors
7773 operators = subgraph .operators
7874
@@ -89,6 +85,7 @@ def _remove_tensors_from_model(model, remove_tensors_idxs):
8985 left_shift_by += 1
9086 else :
9187 d_old_to_new_tensors [idx ] = idx - left_shift_by
88+
9289 # Update tensor indices referenced throughout the model
9390 def update_tensors (tensor_idxs ):
9491 for i , ti in enumerate (tensor_idxs ):
@@ -110,14 +107,9 @@ def update_tensors(tensor_idxs):
110107
111108def _find_int8_quantized_inputs_outputs (model ):
112109 """Validate that model input is quantized and output is dequantized."""
113- if len (model .subgraphs ) > 1 :
114- raise ValueError (
115- "Model must only have one subgraph. Instead, it has "
116- "{} subgraphs." .format (len (model .subgraphs ))
117- )
118- subgraph = model .subgraphs [0 ]
119- tensors = subgraph .tensors
120- operators = subgraph .operators
110+ main_subgraph = model .subgraphs [0 ] # the main subgraph has ID 0
111+ tensors = main_subgraph .tensors
112+ operators = main_subgraph .operators
121113
122114 # Ensure model has atleast one quantize and dequantize operator
123115 quant_opcode_idx , dequant_opcode_idx = None , None
@@ -138,15 +130,18 @@ def _find_int8_quantized_inputs_outputs(model):
138130 input_quant_ops , output_dequant_ops = [], []
139131 for op in operators :
140132 # Find input quantize operator
141- if op .opcodeIndex == quant_opcode_idx and op .inputs [0 ] in subgraph .inputs :
133+ if op .opcodeIndex == quant_opcode_idx and op .inputs [0 ] in main_subgraph .inputs :
142134 pos , float_tensor , int_tensor = (
143135 "input" ,
144136 tensors [op .inputs [0 ]],
145137 tensors [op .outputs [0 ]],
146138 )
147139 input_quant_ops .append (op )
148140 # Find output dequantize operator
149- elif op .opcodeIndex == dequant_opcode_idx and op .outputs [0 ] in subgraph .outputs :
141+ elif (
142+ op .opcodeIndex == dequant_opcode_idx
143+ and op .outputs [0 ] in main_subgraph .outputs
144+ ):
150145 pos , float_tensor , int_tensor = (
151146 "output" ,
152147 tensors [op .outputs [0 ]],
@@ -221,23 +216,28 @@ def modify_integer_quantized_model_io_type(
221216 operators .remove (op )
222217
223218 # Remove tensors marked for deletion.
224- _remove_tensors_from_model (model , remove_tensors_idxs )
219+ _remove_tensors_from_model (model , subgraph , remove_tensors_idxs )
225220
226221 # Convert the model to a bytearray
227222 return _convert_model_from_object_to_bytearray (model )
228223
229224
230- def strip_lcedequantize_ops (model ) :
225+ def strip_lcedequantize_ops (model : bytes ) -> bytes :
231226 """Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
232227 # Convert the model to an object
233228 model = _convert_model_from_bytearray_to_object (model )
234229
235- if len (model .subgraphs ) > 1 :
236- raise ValueError (
237- "Model must only have one subgraph. Instead, it has "
238- "{} subgraphs." .format (len (model .subgraphs ))
239- )
230+ # Process each subgraph separately
231+ for subgraph in model .subgraphs :
232+ _strip_lcedequantize_ops_subgraph (model , subgraph )
233+
234+ # Convert the model to a bytearray
235+ return _convert_model_from_object_to_bytearray (model )
240236
237+
238+ def _strip_lcedequantize_ops_subgraph (
239+ model : tflite_schema .ModelT , subgraph : tflite_schema .SubGraphT
240+ ) -> None :
241241 # Ensure model has at least one LceDequantize and/or Dequantize operator
242242 lce_dequant_opcode_idx , dequant_opcode_idx = None , None
243243 for idx , opcode in enumerate (model .operatorCodes ):
@@ -254,7 +254,6 @@ def strip_lcedequantize_ops(model):
254254
255255 # Ensure model outputs are dequantized and remove Dequantize ops first if any
256256 if dequant_opcode_idx is not None :
257- subgraph = model .subgraphs [0 ]
258257 tensors = subgraph .tensors
259258 operators = subgraph .operators
260259 remove_tensors_idxs = set ()
@@ -306,9 +305,8 @@ def strip_lcedequantize_ops(model):
306305 operators .remove (op )
307306
308307 # Remove tensors marked for deletion.
309- _remove_tensors_from_model (model , remove_tensors_idxs )
308+ _remove_tensors_from_model (model , subgraph , remove_tensors_idxs )
310309
311- subgraph = model .subgraphs [0 ]
312310 tensors = subgraph .tensors
313311 operators = subgraph .operators
314312 remove_tensors_idxs = set ()
@@ -364,7 +362,4 @@ def strip_lcedequantize_ops(model):
364362 operators .remove (op )
365363
366364 # Remove tensors marked for deletion.
367- _remove_tensors_from_model (model , remove_tensors_idxs )
368-
369- # Convert the model to a bytearray
370- return _convert_model_from_object_to_bytearray (model )
365+ _remove_tensors_from_model (model , subgraph , remove_tensors_idxs )
0 commit comments