From 1378e04db23e7fc3866dc4c9484f1706cf5e99cd Mon Sep 17 00:00:00 2001 From: brukcodes Date: Thu, 12 Mar 2026 12:35:37 +0300 Subject: [PATCH 1/2] Fix rollout sync memory spike by broadcasting rollouts in small chunks --- megatron/rl/rl_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index d68a0330989..1852f63fc8d 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -570,9 +570,15 @@ def get_environment_rollouts( rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)] with nvtx_range("sync-rollouts"): - # Wait for Rollouts to be collected - # TODO(jbarker): double check why this isn't causing rank 0 memory allocations - torch.distributed.broadcast_object_list(rollouts, src=0) + # Wait for Rollouts to be collected. Broadcast in small chunks to avoid + # pickling the entire `rollouts` list at once (which can cause large + # temporary allocations on rank 0). + for i in range(n_prompts): + group = rollouts[i] if rank == 0 else None + obj_list = [group] + torch.distributed.broadcast_object_list(obj_list, src=0) + if rank != 0: + rollouts[i] = obj_list[0] logger.debug(f"Got rollouts on rank {rank}") if args.rl_offload_optimizer_during_inference: From 8cc6f73984e8bd45673794f9656c0c551f85e4bd Mon Sep 17 00:00:00 2001 From: brukcodes Date: Mon, 16 Mar 2026 14:25:12 +0300 Subject: [PATCH 2/2] Fix torch.distributed import issue by adding optional import check --- megatron/rl/parallel_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/rl/parallel_utils.py b/megatron/rl/parallel_utils.py index da4ee8aa4cf..3f169d6e2ad 100644 --- a/megatron/rl/parallel_utils.py +++ b/megatron/rl/parallel_utils.py @@ -6,7 +6,12 @@ from typing import Optional -import torch.distributed as dist +try: + import torch.distributed as dist + HAS_TORCH_DISTRIBUTED = True +except ImportError: + HAS_TORCH_DISTRIBUTED = False + dist = None from megatron.core import mpu from megatron.core.hyper_comm_grid import HyperCommGrid @@ -71,6 +76,9 @@ def build_inference_pg_collection( f"World size ({world_size}) must be divisible by expt_tp*ep*pp ({expt_tp_size * ep_size * pp_size})" ) + if not HAS_TORCH_DISTRIBUTED: + raise ImportError("torch.distributed is required for building inference process groups.") + rank = dist.get_rank() # ====================