refactor: simplify rollout in training scripts with the connect_engine API#451
refactor: simplify rollout in training scripts with the connect_engine API#451
connect_engine API#451Conversation
Summary of ChangesHello @garrett4wade, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the rollout process in training scripts by introducing a Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the training scripts to simplify rollout generation by introducing a new connect_engine API and helper methods (rollout_batch, prepare_batch) in the FSDP and Megatron engines. This encapsulates the complex logic of distributed rollout, redistribution, and broadcasting, making the training scripts much cleaner and easier to read. The changes are a significant improvement.
However, I've found a critical issue in the new _broadcast_and_redistribute_batch method in both fsdp_engine.py and megatron_engine.py. The redistribute function returns a RedistributedData object, but the code fails to access the .data attribute, which will lead to a type error. I've also pointed out some minor inconsistencies in megatron_engine.py regarding a redundant check, an incomplete error message, and docstring style.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable refactor by centralizing the distributed rollout logic into the FSDPEngine and MegatronEngine. This greatly simplifies the training scripts, making them cleaner and more maintainable. The addition of _check_rollout_engine_connected is also a good defensive measure against runtime errors.
My main feedback focuses on a high-impact opportunity to reduce code duplication. The new rollout methods are nearly identical across both engine classes, and abstracting them into a shared mixin would significantly improve long-term maintainability. I've also noted a minor inconsistency and another opportunity for refactoring within the new methods. Overall, this is a strong improvement to the codebase's architecture.
connect_engine APIconnect_engine API
rchardx
left a comment
There was a problem hiding this comment.
LGTM in general except for one minor issue.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request is a significant and valuable refactoring. It successfully centralizes the distributed rollout logic into a DistRolloutCoordinator, which greatly simplifies the training scripts by removing boilerplate code for batch redistribution and broadcasting. The introduction of rollout_batch and prepare_batch methods in the FSDP and Megatron engines abstracts away the complexities of distributed data handling, improving maintainability and readability.
My review has identified a critical logical error in the _check_rollout_engine_connected method in both FSDPEngine and MegatronEngine that would prevent the code from executing correctly. Once this is addressed, the changes will be a solid improvement to the codebase.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ne` API (inclusionAI#451) * add rollout methods in train engine to simplify data redistribution * fix * fix docstring * . * . * refactor rollout logic into areal.core.dist_rollout * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: Wentai Zhang <zhangwentai.zwt@antgroup.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>


This pull request introduces a major refactor to the distributed rollout workflow for both the FSDP and Megatron training engines, centralizing batch redistribution and broadcast logic within the engine classes. This change simplifies how rollout batches are prepared and distributed across workers, improving code maintainability and reducing redundancy in example scripts. Additionally, new validation checks ensure the rollout engine is properly connected before rollout or weight update operations.
Distributed Rollout and Batch Coordination Improvements:
_broadcast_and_redistribute_batch,rollout_batch, andprepare_batchmethods to bothFSDPEngine(fsdp_engine.py) andMegatronEngine(megatron_engine.py). These methods encapsulate batch redistribution, broadcasting, and synchronization, streamlining distributed rollout handling. [1] [2]_check_rollout_engine_connectedto both engines, enforcing that the rollout engine is connected before performing rollouts or weight updates, preventing runtime errors. [1] [2] [3] [4]Example Script Simplification:
countdown/train.pyandexperimental/dapo/gsm8k_dapo.py) to delegate rollout batch preparation and distribution to the engine methods, removing manual broadcasting, tensor transfer, and synchronization barriers. This results in cleaner, less error-prone code. [1] [2]Imports and Dependency Updates:
StatefulDataLoader,RolloutWorkflow, and various utility functions. [1] [2] [3] [4] [5] [6] [7] [8]Code Cleanup:
broadcast_tensor_container,tensor_container_to, and synchronization barriers, since these are now handled by the engine methods. [1] [2]These changes collectively make distributed rollout workflows more robust, modular, and easier to use in both core engine code and user-facing training scripts.