feat(archon): add Context Parallelism (Ulysses SP) support#817
feat(archon): add Context Parallelism (Ulysses SP) support#817garrett4wade merged 1 commit intomainfrom
Conversation
Summary of ChangesHello @rchardx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Archon Engine's capabilities by introducing Ulysses Sequence Parallelism (Context Parallelism). This new feature allows for more efficient training of models with very long sequences by distributing the attention computation across multiple GPUs. The changes involve fundamental updates to the parallel dimension management, model parallelization logic for Qwen2 and Qwen3, and improvements to data handling and testing infrastructure to support this advanced parallelism strategy. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces Ulysses Sequence Parallelism (Context Parallelism) to the Archon Engine, a significant feature for training models on long sequences. The implementation is comprehensive and well-executed, touching on parallelism configuration, engine logic, model implementation, and testing.
Key highlights of the changes include:
- A robust refactoring of
ArchonParallelDimsto use a 3D device mesh (dp, cp, tp), which provides a clean and extensible foundation for managing complex parallelism strategies. - Seamless integration of Context Parallelism into the
ArchonEngine, including input slicing, output gathering, and page-aligned batch padding. - Modifications to Qwen2 and Qwen3 models to support Ulysses SP via All-to-All communication in the attention mechanism.
- A comprehensive new test suite for distributed execution (
test_distributed.pyandrun_cp_forward.py) that validates DP, TP, and the new CP functionality. - Significant improvements to the weight synchronization tests, making them more robust and comprehensive by parameterizing across all supported model types and adding bidirectional completeness checks.
The code is of high quality, with clear documentation and strong validation logic. The addition of a check for incompatible configurations (TP + AC + compile) is particularly valuable for preventing user errors.
I have one suggestion for a minor refactoring to reduce code duplication. Overall, this is an excellent contribution.
6705c63 to
dfcbfc2
Compare
There was a problem hiding this comment.
Pull request overview
This PR implements Ulysses Sequence Parallelism (Context Parallelism) for the Archon Engine, enabling efficient long-sequence training by distributing attention computation across GPUs via All-to-All communication.
Key changes:
- Implements 3D mesh topology (dp, cp, tp) with proper process groups and rank accessors
- Adds Ulysses SP utilities for input slicing and output gathering during forward/backward passes
- Integrates CP into Qwen2/Qwen3 attention modules with validation of head count constraints
- Enhances batch padding logic to align with page size and CP requirements using LCM
- Expands test coverage with comprehensive multi-GPU distributed tests and 100% weight sync verification
Reviewed changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| areal/utils/ulysses.py | Refactored Ulysses utilities; removed async_op, simplified All-to-All operations |
| areal/utils/data.py | Added batch_align_to parameter for CP-compatible padding alignment |
| areal/utils/constants.py | Added DEFAULT_PAGE_SIZE_BYTES constant for memory allocation |
| areal/experimental/utils/archon/parallel.py | Implemented 3D mesh (dp, cp, tp) with proper sub-meshes and process groups |
| areal/experimental/models/archon/ulysses.py | New file with CP-specific input slicing and output gathering utilities |
| areal/experimental/models/archon/qwen2/model/model.py | Added Ulysses SP support to Attention module with head scattering/gathering |
| areal/experimental/models/archon/qwen3/model/model.py | Added Ulysses SP support to Attention module with Q/K norm handling |
| areal/experimental/models/archon/qwen2/infra/parallelize.py | Added apply_cp() with head count validation and set_cp_group() integration |
| areal/experimental/models/archon/qwen3/infra/parallelize.py | Added apply_cp() with GQA-aware head repeating logic |
| areal/experimental/models/archon/qwen2/model/state_dict_adapter.py | Strip torch.compile _orig_mod prefix for weight sync compatibility |
| areal/experimental/models/archon/qwen3/model/state_dict_adapter.py | Strip torch.compile _orig_mod prefix for weight sync compatibility |
| areal/experimental/models/archon/model_spec.py | Updated ParallelizeFn protocol to include cp_group parameter |
| areal/experimental/models/archon/base.py | Added cu_seqlens and max_seqlen to forward signature |
| areal/experimental/models/archon/attention.py | Added enable_gqa flag to SDPA for GQA support |
| areal/experimental/engine/archon_engine.py | Integrated CP with input slicing, output gathering, and LCM-based padding |
| areal/api/alloc_mode.py | Removed CP validation error to enable CP configuration |
| areal/engine/fsdp_engine.py | Fixed import path for Scheduler (API module) |
| areal/tests/experimental/archon/test_distributed.py | New comprehensive multi-GPU tests for DP, TP, and CP |
| areal/tests/experimental/archon/torchrun/run_cp_forward.py | New CP forward test script with golden comparison |
| areal/tests/experimental/archon/test_weight_sync.py | Enhanced to test all supported model types with 100% weight verification |
| areal/tests/experimental/archon/test_forward.py | Moved multi-GPU tests to test_distributed.py |
| areal/tests/experimental/archon/test_grpo.py | Updated to use original batch length for padding-agnostic comparison |
| areal/tests/experimental/archon/utils.py | Added get_model_path_for_type() and dist cleanup in teardown |
| areal/tests/experimental/archon/torchrun/run_vs_fsdp.py | Updated comparison logic to handle different padding strategies |
| areal/tests/experimental/archon/torchrun/run_tp_forward.py | Added process group cleanup |
| areal/tests/experimental/archon/torchrun/run_forward.py | Added process group cleanup |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
dfcbfc2 to
454889a
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant new functionality by adding Context Parallelism (Ulysses SP) support to the Archon engine. The changes are extensive and well-structured, including a major refactoring of the parallelism handling with a clearer 3D device mesh, and the addition of dedicated utility modules for Ulysses and parallelism validation. The test suite has also been commendably enhanced with new distributed tests and parameterization of existing ones, which is crucial for a change of this complexity. I've found one critical issue in the new ulysses.py file that would cause a runtime error, which I've detailed below. Besides that, the implementation appears solid and thoughtfully designed.
Implement Context Parallelism for Archon engine using Ulysses SP with All-to-All communication pattern for distributed attention computation. Key changes: - Add CP support to ArchonParallelDims with proper mesh configuration - Implement ulysses_slice_inputs/ulysses_gather_output for input/output handling - Integrate CP into Qwen2/Qwen3 attention modules via set_cp_group() - Add validation for TP/CP constraints on attention head counts - Enable page-aligned batching with LCM of page_size and seq_len_divisor - Strip ._orig_mod prefix in state dict adapters for torch.compile compat - Add incompatible config check for TP + AC + compile combination
454889a to
2542933
Compare
…nAI#817) Implement Context Parallelism for Archon engine using Ulysses SP with All-to-All communication pattern for distributed attention computation. Key changes: - Add CP support to ArchonParallelDims with proper mesh configuration - Implement ulysses_slice_inputs/ulysses_gather_output for input/output handling - Integrate CP into Qwen2/Qwen3 attention modules via set_cp_group() - Add validation for TP/CP constraints on attention head counts - Enable page-aligned batching with LCM of page_size and seq_len_divisor - Strip ._orig_mod prefix in state dict adapters for torch.compile compat - Add incompatible config check for TP + AC + compile combination
Description
Implement Context Parallelism for Archon engine using Ulysses SP with
All-to-All communication pattern for distributed attention computation.
Key changes:
Type of Change
work as expected)
Checklist
jb build docs/gemini review)