Skip to content

Commit 6bb5a43

Browse files
lgeigerTombana
andauthored
Add support for saved model conversion (#655)
* Add support for saved model conversion Co-Authored-By: Tom Bannink <Tombana@users.noreply.github.com> * Fix converter tests * Imrpove target docstrings * Add missing license comment Co-authored-by: Tom Bannink <Tombana@users.noreply.github.com>
1 parent 3e1b148 commit 6bb5a43

File tree

9 files changed

+438
-62
lines changed

9 files changed

+438
-62
lines changed

larq_compute_engine/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from larq_compute_engine.mlir.python.converter import convert_keras_model
1+
from larq_compute_engine.mlir.python.converter import (
2+
convert_keras_model,
3+
convert_saved_model,
4+
)
25
from larq_compute_engine.tflite.python import interpreter as testing
36

47
try:
@@ -9,4 +12,4 @@
912

1013
__version__ = metadata.version("larq_compute_engine")
1114

12-
__all__ = ["convert_keras_model", "testing"]
15+
__all__ = ["convert_keras_model", "convert_saved_model", "testing"]

larq_compute_engine/mlir/BUILD

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,19 @@ tf_cc_binary(
335335
)
336336

337337
pybind_extension(
338-
name = "_graphdef_tfl_flatbuffer",
339-
srcs = ["python/graphdef_tfl_flatbuffer.cc"],
340-
module_name = "graphdef_tfl_flatbuffer",
338+
name = "_tf_tfl_flatbuffer",
339+
srcs = [
340+
"python/graphdef_tfl_flatbuffer.cc",
341+
"python/pybind_export.cc",
342+
"python/saved_model_tfl_flatbuffer.cc",
343+
],
344+
module_name = "tf_tfl_flatbuffer",
341345
deps = [
342346
":lce_tfl_passes",
343347
":tf_to_tfl_flatbuffer",
344348
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
349+
"@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
350+
"@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers",
345351
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
346352
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:import_utils",
347353
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
@@ -366,7 +372,7 @@ py_library(
366372
":tflite_schema_py",
367373
],
368374
deps = [
369-
":_graphdef_tfl_flatbuffer",
375+
":_tf_tfl_flatbuffer",
370376
],
371377
)
372378

larq_compute_engine/mlir/python/converter.py

Lines changed: 149 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import os
12
from packaging import version
23
import warnings
3-
from typing import Optional, Tuple
4+
from typing import Optional, Tuple, Union
45
import tensorflow as tf
6+
import tempfile
57

6-
from larq_compute_engine.mlir._graphdef_tfl_flatbuffer import (
8+
from larq_compute_engine.mlir._tf_tfl_flatbuffer import (
79
convert_graphdef_to_tflite_flatbuffer,
10+
convert_saved_model_to_tflite_flatbuffer,
811
)
12+
913
from larq_compute_engine.mlir.python.util import modify_integer_quantized_model_io_type
1014

1115
from tensorflow.core.framework.types_pb2 import DataType
@@ -57,6 +61,120 @@ def _contains_training_quant_op(graph_def):
5761
return False
5862

5963

64+
def _validate_options(
65+
*,
66+
inference_input_type=None,
67+
inference_output_type=None,
68+
target=None,
69+
experimental_default_int8_range=None,
70+
):
71+
if inference_input_type not in (tf.float32, tf.int8):
72+
raise ValueError(
73+
"Expected `inference_input_type` to be either `tf.float32` or `tf.int8`, "
74+
f"got {inference_input_type}."
75+
)
76+
if inference_output_type not in (tf.float32, tf.int8):
77+
raise ValueError(
78+
"Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
79+
f"got {inference_output_type}."
80+
)
81+
if target not in ("arm", "xcore"):
82+
raise ValueError(f'Expected `target` to be "arm" or "xcore", but got {target}.')
83+
84+
if not tf.executing_eagerly():
85+
raise RuntimeError(
86+
"Graph mode is not supported. Please enable eager execution using "
87+
"tf.enable_eager_execution() when using TensorFlow 1.x"
88+
)
89+
if experimental_default_int8_range:
90+
warnings.warn(
91+
"Using `experimental_default_int8_range` as fallback quantization stats. "
92+
"This should only be used for latency tests."
93+
)
94+
95+
96+
def convert_saved_model(
97+
saved_model_dir: Union[str, os.PathLike],
98+
*, # Require remaining arguments to be keyword-only.
99+
inference_input_type: tf.DType = tf.float32,
100+
inference_output_type: tf.DType = tf.float32,
101+
target: str = "arm",
102+
experimental_default_int8_range: Optional[Tuple[float, float]] = None,
103+
experimental_enable_bitpacked_activations: bool = False,
104+
) -> bytes:
105+
"""Converts a SavedModel to TFLite flatbuffer.
106+
107+
!!! example
108+
```python
109+
tflite_model = convert_saved_model(saved_model_dir)
110+
with open("/tmp/my_model.tflite", "wb") as f:
111+
f.write(tflite_model)
112+
```
113+
114+
# Arguments
115+
saved_model_dir: SavedModel directory to convert.
116+
inference_input_type: Data type of the input layer. Defaults to `tf.float32`,
117+
must be either `tf.float32` or `tf.int8`.
118+
inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
119+
must be either `tf.float32` or `tf.int8`.
120+
target: Target hardware platform. Defaults to "arm", must be either "arm"
121+
or "xcore".
122+
experimental_default_int8_range: Tuple of integers representing `(min, max)`
123+
range values for all arrays without a specified range. Intended for
124+
experimenting with quantization via "dummy quantization". (default None)
125+
experimental_enable_bitpacked_activations: Enable an experimental
126+
converter optimisation that attempts to reduce intermediate
127+
activation memory usage by bitpacking the activation tensor between
128+
consecutive binary convolutions where possible.
129+
130+
# Returns
131+
The converted data in serialized format.
132+
"""
133+
if version.parse(tf.__version__) < version.parse("2.2"):
134+
raise RuntimeError(
135+
"TensorFlow 2.2 or newer is required for saved model conversion."
136+
)
137+
_validate_options(
138+
inference_input_type=inference_input_type,
139+
inference_output_type=inference_output_type,
140+
target=target,
141+
experimental_default_int8_range=experimental_default_int8_range,
142+
)
143+
144+
saved_model_dir = str(saved_model_dir)
145+
saved_model_tags = [tf.saved_model.SERVING]
146+
saved_model_exported_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
147+
148+
from tensorflow.python.saved_model import loader_impl
149+
150+
saved_model_pb, _ = loader_impl.parse_saved_model_with_debug_info(saved_model_dir)
151+
152+
saved_model_version = saved_model_pb.saved_model_schema_version
153+
if saved_model_version not in (1, 2):
154+
raise ValueError(
155+
f"SavedModel file format({saved_model_version}) is not supported"
156+
)
157+
158+
tflite_buffer = convert_saved_model_to_tflite_flatbuffer(
159+
saved_model_dir,
160+
saved_model_tags,
161+
saved_model_exported_names,
162+
saved_model_version,
163+
target,
164+
experimental_default_int8_range,
165+
experimental_enable_bitpacked_activations,
166+
)
167+
168+
if inference_input_type != tf.float32 or inference_output_type != tf.float32:
169+
tflite_buffer = modify_integer_quantized_model_io_type(
170+
tflite_buffer,
171+
inference_input_type=inference_input_type,
172+
inference_output_type=inference_output_type,
173+
)
174+
175+
return tflite_buffer
176+
177+
60178
def convert_keras_model(
61179
model: tf.keras.Model,
62180
*, # Require remaining arguments to be keyword-only.
@@ -81,7 +199,8 @@ def convert_keras_model(
81199
must be either `tf.float32` or `tf.int8`.
82200
inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
83201
must be either `tf.float32` or `tf.int8`.
84-
target: Target hardware platform. Must be "arm" or "xcore".
202+
target: Target hardware platform. Defaults to "arm", must be either "arm"
203+
or "xcore".
85204
experimental_default_int8_range: Tuple of integers representing `(min, max)`
86205
range values for all arrays without a specified range. Intended for
87206
experimenting with quantization via "dummy quantization". (default None)
@@ -97,35 +216,37 @@ def convert_keras_model(
97216
raise ValueError(
98217
f"Expected `model` argument to be a `tf.keras.Model` instance, got `{model}`."
99218
)
100-
if inference_input_type not in (tf.float32, tf.int8):
101-
raise ValueError(
102-
"Expected `inference_input_type` to be either `tf.float32` or `tf.int8`, "
103-
f"got {inference_input_type}."
104-
)
105-
if inference_output_type not in (tf.float32, tf.int8):
106-
raise ValueError(
107-
"Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
108-
f"got {inference_output_type}."
109-
)
110-
if target not in ("arm", "xcore"):
111-
raise ValueError(f'Expected `target` to be "arm" or "xcore", but got {target}.')
112-
113-
if not tf.executing_eagerly():
114-
raise RuntimeError(
115-
"Graph mode is not supported. Please enable eager execution using "
116-
"tf.enable_eager_execution() when using TensorFlow 1.x"
117-
)
118-
if experimental_default_int8_range:
119-
warnings.warn(
120-
"Using `experimental_default_int8_range` as fallback quantization stats. "
121-
"This should only be used for latency tests."
122-
)
123219
if hasattr(model, "dtype_policy") and model.dtype_policy.name != "float32":
124-
raise RuntimeError(
220+
raise ValueError(
125221
"Mixed precision float16 models are not supported by the TFLite converter, "
126222
"please convert them to float32 first. See also: "
127223
"https://github.com/tensorflow/tensorflow/issues/46380"
128224
)
225+
_validate_options(
226+
inference_input_type=inference_input_type,
227+
inference_output_type=inference_output_type,
228+
target=target,
229+
experimental_default_int8_range=experimental_default_int8_range,
230+
)
231+
232+
# First attempt conversion as saved model
233+
try:
234+
with tempfile.TemporaryDirectory() as saved_model_dir:
235+
model.save(saved_model_dir, save_format="tf")
236+
237+
return convert_saved_model(
238+
saved_model_dir,
239+
inference_input_type=inference_input_type,
240+
inference_output_type=inference_output_type,
241+
experimental_default_int8_range=experimental_default_int8_range,
242+
experimental_enable_bitpacked_activations=experimental_enable_bitpacked_activations,
243+
target=target,
244+
)
245+
except Exception:
246+
warnings.warn(
247+
"Saved-model conversion failed, falling back to graphdef-based conversion."
248+
)
249+
129250
func = concrete_function_from_keras_model(model)
130251
if version.parse(tf.__version__) >= version.parse("1.15"):
131252
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
@@ -168,6 +289,7 @@ def convert_keras_model(
168289
experimental_default_int8_range,
169290
experimental_enable_bitpacked_activations,
170291
)
292+
171293
if should_quantize and (
172294
inference_input_type != tf.float32 or inference_output_type != tf.float32
173295
):

larq_compute_engine/mlir/python/converter_test.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import sys
22
import unittest
3+
from packaging import version
34
from unittest import mock
45

6+
import tensorflow as tf
57
import larq_zoo as lqz
68
from tensorflow.python.eager import context
79

810
sys.modules["importlib.metadata"] = mock.MagicMock()
911
sys.modules["importlib_metadata"] = mock.MagicMock()
10-
sys.modules["larq_compute_engine.mlir._graphdef_tfl_flatbuffer"] = mock.MagicMock()
12+
sys.modules["larq_compute_engine.mlir._tf_tfl_flatbuffer"] = mock.MagicMock()
1113
sys.modules[
1214
"larq_compute_engine.tflite.python.interpreter_wrapper_lite"
1315
] = mock.MagicMock()
1416
sys.modules["larq_compute_engine.mlir.python.tflite_schema"] = mock.MagicMock()
1517

1618
from larq_compute_engine.mlir.python.converter import convert_keras_model
17-
from larq_compute_engine.mlir._graphdef_tfl_flatbuffer import (
18-
convert_graphdef_to_tflite_flatbuffer as mocked_converter,
19+
from larq_compute_engine.mlir._tf_tfl_flatbuffer import (
20+
convert_graphdef_to_tflite_flatbuffer as mocked_graphdef_converter,
21+
convert_saved_model_to_tflite_flatbuffer as mocked_saved_model_converter,
1922
)
2023

2124

@@ -24,17 +27,22 @@ def test_larq_zoo_models(self):
2427
with context.eager_mode():
2528
model = lqz.sota.QuickNet(weights=None)
2629
convert_keras_model(model)
27-
mocked_converter.assert_called_once_with(
28-
mock.ANY,
29-
["input_1"],
30-
["DT_FLOAT"],
31-
[[1, 224, 224, 3]],
32-
["Identity"],
33-
False,
34-
"arm",
35-
None,
36-
False,
37-
)
30+
if version.parse(tf.__version__) < version.parse("2.2"):
31+
mocked_graphdef_converter.assert_called_once_with(
32+
mock.ANY,
33+
["input_1"],
34+
["DT_FLOAT"],
35+
[[1, 224, 224, 3]],
36+
["Identity"],
37+
False,
38+
"arm",
39+
None,
40+
False,
41+
)
42+
else:
43+
mocked_saved_model_converter.assert_called_once_with(
44+
mock.ANY, ["serve"], ["serving_default"], 1, "arm", None, False
45+
)
3846

3947
def test_wrong_arg(self):
4048
with self.assertRaises(ValueError):

larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "pybind11/stl.h"
1313
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
1414
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
15-
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
1615
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
1716
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
1817
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@@ -118,8 +117,3 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
118117
}
119118

120119
} // namespace tensorflow
121-
122-
PYBIND11_MODULE(_graphdef_tfl_flatbuffer, m) {
123-
m.def("convert_graphdef_to_tflite_flatbuffer",
124-
&tensorflow::ConvertGraphDefToTFLiteFlatBuffer);
125-
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "pybind11/pybind11.h"
2+
#include "pybind11/pytypes.h"
3+
#include "pybind11/stl.h"
4+
5+
namespace tensorflow {
6+
7+
using std::string;
8+
9+
pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
10+
const pybind11::bytes& graphdef_bytes,
11+
const std::vector<string>& input_arrays,
12+
const std::vector<string>& input_dtypes,
13+
const std::vector<std::vector<int>>& input_shapes,
14+
const std::vector<string>& output_arrays, const bool should_quantize,
15+
const std::string& target_str, const pybind11::object& default_ranges,
16+
const bool experimental_enable_bitpacked_activations);
17+
18+
pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(
19+
const std::string& saved_model_dir,
20+
const std::vector<std::string>& saved_model_tags,
21+
const std::vector<std::string>& exported_names,
22+
const int saved_model_version, const std::string& target_str,
23+
const pybind11::object& default_ranges,
24+
const bool experimental_enable_bitpacked_activations);
25+
} // namespace tensorflow
26+
27+
PYBIND11_MODULE(_tf_tfl_flatbuffer, m) {
28+
m.def("convert_graphdef_to_tflite_flatbuffer",
29+
&tensorflow::ConvertGraphDefToTFLiteFlatBuffer);
30+
m.def("convert_saved_model_to_tflite_flatbuffer",
31+
&tensorflow::ConvertSavedModelToTFLiteFlatBuffer);
32+
};

0 commit comments

Comments
 (0)