From 17c0c6f1131232772d39d665472d3c8e137217fa Mon Sep 17 00:00:00 2001 From: Philip Petrakian Date: Wed, 11 Feb 2026 18:07:40 +0000 Subject: [PATCH] Remove pickle from dist ckpt --- .../core/dist_checkpointing/strategies/filesystem_async.py | 3 +-- megatron/core/dist_checkpointing/strategies/torch.py | 3 +-- tests/unit_tests/dist_checkpointing/test_torch_dist.py | 7 ++----- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index b23c4e9893d..d905b471719 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -6,7 +6,6 @@ import inspect import logging import os -import pickle import queue from functools import partial from heapq import heappop, heappush @@ -493,7 +492,7 @@ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: path = os.path.join(self.checkpoint_dir, ".metadata") with msc.open(path, "wb") as metadata_file: - pickle.dump(metadata, metadata_file) + torch.save(metadata, metadata_file) else: super().finish(metadata, results) diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index a5b6c009ba4..bdbd96cea87 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -3,7 +3,6 @@ """ Strategies using PyTorch distributed.checkpoint as an underlying format. """ import io import os -import pickle import warnings from collections import defaultdict from contextlib import contextmanager @@ -917,7 +916,7 @@ def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): ) ## save the new metadata with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file: - pickle.dump(metadata, metadata_file) + torch.save(metadata, metadata_file) try: os.fsync(metadata_file.fileno()) except AttributeError: diff --git a/tests/unit_tests/dist_checkpointing/test_torch_dist.py b/tests/unit_tests/dist_checkpointing/test_torch_dist.py index 4f4df058977..c8f4a8106d7 100644 --- a/tests/unit_tests/dist_checkpointing/test_torch_dist.py +++ b/tests/unit_tests/dist_checkpointing/test_torch_dist.py @@ -2,7 +2,6 @@ """Tests for PyTorch DCP based checkpoint format. """ -import pickle from copy import deepcopy from dataclasses import fields @@ -50,8 +49,7 @@ def test_cached_metadata(self, tmp_path_dist_ckpt): save(sharded_state_dict_non_cached, ckpt_dir, async_sharded_save=False) loaded_non_cached = load(sharded_state_dict_non_cached, ckpt_dir) md_path = ckpt_dir / '.metadata' - with md_path.open('rb') as f: - md_non_cached = pickle.load(f) + md_non_cached = torch.load(md_path, weights_only=False) save_strategy = deepcopy(get_default_save_sharded_strategy()) save_strategy.use_cached_ckpt_structure = True @@ -72,8 +70,7 @@ def test_cached_metadata(self, tmp_path_dist_ckpt): loaded_cached = load(sharded_state_dict_cached, ckpt_dir.__enter__()) md_path = ckpt_dir.__enter__() / '.metadata' - with md_path.open('rb') as f: - md_cached = pickle.load(f) + md_cached = torch.load(md_path, weights_only=False) # Check loaded state dict diffs = diff(loaded_non_cached, loaded_cached)