-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Model Export to liteRT #21674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Model Export to liteRT #21674
Changes from 104 commits
04cd682
1a74465
c11eb81
c81e18c
d938e20
f60811e
b3ae323
28eeb24
de81e5b
66661ac
c37f2b5
498dece
b0b5f63
653f5b1
e681e4c
27ad80b
50f6292
579cc11
7a0c547
f7a2290
8bae892
ee196cd
40583c8
7c918ba
a927e7e
f04eafa
bbc29a7
bac3416
98877eb
39c559b
417e4b1
14bfd9b
9d34d0a
837506d
631850e
f5aa72e
2b952d6
8f81dd5
011f1d8
9a99a32
d0070c6
761793f
7bb0506
c219eb1
e26ff6b
4a32e04
20d29a8
3aca2f6
926b0a8
441a778
4a8a9d5
f4b43b4
0fe4bd5
8c3faa3
88b6a6f
da13d04
f1f700c
5944780
4404c39
6a119fb
e469244
4cec7cd
3a7fcc4
51a1c7f
e1fca24
fd197d9
214558a
73f00f1
ebf11e2
c6f0c70
a6746e1
3c1d90a
657a271
8ce8bfa
cd9d063
fa3d3ed
e775ff2
34b662d
33b0550
cbe0229
e52de85
87af9ed
c643772
f243a6e
d8236fa
83577be
c53b264
487184d
374d90b
6a5597d
e843f7e
f99a103
d01a4cb
52440e1
794d85d
7a46f78
d2b90eb
191f802
b736ede
27f1d07
17dccf2
3e16ab3
efbc6d3
7825983
676a53c
4b6386e
79f05c8
a22eb65
315f7f6
4efae3e
f019a0a
5067904
1c8dbcd
ff4a81e
bcd965b
022cce8
820f73b
85e878b
1005063
c984a6b
809f6bc
30e4cdd
65dc0f9
dd1cfbd
4bf2e80
4773089
d98cca1
11bb4be
0f9f214
ddf911f
2a46ab3
26ac160
7c5cb3f
537880f
211b44d
4199c69
d376afb
66acb8f
3c2a4be
071c819
f17422c
a550fcc
b8267c6
46ead2f
b523552
ada71de
00088c9
42407e8
30deea8
2f64e9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| from keras.src import layers | ||
| from keras.src import models | ||
| from keras.src.export.export_utils import get_input_signature | ||
| from keras.src.export.export_utils import make_input_spec | ||
| from keras.src.utils import io_utils | ||
| from keras.src.utils.module_utils import tensorflow as tf | ||
|
|
||
|
|
||
| def export_litert( | ||
| model, | ||
| filepath, | ||
| input_signature=None, | ||
| **kwargs, | ||
| ): | ||
| """Export the model as a LiteRT artifact for inference. | ||
| Args: | ||
| model: The Keras model to export. | ||
| filepath: The path to save the exported artifact. | ||
| input_signature: Optional input signature specification. If | ||
| `None`, it will be inferred. | ||
| **kwargs: Additional keyword arguments passed to the exporter. | ||
| """ | ||
|
|
||
| exporter = LiteRTExporter( | ||
| model=model, | ||
| input_signature=input_signature, | ||
| **kwargs, | ||
| ) | ||
| exporter.export(filepath) | ||
| io_utils.print_msg(f"Saved artifact at '{filepath}'.") | ||
|
|
||
|
|
||
| class LiteRTExporter: | ||
| """Exporter for the LiteRT (TFLite) format. | ||
| This class handles the conversion of Keras models for LiteRT runtime and | ||
| generates a `.tflite` model file. For efficient inference on mobile and | ||
| embedded devices, it creates a single callable signature based on the | ||
| model's `call()` method. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model, | ||
| input_signature=None, | ||
| **kwargs, | ||
| ): | ||
| """Initialize the LiteRT exporter. | ||
| Args: | ||
| model: The Keras model to export | ||
| input_signature: Input signature specification (e.g., TensorFlow | ||
| TensorSpec or list of TensorSpec) | ||
| **kwargs: Additional export parameters | ||
| """ | ||
| self.model = model | ||
| self.input_signature = input_signature | ||
| self.kwargs = kwargs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for clarity, let's call it Also, it's very unfortunate that it's not validated right here. If it fails later, you're no longer in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would be dealing with |
||
|
|
||
| def _infer_dict_input_signature(self): | ||
| """Infer input signature from a model with dict inputs. | ||
| This reads the actual shapes and dtypes from model._inputs_struct. | ||
| Returns: | ||
| dict or None: Dictionary mapping input names to InputSpec, or None | ||
| """ | ||
| # Check _inputs_struct first (preserves dict structure) | ||
| if hasattr(self.model, "_inputs_struct") and isinstance( | ||
| self.model._inputs_struct, dict | ||
| ): | ||
| return { | ||
| name: make_input_spec(inp) | ||
| for name, inp in self.model._inputs_struct.items() | ||
| } | ||
|
|
||
| # Fall back to model.inputs if it's a dict | ||
| if hasattr(self.model, "inputs") and isinstance( | ||
| self.model.inputs, dict | ||
| ): | ||
| return { | ||
| name: make_input_spec(inp) | ||
| for name, inp in self.model.inputs.items() | ||
| } | ||
|
|
||
| return None | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def export(self, filepath): | ||
| """Exports the Keras model to a TFLite file. | ||
| Args: | ||
| filepath: Output path for the exported model | ||
| Returns: | ||
| Path to exported model | ||
| """ | ||
| # 1. Resolve / infer input signature | ||
| if self.input_signature is None: | ||
| # Try dict-specific inference first (for models with dict inputs) | ||
| dict_signature = self._infer_dict_input_signature() | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if dict_signature is not None: | ||
| self.input_signature = dict_signature | ||
| else: | ||
| # Fall back to standard inference | ||
| self.input_signature = get_input_signature(self.model) | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # 3. Handle dictionary inputs by creating an adapter | ||
| # Check if we have dict inputs that need adaptation | ||
| has_dict_inputs = isinstance(self.input_signature, dict) | ||
|
|
||
| if has_dict_inputs: | ||
| # Create adapter model that converts list to dict | ||
| adapted_model = self._create_dict_adapter(self.input_signature) | ||
|
|
||
| # Convert dict signature to list for TFLite conversion | ||
| # The adapter will handle the dict->list conversion | ||
| input_signature_list = list(self.input_signature.values()) | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Use adapted model and list signature for conversion | ||
| model_to_convert = adapted_model | ||
| signature_for_conversion = input_signature_list | ||
| else: | ||
| # No dict inputs - use model as-is | ||
| model_to_convert = self.model | ||
| signature_for_conversion = self.input_signature | ||
|
|
||
| # Store original model reference for later use | ||
| original_model = self.model | ||
|
|
||
| # Temporarily replace self.model with the model to convert | ||
| self.model = model_to_convert | ||
|
|
||
| try: | ||
| # 4. Convert the model to TFLite. | ||
| tflite_model = self._convert_to_tflite(signature_for_conversion) | ||
| finally: | ||
| # Restore original model | ||
| self.model = original_model | ||
|
|
||
| # 4. Save the initial TFLite model to the specified file path. | ||
| if not filepath.endswith(".tflite"): | ||
| raise ValueError( | ||
| "The LiteRT export requires the filepath to end with " | ||
| "'.tflite'. Got: {filepath}" | ||
| ) | ||
pctablet505 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| with open(filepath, "wb") as f: | ||
| f.write(tflite_model) | ||
|
|
||
| return filepath | ||
|
|
||
| def _create_dict_adapter(self, input_signature_dict): | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Create an adapter model that converts list inputs to dict inputs. | ||
| This adapter allows models expecting dictionary inputs to be exported | ||
| to TFLite format (which only supports positional/list inputs). | ||
| Args: | ||
| input_signature_dict: Dictionary mapping input names to InputSpec | ||
| Returns: | ||
| A Functional model that accepts list inputs and converts to dict | ||
| """ | ||
| io_utils.print_msg( | ||
| f"Creating adapter for dictionary inputs: " | ||
| f"{list(input_signature_dict.keys())}" | ||
| ) | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| input_keys = list(input_signature_dict.keys()) | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Create Input layers for TFLite (list-based) | ||
| input_layers = [] | ||
| for name in input_keys: | ||
| spec = input_signature_dict[name] | ||
| input_layer = layers.Input( | ||
| shape=spec.shape[1:], # Remove batch dimension | ||
| dtype=spec.dtype, | ||
| name=name, | ||
| ) | ||
| input_layers.append(input_layer) | ||
|
|
||
| # Create dict from list inputs | ||
| inputs_dict = { | ||
| name: layer for name, layer in zip(input_keys, input_layers) | ||
| } | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Call the original model with dict inputs | ||
| outputs = self.model(inputs_dict) | ||
|
|
||
| # Build as Functional model (list inputs -> dict -> model -> output) | ||
| adapted_model = models.Model(inputs=input_layers, outputs=outputs) | ||
|
|
||
| # Preserve the original model's variables | ||
| adapted_model._variables = self.model.variables | ||
| adapted_model._trainable_variables = self.model.trainable_variables | ||
| adapted_model._non_trainable_variables = ( | ||
| self.model.non_trainable_variables | ||
| ) | ||
pctablet505 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return adapted_model | ||
|
|
||
| def _convert_to_tflite(self, input_signature): | ||
| """Converts the Keras model to TFLite format. | ||
| Returns: | ||
| A bytes object containing the serialized TFLite model. | ||
| """ | ||
| # Try direct conversion first for all models | ||
| try: | ||
| converter = tf.lite.TFLiteConverter.from_keras_model(self.model) | ||
| converter.target_spec.supported_ops = [ | ||
| tf.lite.OpsSet.TFLITE_BUILTINS, | ||
| tf.lite.OpsSet.SELECT_TF_OPS, | ||
| ] | ||
| # Keras 3 only supports resource variables | ||
| converter.experimental_enable_resource_variables = True | ||
|
|
||
| # Apply any additional converter settings from kwargs | ||
| self._apply_converter_kwargs(converter) | ||
|
|
||
| tflite_model = converter.convert() | ||
|
|
||
| return tflite_model | ||
|
|
||
| except Exception as e: | ||
| # If direct conversion fails, raise the error with helpful message | ||
| raise RuntimeError( | ||
| f"Direct TFLite conversion failed. This may be due to model " | ||
| f"complexity or unsupported operations. Error: {e}" | ||
| ) from e | ||
|
|
||
| def _apply_converter_kwargs(self, converter): | ||
| """Apply additional converter settings from kwargs. | ||
| Args: | ||
| converter: tf.lite.TFLiteConverter instance to configure | ||
| Raises: | ||
| ValueError: If any kwarg is not a valid converter attribute | ||
| """ | ||
| for attr, value in self.kwargs.items(): | ||
| if attr == "target_spec" and isinstance(value, dict): | ||
| # Handle nested target_spec settings | ||
| for spec_key, spec_value in value.items(): | ||
| if hasattr(converter.target_spec, spec_key): | ||
| setattr(converter.target_spec, spec_key, spec_value) | ||
| else: | ||
| raise ValueError( | ||
| f"Unknown target_spec attribute '{spec_key}'" | ||
| ) | ||
| elif hasattr(converter, attr): | ||
| setattr(converter, attr, value) | ||
| else: | ||
| raise ValueError(f"Unknown converter attribute '{attr}'") | ||
Uh oh!
There was an error while loading. Please reload this page.