-
Notifications
You must be signed in to change notification settings - Fork 961
[WIP][Exp]Support ray dpo #6395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…-ray # Conflicts: # swift/ray/resource_manager.py
Summary of ChangesHello @tastelikefeet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive Ray support for Direct Preference Optimization (DPO), significantly expanding the distributed training capabilities. It involves a major refactoring of the underlying Ray integration mechanisms, providing more granular control over how different model components and training steps are distributed across Ray workers. This update not only adds DPO to the suite of Ray-supported features but also enhances the scalability and efficiency of existing pre-training and supervised fine-tuning workflows through a more robust and modular Ray infrastructure. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for distributed DPO training using Ray, which is a significant new feature. The implementation is comprehensive, with changes to documentation, new examples, and substantial refactoring of the core training logic to accommodate a distributed setup with separate workers for policy and reference models. While the overall approach is solid, I've identified a couple of critical issues related to model initialization and usage in the distributed context that would prevent the feature from functioning as intended. I've also noted a minor issue regarding error handling. Addressing these points will be crucial for the stability and correctness of the new Ray DPO support.
| super().__init__(args) | ||
| self.reward_model = [] | ||
| if self.args.rlhf_type == 'grpo': | ||
| self.reward_template = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initialization of self.reward_model and self.reward_template happens after super().__init__(args) is called. The super().__init__ call chain eventually triggers remote calls to methods like _prepare_reward_model, which attempt to use these attributes on the worker instances. Since they are not yet initialized on the workers when these remote methods are called, this will lead to an AttributeError.
To fix this, these attributes should be initialized before the super().__init__(args) call. Also, note that self.args is not available before super().__init__, so you should use the args from the __init__ method signature for the conditional initialization of self.reward_template.
| super().__init__(args) | |
| self.reward_model = [] | |
| if self.args.rlhf_type == 'grpo': | |
| self.reward_template = [] | |
| self.reward_model = [] | |
| if args.rlhf_type == 'grpo': | |
| self.reward_template = [] | |
| super().__init__(args) |
| return (mini_batch, ), {} | ||
|
|
||
|
|
||
| class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DPOTrainer inherits from trl.DPOTrainer, which uses self.ref_model to compute reference log probabilities within its get_batch_logps method. In this distributed setup, self.ref_model is None on the default worker where the trainer runs, as the reference model resides on a separate ref worker. This will cause the trainer to incorrectly use the policy model for computing reference log probabilities.
While you've correctly introduced _compute_ref_log_probs to perform remote computation on the ref worker, it is not being called. To fix this, you should override the get_batch_logps method in your custom DPOTrainer to use self.compute_ref_log_probs when a separate reference model is specified. This will ensure the reference log probabilities are computed on the correct worker.
| try: | ||
| ray.get(RayHelper._registry.clear.remote()) | ||
| ray.kill(RayHelper._registry) | ||
| except: # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except clause is generally discouraged as it can catch and hide unexpected system-exiting exceptions like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to catch a more specific exception, such as Exception, to avoid unintentionally suppressing important errors during the teardown process.
| except: # noqa | |
| except Exception: # noqa |
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).