From e25e08928b07b1546a0534fb4295e0a53f233f50 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Thu, 23 Apr 2026 15:05:01 +0800 Subject: [PATCH] fix(eval): abort all ranks on per-rank exception instead of deadlocking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When torchrun or accelerate launches multiple ranks and one rank's evaluate() raises, the current handler logs the error, appends None, and lets the rank return normally. The other ranks continue into the next collective (gather_object, barrier, etc.) and block on NCCL until the launcher's wall-clock timeout tears the job down. On a cluster this can waste hours of GPU time for a failure that is visible in a single rank's log at second zero. Propagate the failure immediately: if torch.distributed is initialized when the exception reaches the outer handler, destroy the process group and sys.exit(1) so the launcher's elastic supervisor sees a non-zero exit and tears down the rest of the world. No behavior change for single-process runs — the is_initialized check gates the new path entirely. --- lmms_eval/__main__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index d9804d98f..316e0e191 100755 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -563,6 +563,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: traceback.print_exc() eval_logger.error(f"Error during evaluation: {e}. Please set `--verbosity=DEBUG` to get more information.") results_list.append(None) + # Under torchrun/accelerate, a rank returning after a local + # exception leaves peers blocked in NCCL collectives until + # the launcher's timeout. Abort all ranks so the launcher + # propagates the failure immediately instead of deadlocking. + if torch.distributed.is_available() and torch.distributed.is_initialized(): + try: + torch.distributed.destroy_process_group() + except Exception: + pass + sys.exit(1) for args, results in zip(args_list, results_list): # cli_evaluate will return none if the process is not the main process (rank 0)