-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Model Export to liteRT #21674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Model Export to liteRT #21674
Conversation
Corrected indentation in doc string
Fixed issue with passing a single image without batch dimension.
…scale.py Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com>
Test case for unbatched inputs
Testcase for checking both unbatched and batched single image inputs.
There was a bug, and it was causing cycle in graph.
removed the use of tree.map_structure
…s-team#21254)" (keras-team#21329) This reverts commit 81821e0.
Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR keras-team#2257 and addresses flash attention debugging needs.
… debugging" This reverts commit 7a0c547.
…sh_attention` Changes: - Add missing q_offsets=None and kv_offsets=None parameters to check_layout() call to match updated JAX function signature - Replace bare `except:` with `except Exception as e:` and `raise e` to preserve detailed error messages from JAX validation functions - Maintain existing fallback behavior when raise_error=False This resolves compatibility issues with newer JAX versions and improves debugging experience by surfacing specific technical details about flash attention compatibility failures.
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
Eliminated the _ensure_model_built method and its invocation from LiteRTExporter. The exporter now skips explicit model building before input signature resolution, simplifying the export process.
Eliminates Ahead-of-Time (AOT) compilation functionality from the LiteRT exporter and related tests. The exporter no longer accepts or processes aot_compile_targets, and all AOT-related methods and test cases have been removed for simplification and maintenance. AOT compilation was opensouce earlier, now removed by google as opensource feature.
Moved import statements for export_utils functions to the top of the file and removed redundant inline imports. This improves code clarity and reduces repeated imports within methods.
Replaces assertion with ValueError when the export filepath does not end with '.tflite' in LiteRTExporter. Updates corresponding test to expect ValueError instead of AssertionError for incorrect file extension.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for exporting Keras models to the LiteRT (TFLite) format. The implementation is comprehensive, handling various model types, including those with dictionary inputs, and includes robust fallback mechanisms for the conversion process. The accompanying tests are thorough and cover a wide range of scenarios.
My review focuses on improving maintainability, debuggability, and adherence to the project's documentation standards. I've identified some unused code, suggested adding more detailed logging to exception handling to aid in debugging conversion failures, and recommended adding comments to a particularly complex method to improve its clarity. I also pointed out a missing code example in a docstring, to align with the Keras API design guidelines.
Moved tracked collection conversion logic into a dedicated method with no automatic dependency tracking to avoid TensorFlow wrappers. Added a compatibility shim for TensorFlow saving utilities by implementing the _get_save_spec method, which generates TensorSpec objects for model input signatures. Needed for liteRT export.
Refactored the fallback TFLite conversion method to use a direct tf.function approach instead of a tf.Module wrapper, simplifying the conversion logic. Added a 'verbose' parameter to export_litert and LiteRTExporter for progress messaging. Improved converter kwargs handling to only apply known TFLite settings.
Eliminates the 'verbose' parameter from export_litert and LiteRTExporter, simplifying the API and reducing unnecessary options for export progress messages.
Update LiteRTExporter to always enable resource variables during TFLite conversion, as Keras 3 only supports resource variables. Simplify conversion logic by removing strategy loop and error handling for unsupported conversion paths.
Deleted the _has_dict_inputs method from the LiteRTExporter class in litert.py as it is no longer used. This helps clean up the code and improve maintainability.
Simplifies and enforces stricter validation for converter kwargs in LiteRTExporter. Unknown attributes now raise ValueError instead of being ignored, and the method no longer maintains a list of known attributes, relying on attribute existence checks.
Eliminates the fallback method that used SavedModel as an intermediate step for TFLite conversion. Now, if direct conversion fails, a RuntimeError is raised with a helpful message, simplifying the export logic and error handling.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for exporting Keras models to the LiteRT (TFLite) format, which is a valuable addition for on-device inference workflows. The implementation is well-structured, with a dedicated LiteRTExporter that handles the conversion logic, including a clever adapter for models with dictionary inputs. The changes to the Model.export API are clean and the new functionality is well-documented with examples. The accompanying test suite is comprehensive and covers a wide range of scenarios, ensuring the robustness of the new feature.
My review includes a few suggestions to improve maintainability and fix a minor bug in an error message, as well as a recommendation to enhance the documentation example to make it more user-friendly. Overall, this is a high-quality contribution.
Replaces tf.keras.layers and tf.keras.Model references with locally imported layers and models from keras.src. This improves consistency and may help with modularity or compatibility within the keras.src namespace.
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is getting really close.
Improves input signature inference and adapter creation for models with nested input structures (dicts, lists, etc.) in LiteRTExporter. Moves TensorSpec creation logic to export_utils and updates TFLayer to use tree.map_structure for save spec generation. Removes legacy dict-specific input signature inference and centralizes input structure handling for TFLite conversion.
Included the ai-edge-litert package in requirements.txt to support new functionality or dependencies.
Improves analysis of input signatures by unwrapping single-element lists for Functional models and consistently using the correct structure for input handling. Also updates input layer naming to prefer spec.name when available, ensuring more accurate input identification.
Brings in latest upstream changes from keras-team/master including: - Support PyDataset in Normalization layer adapt methods - Fix Torch output_padding constraint for ConvTranspose layers - Improve error messages and validation - Various bug fixes and improvements from upstream This keeps the export branch up-to-date with the latest Keras codebase while preserving all LiteRT export functionality.
This pull request adds support for exporting Keras models to the LiteRT (TFLite) format, along with improvements to input signature handling and export utility documentation. The changes ensure that LiteRT export is only available when TensorFlow is installed, update the
exportAPI and documentation, and enhance input signature inference for various model types.LiteRT export Design Doc
LiteRT Export Support:
LitertExporterandexport_litertinkeras/src/export/__init__.py, making LiteRT export available only if TensorFlow is installed.Model.exportmethod to support the"litert"format, including new options for LiteRT export and user-facing documentation and example. Raises an informative error if TensorFlow is not installed. [1] [2] [3] [4]litertas a lazy module inkeras/src/utils/module_utils.pyfor dynamic import support.Input Signature and Export Utilities:
get_input_signatureto clarify behavior for different model types and ensure correct input signature construction for export. [1] [2]_infer_input_signature_from_modelto handle flexible batch dimensions and ensure compatibility with downstream exporters, always returning a flat list of input specs.