diff --git a/megatron/rl/parallel_utils.py b/megatron/rl/parallel_utils.py index da4ee8aa4cf..3f169d6e2ad 100644 --- a/megatron/rl/parallel_utils.py +++ b/megatron/rl/parallel_utils.py @@ -6,7 +6,12 @@ from typing import Optional -import torch.distributed as dist +try: + import torch.distributed as dist + HAS_TORCH_DISTRIBUTED = True +except ImportError: + HAS_TORCH_DISTRIBUTED = False + dist = None from megatron.core import mpu from megatron.core.hyper_comm_grid import HyperCommGrid @@ -71,6 +76,9 @@ def build_inference_pg_collection( f"World size ({world_size}) must be divisible by expt_tp*ep*pp ({expt_tp_size * ep_size * pp_size})" ) + if not HAS_TORCH_DISTRIBUTED: + raise ImportError("torch.distributed is required for building inference process groups.") + rank = dist.get_rank() # ==================== diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 8f145bd5a4c..7b0ab7f5031 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -571,9 +571,15 @@ def get_environment_rollouts( rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)] with nvtx_range("sync-rollouts"): - # Wait for Rollouts to be collected - # TODO(jbarker): double check why this isn't causing rank 0 memory allocations - torch.distributed.broadcast_object_list(rollouts, src=0) + # Wait for Rollouts to be collected. Broadcast in small chunks to avoid + # pickling the entire `rollouts` list at once (which can cause large + # temporary allocations on rank 0). + for i in range(n_prompts): + group = rollouts[i] if rank == 0 else None + obj_list = [group] + torch.distributed.broadcast_object_list(obj_list, src=0) + if rank != 0: + rollouts[i] = obj_list[0] logger.debug(f"Got rollouts on rank {rank}") if args.rl_offload_optimizer_during_inference: