Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion megatron/rl/parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from typing import Optional

import torch.distributed as dist
try:
import torch.distributed as dist
HAS_TORCH_DISTRIBUTED = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can you not have torch distributed here?

except ImportError:
HAS_TORCH_DISTRIBUTED = False
dist = None

from megatron.core import mpu
from megatron.core.hyper_comm_grid import HyperCommGrid
Expand Down Expand Up @@ -71,6 +76,9 @@ def build_inference_pg_collection(
f"World size ({world_size}) must be divisible by expt_tp*ep*pp ({expt_tp_size * ep_size * pp_size})"
)

if not HAS_TORCH_DISTRIBUTED:
raise ImportError("torch.distributed is required for building inference process groups.")

rank = dist.get_rank()

# ====================
Expand Down
12 changes: 9 additions & 3 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,15 @@ def get_environment_rollouts(
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]

with nvtx_range("sync-rollouts"):
# Wait for Rollouts to be collected
# TODO(jbarker): double check why this isn't causing rank 0 memory allocations
torch.distributed.broadcast_object_list(rollouts, src=0)
# Wait for Rollouts to be collected. Broadcast in small chunks to avoid
# pickling the entire `rollouts` list at once (which can cause large
# temporary allocations on rank 0).
for i in range(n_prompts):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you confirm that this is indeed fixing what it claims to fix, please?

group = rollouts[i] if rank == 0 else None
obj_list = [group]
torch.distributed.broadcast_object_list(obj_list, src=0)
if rank != 0:
rollouts[i] = obj_list[0]
logger.debug(f"Got rollouts on rank {rank}")

if args.rl_offload_optimizer_during_inference:
Expand Down
Loading