1+ import os
12from packaging import version
23import warnings
3- from typing import Optional , Tuple
4+ from typing import Optional , Tuple , Union
45import tensorflow as tf
6+ import tempfile
57
6- from larq_compute_engine .mlir ._graphdef_tfl_flatbuffer import (
8+ from larq_compute_engine .mlir ._tf_tfl_flatbuffer import (
79 convert_graphdef_to_tflite_flatbuffer ,
10+ convert_saved_model_to_tflite_flatbuffer ,
811)
12+
913from larq_compute_engine .mlir .python .util import modify_integer_quantized_model_io_type
1014
1115from tensorflow .core .framework .types_pb2 import DataType
@@ -57,6 +61,120 @@ def _contains_training_quant_op(graph_def):
5761 return False
5862
5963
64+ def _validate_options (
65+ * ,
66+ inference_input_type = None ,
67+ inference_output_type = None ,
68+ target = None ,
69+ experimental_default_int8_range = None ,
70+ ):
71+ if inference_input_type not in (tf .float32 , tf .int8 ):
72+ raise ValueError (
73+ "Expected `inference_input_type` to be either `tf.float32` or `tf.int8`, "
74+ f"got { inference_input_type } ."
75+ )
76+ if inference_output_type not in (tf .float32 , tf .int8 ):
77+ raise ValueError (
78+ "Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
79+ f"got { inference_output_type } ."
80+ )
81+ if target not in ("arm" , "xcore" ):
82+ raise ValueError (f'Expected `target` to be "arm" or "xcore", but got { target } .' )
83+
84+ if not tf .executing_eagerly ():
85+ raise RuntimeError (
86+ "Graph mode is not supported. Please enable eager execution using "
87+ "tf.enable_eager_execution() when using TensorFlow 1.x"
88+ )
89+ if experimental_default_int8_range :
90+ warnings .warn (
91+ "Using `experimental_default_int8_range` as fallback quantization stats. "
92+ "This should only be used for latency tests."
93+ )
94+
95+
96+ def convert_saved_model (
97+ saved_model_dir : Union [str , os .PathLike ],
98+ * , # Require remaining arguments to be keyword-only.
99+ inference_input_type : tf .DType = tf .float32 ,
100+ inference_output_type : tf .DType = tf .float32 ,
101+ target : str = "arm" ,
102+ experimental_default_int8_range : Optional [Tuple [float , float ]] = None ,
103+ experimental_enable_bitpacked_activations : bool = False ,
104+ ) -> bytes :
105+ """Converts a SavedModel to TFLite flatbuffer.
106+
107+ !!! example
108+ ```python
109+ tflite_model = convert_saved_model(saved_model_dir)
110+ with open("/tmp/my_model.tflite", "wb") as f:
111+ f.write(tflite_model)
112+ ```
113+
114+ # Arguments
115+ saved_model_dir: SavedModel directory to convert.
116+ inference_input_type: Data type of the input layer. Defaults to `tf.float32`,
117+ must be either `tf.float32` or `tf.int8`.
118+ inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
119+ must be either `tf.float32` or `tf.int8`.
120+ target: Target hardware platform. Defaults to "arm", must be either "arm"
121+ or "xcore".
122+ experimental_default_int8_range: Tuple of integers representing `(min, max)`
123+ range values for all arrays without a specified range. Intended for
124+ experimenting with quantization via "dummy quantization". (default None)
125+ experimental_enable_bitpacked_activations: Enable an experimental
126+ converter optimisation that attempts to reduce intermediate
127+ activation memory usage by bitpacking the activation tensor between
128+ consecutive binary convolutions where possible.
129+
130+ # Returns
131+ The converted data in serialized format.
132+ """
133+ if version .parse (tf .__version__ ) < version .parse ("2.2" ):
134+ raise RuntimeError (
135+ "TensorFlow 2.2 or newer is required for saved model conversion."
136+ )
137+ _validate_options (
138+ inference_input_type = inference_input_type ,
139+ inference_output_type = inference_output_type ,
140+ target = target ,
141+ experimental_default_int8_range = experimental_default_int8_range ,
142+ )
143+
144+ saved_model_dir = str (saved_model_dir )
145+ saved_model_tags = [tf .saved_model .SERVING ]
146+ saved_model_exported_names = [tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ]
147+
148+ from tensorflow .python .saved_model import loader_impl
149+
150+ saved_model_pb , _ = loader_impl .parse_saved_model_with_debug_info (saved_model_dir )
151+
152+ saved_model_version = saved_model_pb .saved_model_schema_version
153+ if saved_model_version not in (1 , 2 ):
154+ raise ValueError (
155+ f"SavedModel file format({ saved_model_version } ) is not supported"
156+ )
157+
158+ tflite_buffer = convert_saved_model_to_tflite_flatbuffer (
159+ saved_model_dir ,
160+ saved_model_tags ,
161+ saved_model_exported_names ,
162+ saved_model_version ,
163+ target ,
164+ experimental_default_int8_range ,
165+ experimental_enable_bitpacked_activations ,
166+ )
167+
168+ if inference_input_type != tf .float32 or inference_output_type != tf .float32 :
169+ tflite_buffer = modify_integer_quantized_model_io_type (
170+ tflite_buffer ,
171+ inference_input_type = inference_input_type ,
172+ inference_output_type = inference_output_type ,
173+ )
174+
175+ return tflite_buffer
176+
177+
60178def convert_keras_model (
61179 model : tf .keras .Model ,
62180 * , # Require remaining arguments to be keyword-only.
@@ -81,7 +199,8 @@ def convert_keras_model(
81199 must be either `tf.float32` or `tf.int8`.
82200 inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
83201 must be either `tf.float32` or `tf.int8`.
84- target: Target hardware platform. Must be "arm" or "xcore".
202+ target: Target hardware platform. Defaults to "arm", must be either "arm"
203+ or "xcore".
85204 experimental_default_int8_range: Tuple of integers representing `(min, max)`
86205 range values for all arrays without a specified range. Intended for
87206 experimenting with quantization via "dummy quantization". (default None)
@@ -97,35 +216,37 @@ def convert_keras_model(
97216 raise ValueError (
98217 f"Expected `model` argument to be a `tf.keras.Model` instance, got `{ model } `."
99218 )
100- if inference_input_type not in (tf .float32 , tf .int8 ):
101- raise ValueError (
102- "Expected `inference_input_type` to be either `tf.float32` or `tf.int8`, "
103- f"got { inference_input_type } ."
104- )
105- if inference_output_type not in (tf .float32 , tf .int8 ):
106- raise ValueError (
107- "Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
108- f"got { inference_output_type } ."
109- )
110- if target not in ("arm" , "xcore" ):
111- raise ValueError (f'Expected `target` to be "arm" or "xcore", but got { target } .' )
112-
113- if not tf .executing_eagerly ():
114- raise RuntimeError (
115- "Graph mode is not supported. Please enable eager execution using "
116- "tf.enable_eager_execution() when using TensorFlow 1.x"
117- )
118- if experimental_default_int8_range :
119- warnings .warn (
120- "Using `experimental_default_int8_range` as fallback quantization stats. "
121- "This should only be used for latency tests."
122- )
123219 if hasattr (model , "dtype_policy" ) and model .dtype_policy .name != "float32" :
124- raise RuntimeError (
220+ raise ValueError (
125221 "Mixed precision float16 models are not supported by the TFLite converter, "
126222 "please convert them to float32 first. See also: "
127223 "https://github.com/tensorflow/tensorflow/issues/46380"
128224 )
225+ _validate_options (
226+ inference_input_type = inference_input_type ,
227+ inference_output_type = inference_output_type ,
228+ target = target ,
229+ experimental_default_int8_range = experimental_default_int8_range ,
230+ )
231+
232+ # First attempt conversion as saved model
233+ try :
234+ with tempfile .TemporaryDirectory () as saved_model_dir :
235+ model .save (saved_model_dir , save_format = "tf" )
236+
237+ return convert_saved_model (
238+ saved_model_dir ,
239+ inference_input_type = inference_input_type ,
240+ inference_output_type = inference_output_type ,
241+ experimental_default_int8_range = experimental_default_int8_range ,
242+ experimental_enable_bitpacked_activations = experimental_enable_bitpacked_activations ,
243+ target = target ,
244+ )
245+ except Exception :
246+ warnings .warn (
247+ "Saved-model conversion failed, falling back to graphdef-based conversion."
248+ )
249+
129250 func = concrete_function_from_keras_model (model )
130251 if version .parse (tf .__version__ ) >= version .parse ("1.15" ):
131252 frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
@@ -168,6 +289,7 @@ def convert_keras_model(
168289 experimental_default_int8_range ,
169290 experimental_enable_bitpacked_activations ,
170291 )
292+
171293 if should_quantize and (
172294 inference_input_type != tf .float32 or inference_output_type != tf .float32
173295 ):
0 commit comments