Skip to content

Conversation

@pctablet505
Copy link
Collaborator

@pctablet505 pctablet505 commented Sep 17, 2025

This pull request adds support for exporting Keras models to the LiteRT (TFLite) format, along with improvements to input signature handling and export utility documentation. The changes ensure that LiteRT export is only available when TensorFlow is installed, update the export API and documentation, and enhance input signature inference for various model types.

LiteRT export Design Doc
LiteRT Export Support:

  • Added conditional import of LitertExporter and export_litert in keras/src/export/__init__.py, making LiteRT export available only if TensorFlow is installed.
  • Updated the Model.export method to support the "litert" format, including new options for LiteRT export and user-facing documentation and example. Raises an informative error if TensorFlow is not installed. [1] [2] [3] [4]
  • Registered litert as a lazy module in keras/src/utils/module_utils.py for dynamic import support.

Input Signature and Export Utilities:

  • Improved documentation and logic in get_input_signature to clarify behavior for different model types and ensure correct input signature construction for export. [1] [2]
  • Enhanced _infer_input_signature_from_model to handle flexible batch dimensions and ensure compatibility with downstream exporters, always returning a flat list of input specs.

pctablet505 and others added 30 commits May 6, 2025 10:25
Corrected indentation in doc string
Fixed issue with passing a single image without batch dimension.
…scale.py

Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com>
Test case for unbatched inputs
Testcase for checking both unbatched and batched single image inputs.
There was a bug, and it was causing cycle in graph.
removed the use of tree.map_structure
Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.

Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
  check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)

This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.

Relates to keras-hub PR keras-team#2257 and addresses flash attention debugging needs.
…sh_attention`

Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
  call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
  preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False

This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
pctablet505 and others added 5 commits November 12, 2025 19:18
Eliminated the _ensure_model_built method and its invocation from LiteRTExporter. The exporter now skips explicit model building before input signature resolution, simplifying the export process.
Eliminates Ahead-of-Time (AOT) compilation functionality from the LiteRT exporter and related tests. The exporter no longer accepts or processes aot_compile_targets, and all AOT-related methods and test cases have been removed for simplification and maintenance.

AOT compilation was opensouce earlier, now removed by google as opensource feature.
Moved import statements for export_utils functions to the top of the file and removed redundant inline imports. This improves code clarity and reduces repeated imports within methods.
Replaces assertion with ValueError when the export filepath does not end with '.tflite' in LiteRTExporter. Updates corresponding test to expect ValueError instead of AssertionError for incorrect file extension.
@pctablet505
Copy link
Collaborator Author

@gemini-code-assist review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for exporting Keras models to the LiteRT (TFLite) format. The implementation is comprehensive, handling various model types, including those with dictionary inputs, and includes robust fallback mechanisms for the conversion process. The accompanying tests are thorough and cover a wide range of scenarios.

My review focuses on improving maintainability, debuggability, and adherence to the project's documentation standards. I've identified some unused code, suggested adding more detailed logging to exception handling to aid in debugging conversion failures, and recommended adding comments to a particularly complex method to improve its clarity. I also pointed out a missing code example in a docstring, to align with the Keras API design guidelines.

Moved tracked collection conversion logic into a dedicated method with no automatic dependency tracking to avoid TensorFlow wrappers. Added a compatibility shim for TensorFlow saving utilities by implementing the _get_save_spec method, which generates TensorSpec objects for model input signatures.

Needed for liteRT export.
Refactored the fallback TFLite conversion method to use a direct tf.function approach instead of a tf.Module wrapper, simplifying the conversion logic. Added a 'verbose' parameter to export_litert and LiteRTExporter for progress messaging. Improved converter kwargs handling to only apply known TFLite settings.
Eliminates the 'verbose' parameter from export_litert and LiteRTExporter, simplifying the API and reducing unnecessary options for export progress messages.
Update LiteRTExporter to always enable resource variables during TFLite conversion, as Keras 3 only supports resource variables. Simplify conversion logic by removing strategy loop and error handling for unsupported conversion paths.
Deleted the _has_dict_inputs method from the LiteRTExporter class in litert.py as it is no longer used. This helps clean up the code and improve maintainability.
Simplifies and enforces stricter validation for converter kwargs in LiteRTExporter. Unknown attributes now raise ValueError instead of being ignored, and the method no longer maintains a list of known attributes, relying on attribute existence checks.
Eliminates the fallback method that used SavedModel as an intermediate step for TFLite conversion. Now, if direct conversion fails, a RuntimeError is raised with a helpful message, simplifying the export logic and error handling.
@pctablet505
Copy link
Collaborator Author

@gemini-code-assist review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for exporting Keras models to the LiteRT (TFLite) format, which is a valuable addition for on-device inference workflows. The implementation is well-structured, with a dedicated LiteRTExporter that handles the conversion logic, including a clever adapter for models with dictionary inputs. The changes to the Model.export API are clean and the new functionality is well-documented with examples. The accompanying test suite is comprehensive and covers a wide range of scenarios, ensuring the robustness of the new feature.

My review includes a few suggestions to improve maintainability and fix a minor bug in an error message, as well as a recommendation to enhance the documentation example to make it more user-friendly. Overall, this is a high-quality contribution.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 18, 2025
Replaces tf.keras.layers and tf.keras.Model references with locally imported layers and models from keras.src. This improves consistency and may help with modularity or compatibility within the keras.src namespace.
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Nov 18, 2025
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is getting really close.

Improves input signature inference and adapter creation for models with nested input structures (dicts, lists, etc.) in LiteRTExporter. Moves TensorSpec creation logic to export_utils and updates TFLayer to use tree.map_structure for save spec generation. Removes legacy dict-specific input signature inference and centralizes input structure handling for TFLite conversion.
Included the ai-edge-litert package in requirements.txt to support new functionality or dependencies.
Improves analysis of input signatures by unwrapping single-element lists for Functional models and consistently using the correct structure for input handling. Also updates input layer naming to prefer spec.name when available, ensuring more accurate input identification.
Brings in latest upstream changes from keras-team/master including:
- Support PyDataset in Normalization layer adapt methods
- Fix Torch output_padding constraint for ConvTranspose layers
- Improve error messages and validation
- Various bug fixes and improvements from upstream

This keeps the export branch up-to-date with the latest Keras codebase
while preserving all LiteRT export functionality.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants