Skip to content

fix: Use DistributedSampler for dataloader instead of splitting dataset#456

Merged
garrett4wade merged 10 commits intomainfrom
dhh/fix-distri-sample
Oct 16, 2025
Merged

fix: Use DistributedSampler for dataloader instead of splitting dataset#456
garrett4wade merged 10 commits intomainfrom
dhh/fix-distri-sample

Conversation

@dhh1995
Copy link
Copy Markdown
Collaborator

@dhh1995 dhh1995 commented Oct 15, 2025

Issue

When splitting datasets among different ranks, the len(dataloader) could be different.
This could lead to the training process getting stuck.

Example

255 data points, batch_size=64, world_size=8, and batch_size on each rank is 64 / 8 = 8.
Then steps_per_epoch should be 255 // 64 = 3 when drop_last is true.
However, when using split_dataset_by_node to split the dataset, ranks 0-6 receive 32 data points, and rank 7 receives 31.
Then, when using dataloader for these split datasets, the rank 7 gets steps_per_epoch=3 (because 31 // 8 = 3 when drop_last is true) while others get steps_per_epoch=4.

Solution

We use DistributedSampler instead to sample a distributed batch.
This PR applies this change to all examples.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @dhh1995, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the robustness and correctness of distributed training by addressing a critical flaw in how datasets were previously split across multiple processes. By transitioning to torch.utils.data.DistributedSampler, the system now guarantees that each training process receives a perfectly balanced and non-overlapping subset of data, thereby eliminating potential deadlocks caused by inconsistent batch sizes. This change centralizes data loading logic into a new utility, simplifying future development and maintenance of distributed training pipelines.

Highlights

  • Distributed Data Loading Fix: Resolved an issue where uneven dataset splitting across distributed training ranks could lead to dataloader length inconsistencies and training deadlocks. The solution involves adopting torch.utils.data.DistributedSampler for robust data distribution.
  • Standardized Dataloader Creation: Introduced a new utility function, create_dataloader, in areal/utils/dataloader.py to centralize and standardize the process of creating StatefulDataLoader instances with DistributedSampler across all training examples.
  • Simplified Dataset Functions: Refactored existing dataset retrieval functions (e.g., get_custom_dataset, get_boba_math_dataset) to no longer handle explicit data splitting by rank and world_size. These functions now return complete datasets, with distributed sampling handled by the new create_dataloader utility.
  • Deprecation of Old Splitting Method: The original get_custom_dataset function has been deprecated and now issues a warning, guiding users to use the new get_complete_custom_dataset in conjunction with DistributedSampler.
  • Widespread Example Updates: Applied the new distributed data loading pattern across all relevant training examples, ensuring consistency and correctness in how datasets are prepared and loaded for distributed training.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical issue in distributed training where uneven dataset splitting could cause processes to hang. The solution of replacing manual dataset splitting with torch.utils.data.DistributedSampler is a solid approach and has been applied consistently across most of the example scripts. This is a great improvement for the stability and correctness of distributed training. I've identified a few critical issues in the implementation: a couple of dataset loading functions were missed during the refactoring, which will lead to runtime errors, and there's a logical error in the new create_dataloader utility. Once these are addressed, the PR will be in excellent shape.

Comment thread areal/dataset/__init__.py Outdated
Comment thread examples/alignment/hhrlhf_rw.py Outdated
@garrett4wade
Copy link
Copy Markdown
Collaborator

In addition:

  • Please run pre-commit run --all-files
  • You can use pytest -sv --lf --sw areal/tests/test_examples.py to run over all examples to test whether the new script is runnable.

@garrett4wade
Copy link
Copy Markdown
Collaborator

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the dataset loading mechanism to use torch.utils.data.DistributedSampler instead of manually splitting datasets with split_dataset_by_node. This is a solid improvement that correctly addresses the issue of uneven data distribution across ranks, which could cause training to hang.

I've identified a few critical issues where the refactoring was incomplete, leading to NameError or TypeError in some dataset loading functions. I've also pointed out one function that was missed during the refactoring. Please address these points to ensure the stability and consistency of the new data loading pipeline.

Overall, great work on centralizing the distributed data loading logic!

@garrett4wade
Copy link
Copy Markdown
Collaborator

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the dataset loading logic to use DistributedSampler instead of manually splitting datasets, which is a great improvement to prevent training hangs in distributed setups. The changes are extensive and correctly applied across most of the example scripts. However, I've identified three critical issues that will cause runtime errors and need to be addressed: a broken fallback in the new get_custom_dataset function, an oversight in updating one of the dataset helper functions (get_geometry3k_sft_dataset), and a NameError in another (get_torl_data_rl_dataset). Once these are fixed, the PR should be in good shape.

Comment thread areal/dataset/__init__.py
Comment thread areal/dataset/geometry3k.py
Comment thread areal/dataset/torl_data.py Outdated
Comment thread areal/utils/dataloader.py
Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@garrett4wade garrett4wade merged commit e02aa9a into main Oct 16, 2025
1 of 4 checks passed
@garrett4wade garrett4wade deleted the dhh/fix-distri-sample branch October 16, 2025 05:06
leandermaben pushed a commit to leandermaben/AReaL that referenced this pull request Mar 24, 2026
…et (inclusionAI#456)

* use DistributedSampler and update get_custom_dataset interface

* apply changes to all other examples

* fix remaining datasets based on gemini review
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants