Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import inspect
import logging
import os
import pickle
import queue
from functools import partial
from heapq import heappop, heappush
Expand Down Expand Up @@ -493,7 +492,7 @@ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
path = os.path.join(self.checkpoint_dir, ".metadata")

with msc.open(path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
torch.save(metadata, metadata_file)
else:
super().finish(metadata, results)

Expand Down
3 changes: 1 addition & 2 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
import os
import pickle
import warnings
from collections import defaultdict
from contextlib import contextmanager
Expand Down Expand Up @@ -917,7 +916,7 @@ def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
)
## save the new metadata
with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
torch.save(metadata, metadata_file)
try:
os.fsync(metadata_file.fileno())
except AttributeError:
Expand Down
7 changes: 2 additions & 5 deletions tests/unit_tests/dist_checkpointing/test_torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

"""Tests for PyTorch DCP based checkpoint format. """

import pickle
from copy import deepcopy
from dataclasses import fields

Expand Down Expand Up @@ -50,8 +49,7 @@ def test_cached_metadata(self, tmp_path_dist_ckpt):
save(sharded_state_dict_non_cached, ckpt_dir, async_sharded_save=False)
loaded_non_cached = load(sharded_state_dict_non_cached, ckpt_dir)
md_path = ckpt_dir / '.metadata'
with md_path.open('rb') as f:
md_non_cached = pickle.load(f)
md_non_cached = torch.load(md_path, weights_only=False)

save_strategy = deepcopy(get_default_save_sharded_strategy())
save_strategy.use_cached_ckpt_structure = True
Expand All @@ -72,8 +70,7 @@ def test_cached_metadata(self, tmp_path_dist_ckpt):
loaded_cached = load(sharded_state_dict_cached, ckpt_dir.__enter__())
md_path = ckpt_dir.__enter__() / '.metadata'

with md_path.open('rb') as f:
md_cached = pickle.load(f)
md_cached = torch.load(md_path, weights_only=False)

# Check loaded state dict
diffs = diff(loaded_non_cached, loaded_cached)
Expand Down
Loading