From 37a234f9b1354fba1e2723fc7c43b7907c2654f1 Mon Sep 17 00:00:00 2001 From: Philip Petrakian Date: Wed, 11 Feb 2026 18:12:35 +0000 Subject: [PATCH 1/2] Remove pickle from rest of repo --- .../core/extensions/transformer_engine.py | 12 +++-- megatron/legacy/data/realm_index.py | 46 +++++++++++++------ megatron/training/common_config.py | 4 +- megatron/training/training.py | 6 +-- model_provider.py | 5 +- tools/checkpoint/checkpoint_inspector.py | 5 +- 6 files changed, 51 insertions(+), 27 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 996330f5674..9a58a8160a6 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -5,7 +5,6 @@ import inspect import io import os -import pickle import warnings from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, cast @@ -1693,8 +1692,9 @@ def _encode_extra_state(self, state): # TE 2.0 changed the format of extra_state to be a byte tensor if is_te_min_version("2.0.0"): torch.cuda.synchronize() - state_serialized = bytearray(pickle.dumps(state)) - state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) + buffer = io.BytesIO() + torch.save(state, buffer) + state_serialized = torch.frombuffer(bytearray(buffer.getvalue()), dtype=torch.uint8) else: state_serialized = io.BytesIO() torch.save(state, state_serialized) @@ -1702,10 +1702,12 @@ def _encode_extra_state(self, state): def _decode_extra_state(self, state): if isinstance(state, torch.Tensor): - # No FP8 is indicated by an empty tensor we don't need to unpickle. + # No FP8 is indicated by an empty tensor we don't need to deserialize. if state.numel() == 0: return - return pickle.loads(state.detach().cpu().numpy().tobytes()) + raw_bytes = state.detach().cpu().numpy().tobytes() + buffer = io.BytesIO(raw_bytes) + return torch.load(buffer, map_location="cuda", weights_only=False) elif isinstance(state, io.BytesIO): state.seek(0) return torch.load(state, map_location="cuda") diff --git a/megatron/legacy/data/realm_index.py b/megatron/legacy/data/realm_index.py index dbe924a52ae..6ec28b4eae4 100644 --- a/megatron/legacy/data/realm_index.py +++ b/megatron/legacy/data/realm_index.py @@ -1,7 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import itertools import os -import pickle +import json import shutil import numpy as np @@ -52,12 +52,16 @@ def load_from_file(self): """Populate members from instance saved to file""" if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print("\n> Unpickling BlockData", flush=True) - state_dict = pickle.load(open(self.embedding_path, 'rb')) + print("\n> Loading BlockData", flush=True) + with open(self.embedding_path, 'r') as f: + state_dict = json.load(f) if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print(">> Finished unpickling BlockData\n", flush=True) + print(">> Finished loading BlockData\n", flush=True) - self.embed_data = state_dict['embed_data'] + # Convert string keys back to ints and lists back to numpy float16 arrays + self.embed_data = { + int(k): np.float16(v) for k, v in state_dict['embed_data'].items() + } def add_block_data(self, row_id, block_embeds, allow_overwrite=False): """ @@ -72,6 +76,18 @@ def add_block_data(self, row_id, block_embeds, allow_overwrite=False): self.embed_data[idx] = np.float16(embed) + def _state_to_json_serializable(self): + """Convert state to a JSON-serializable format. + + Converts numpy arrays to lists and int keys to strings for JSON compatibility. + """ + state = self.state() + state['embed_data'] = { + str(k): v.tolist() if hasattr(v, 'tolist') else v + for k, v in state['embed_data'].items() + } + return state + def save_shard(self): """ Save the block data that was created this in this process @@ -80,9 +96,9 @@ def save_shard(self): os.makedirs(self.temp_dir_name, exist_ok=True) # save the data for each shard - with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \ + with open('{}/{}.json'.format(self.temp_dir_name, self.rank), 'w') \ as writer: - pickle.dump(self.state(), writer) + json.dump(self._state_to_json_serializable(), writer) def merge_shards_and_save(self): #Combine all the shards made using save_shard @@ -95,21 +111,25 @@ def merge_shards_and_save(self): seen_own_shard = True continue - with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f: - data = pickle.load(f) + with open('{}/{}'.format(self.temp_dir_name, fname), 'r') as f: + data = json.load(f) old_size = len(self.embed_data) shard_size = len(data['embed_data']) # add the shard's data and check to make sure there - # is no overlap - self.embed_data.update(data['embed_data']) + # is no overlap. Convert string keys back to ints and + # lists back to numpy float16 arrays. + loaded_embed_data = { + int(k): np.float16(v) for k, v in data['embed_data'].items() + } + self.embed_data.update(loaded_embed_data) assert len(self.embed_data) == old_size + shard_size assert seen_own_shard # save the consolidated shards and remove temporary directory - with open(self.embedding_path, 'wb') as final_file: - pickle.dump(self.state(), final_file) + with open(self.embedding_path, 'w') as final_file: + json.dump(self._state_to_json_serializable(), final_file) shutil.rmtree(self.temp_dir_name, ignore_errors=True) print("Finished merging {} shards for a total of {} embeds".format( diff --git a/megatron/training/common_config.py b/megatron/training/common_config.py index 06c84bf7f13..8436b23e13d 100644 --- a/megatron/training/common_config.py +++ b/megatron/training/common_config.py @@ -47,8 +47,8 @@ class ProfilingConfig: record_memory_history: bool = False """Record memory history in last rank.""" - memory_snapshot_path: str = "snapshot.pickle" - """Specifies where to dump the memory history pickle.""" + memory_snapshot_path: str = "snapshot.json" + """Specifies where to dump the memory history snapshot.""" record_shapes: bool = False """Record shapes of tensors.""" diff --git a/megatron/training/training.py b/megatron/training/training.py index c4021fa698e..4be557d12b7 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2071,10 +2071,10 @@ def training_log( if iteration % args.log_interval == 0 or is_first_iteration: if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'): snapshot = torch.cuda.memory._snapshot() - from pickle import dump + from json import dump - with open(args.memory_snapshot_path, 'wb') as f: - dump(snapshot, f) + with open(args.memory_snapshot_path, 'w') as f: + dump(snapshot, f, default=str) elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset) elapsed_time_per_iteration = elapsed_time / total_iterations diff --git a/model_provider.py b/model_provider.py index f8f6ccae01c..65b173981f2 100644 --- a/model_provider.py +++ b/model_provider.py @@ -51,11 +51,12 @@ def oom_observer(device, alloc, device_alloc, device_free): # snapshot right after an OOM happened print('saving allocated state during OOM') snapshot = torch.cuda.memory._snapshot() - from pickle import dump + from json import dump dump( snapshot, - open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'), + open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'w'), + default=str, ) torch._C._cuda_attach_out_of_memory_observer(oom_observer) diff --git a/tools/checkpoint/checkpoint_inspector.py b/tools/checkpoint/checkpoint_inspector.py index 3d03f4db959..71fcc1aa3dd 100644 --- a/tools/checkpoint/checkpoint_inspector.py +++ b/tools/checkpoint/checkpoint_inspector.py @@ -731,11 +731,12 @@ def oom_observer(device, alloc, device_alloc, device_free): ) ) snapshot = torch.cuda.memory._snapshot() - from pickle import dump + from json import dump dump( snapshot, - open(f"oom_rank-{torch.distributed.get_rank()}_snapshot.pickle", "wb"), + open(f"oom_rank-{torch.distributed.get_rank()}_snapshot.json", "w"), + default=str, ) torch._C._cuda_attach_out_of_memory_observer(oom_observer) From 8efb93e5b7496fedf128bc2b88ef8359e94ac0a6 Mon Sep 17 00:00:00 2001 From: Philip Petrakian Date: Thu, 12 Feb 2026 06:51:44 +0000 Subject: [PATCH 2/2] undo changes to model_provider.py --- model_provider.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model_provider.py b/model_provider.py index 65b173981f2..f8f6ccae01c 100644 --- a/model_provider.py +++ b/model_provider.py @@ -51,12 +51,11 @@ def oom_observer(device, alloc, device_alloc, device_free): # snapshot right after an OOM happened print('saving allocated state during OOM') snapshot = torch.cuda.memory._snapshot() - from json import dump + from pickle import dump dump( snapshot, - open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'w'), - default=str, + open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'), ) torch._C._cuda_attach_out_of_memory_observer(oom_observer)