Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

@amitsrivastava78 amitsrivastava78 commented Oct 22, 2025

Supports following feature

  • Asynchronous Checkpointing
  • Composite Checkpointing
  • Preservation Policies
  • Save Decision Policies

…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new OrbaxCheckpoint callback offers features like asynchronous saving, customizable save policies, and the ability to save complex states including model weights, optimizer variables, metrics, and data iterator positions. This enhancement significantly improves the reliability and efficiency of training large models, especially in distributed environments, by leveraging Orbax's advanced capabilities.

Highlights

  • New OrbaxCheckpoint Callback: Introduces a new OrbaxCheckpoint callback for Keras 3.0, enabling advanced data-centric saving and restoration of model states.
  • Asynchronous Checkpointing: Supports asynchronous saving of model weights and optimizer states, allowing training to continue without I/O blocking.
  • Comprehensive Checkpointing Features: Includes support for composite checkpointing, preservation policies (e.g., max_to_keep, keep_period), save decision policies (e.g., save_interval), and custom transformations during saving.
  • Distributed Training Support: Adds a get_process_index utility function to the Keras backend, facilitating distributed training setups by identifying the primary process for checkpoint operations across JAX, TensorFlow, and PyTorch.
  • Extensible with Custom Handlers: Exposes advanced Orbax functionalities like CheckpointManager, TypeHandler, and register_type_handler to allow users to define custom serialization logic for complex objects.
  • Iterator State Saving and Restoration: Enables saving and restoring the state of data iterators, crucial for seamless training resumption from a specific point, with backend-specific examples for TensorFlow, JAX, and PyTorch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.

My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.

@codecov-commenter
Copy link

codecov-commenter commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 77.09924% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.56%. Comparing base (47fcb39) to head (7b3cce9).
⚠️ Report is 54 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 77.50% 17 Missing and 10 partials ⚠️
keras/src/utils/module_utils.py 75.00% 2 Missing ⚠️
keras/api/_tf_keras/keras/callbacks/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21762      +/-   ##
==========================================
- Coverage   82.69%   82.56%   -0.13%     
==========================================
  Files         573      578       +5     
  Lines       58888    59730     +842     
  Branches     9218     9378     +160     
==========================================
+ Hits        48696    49318     +622     
- Misses       7845     7995     +150     
- Partials     2347     2417      +70     
Flag Coverage Δ
keras 82.38% <75.57%> (-0.12%) ⬇️
keras-jax 62.89% <68.70%> (-0.35%) ⬇️
keras-numpy 57.46% <32.82%> (-0.26%) ⬇️
keras-openvino 34.35% <32.82%> (-0.05%) ⬇️
keras-tensorflow 64.42% <71.75%> (+0.41%) ⬆️
keras-torch 63.60% <71.75%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This checkpointing system has a ton of features!

Quick first pass.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more comments I forgot.

- Remove conditional export decorator to ensure OrbaxCheckpoint is always available
- Remove unnecessary exception handling in state tree operations
- Update process index check comment for clarity
- Format code to comply with 80-character line limit
- Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values
- Fix long comment line in test file
- Apply code formatting changes
…st handling

- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling
- Add conditional exports for optional orbax-checkpoint dependency
- Use pytest.importorskip for clean optional dependency testing
- Ensure graceful handling when orbax-checkpoint is not installed
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX implementation of def process_id() is missing.

General questions:

  • Does this as-is support all backends?
  • Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?

- Preserve nested state tree structures instead of flattening for better layer name preservation
- Add backward compatibility for old flattened format checkpoints
- Simplify test class by using self.get_temp_dir() instead of setUp/tearDown
- Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests
- Move process_id function from backend to distribution module
- Update imports to use centralized LazyModule for orbax.checkpoint
- Test across all backends (JAX, TensorFlow, PyTorch) - all passing
- Optimize JAX array handling: avoid unnecessary numpy conversions for JAX >= 0.7.0
- Simplify step counting: use _total_batches_seen directly instead of dual mechanisms
- Remove impossible error checks and verbose messages
- Clean up unused Orbax exports that violated import policies
- Update error message for consistency
- All changes maintain backward compatibility and pass tests across JAX/TensorFlow/PyTorch backends
- Remove extra features: save_metadata, save_data_iterator, post_finalization_callback, save_decision_policy, keep_period
- Remove loading methods: load_checkpoint, load_latest, all_steps, _restore_model_state_from_full_tree
- Replace save_optimizer_state/save_metrics_state with save_weights_only parameter
- Add comprehensive test coverage for all remaining functionality
- Maintain async saving and preservation policies as Orbax-specific advantages
- All tests pass across JAX/TensorFlow/PyTorch backends
- Add test_checkpoint_loading: Verifies weights can be loaded from checkpoints
- Add test_checkpoint_loading_weights_only: Tests save_weights_only=True loading
- Add test_checkpoint_loading_with_optimizer_state: Tests full state loading with optimizer
- Fix array comparison logic for JAX, TensorFlow, and PyTorch backends
- Ensure all lines are within 80-character limit
- All tests pass on JAX, TensorFlow, and PyTorch backends
…ompatibility, and comprehensive testing

- Add complete model state saving (trainable/non-trainable vars, optimizer, metrics)
- Simplify save_weights_only logic to use full state tree when saving complete state
- Remove unnecessary try-except fallback for wait() method (V1 API always has it)
- Add comprehensive test coverage (13 tests) for all state components
- Ensure cross-backend compatibility (JAX, TensorFlow, PyTorch)
- Remove version dependencies and conditional imports
- Update requirements-common.txt with orbax-checkpoint dependency
@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch 2 times, most recently from e44e546 to 6c5fda6 Compare November 24, 2025 08:33
…hen missing

- Prevents AttributeError in CI environments with older JAX versions
- Adds no-op lambda function when record_scalar is not available
- Ensures tests run across different JAX versions
- All 13 tests pass on JAX, TensorFlow, and PyTorch backends
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants