feat(archon): add EP-aware padding wrapper for MoE grouped_mm#836
feat(archon): add EP-aware padding wrapper for MoE grouped_mm#836garrett4wade merged 1 commit intomainfrom
Conversation
Summary of ChangesHello @rchardx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an EP-aware padding wrapper for MoE grouped_mm, which is a solid performance enhancement. The core logic change in GroupedExperts.forward is sound. The accompanying tests are thorough and cover many important cases. However, I've identified a few areas for improvement in the tests to make them more robust and accurate. Specifically, one test is incomplete, and another doesn't fully verify the intended behavior. I've also included a minor suggestion to enhance code readability. Overall, great work on this feature.
There was a problem hiding this comment.
Pull request overview
This PR adds EP-aware automatic padding for MoE grouped matrix multiplication operations. The padding wrapper from utils is now automatically applied in GroupedExperts.forward() when Expert Parallelism is not being used, following Torchtitan's approach for token alignment.
Changes:
- Import and conditionally apply
indices_padding_wrapperto_run_experts_grouped_mmbased on EP detection - Add comprehensive tests for gradient flow, aligned/unaligned tokens, and EP detection logic
- Update integration tests to use CUDA devices (required for
permute_tokensandgrouped_mm)
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| areal/experimental/models/archon/moe/grouped_experts.py | Adds EP-aware conditional application of padding wrapper in forward() method |
| areal/tests/experimental/archon/test_grouped_experts.py | Adds tests for padding wrapper functionality and updates integration tests for CUDA compatibility |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Apply indices_padding_wrapper automatically in GroupedExperts.forward() when Expert Parallelism (EP) is not used. This follows Torchtitan's approach where the padding wrapper handles token alignment for torch._grouped_mm, while EP hooks handle padding when EP is enabled. The wrapper is applied when: - Weights are not DTensor, OR - DTensor doesn't have "ep" in device_mesh.mesh_dim_names
7fad817 to
c21af32
Compare
…ionAI#836) Apply indices_padding_wrapper automatically in GroupedExperts.forward() when Expert Parallelism (EP) is not used. This follows Torchtitan's approach where the padding wrapper handles token alignment for torch._grouped_mm, while EP hooks handle padding when EP is enabled. The wrapper is applied when: - Weights are not DTensor, OR - DTensor doesn't have "ep" in device_mesh.mesh_dim_names
Description
Apply indices_padding_wrapper automatically in GroupedExperts.forward() when Expert Parallelism (EP) is not used. This follows Torchtitan's approach where the padding wrapper handles token alignment for torch._grouped_mm, while EP hooks handle padding when EP is enabled.
The wrapper is applied when:
Type of Change
work as expected)
Checklist
jb build docs/gemini review)