Skip to content

Implement PyTorch RLDS data loader with distributed support#920

Open
tahsinkose wants to merge 1 commit intoPhysical-Intelligence:mainfrom
tahsinkose:feat/pytorch-rlds-dataloader
Open

Implement PyTorch RLDS data loader with distributed support#920
tahsinkose wants to merge 1 commit intoPhysical-Intelligence:mainfrom
tahsinkose:feat/pytorch-rlds-dataloader

Conversation

@tahsinkose
Copy link
Copy Markdown

Summary

Implements the PyTorch RLDS data loader, replacing the NotImplementedError stub in
create_rlds_data_loader. This enables PyTorch DDP training on RLDS-format datasets (e.g.,
DROID).

Design

  • Rank-0-only pipeline: Only rank 0 builds the heavy TF-based RLDS pipeline; other ranks
    receive batches via torch.distributed.broadcast to avoid redundant resource usage.
  • Broadcast + shard: Each full batch is broadcast from rank 0, then sliced so every rank gets
    batch_size // world_size samples.
  • Exhaustion sync: A broadcast flag ensures all ranks break together, preventing deadlocks on
    collectives.
  • Single-GPU: When torch.distributed is not initialized, the loader degenerates to a simple
    pass-through with no overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant