Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import inspect
import io
import os
import pickle
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, cast
Expand Down Expand Up @@ -1693,19 +1692,22 @@ def _encode_extra_state(self, state):
# TE 2.0 changed the format of extra_state to be a byte tensor
if is_te_min_version("2.0.0"):
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
buffer = io.BytesIO()
torch.save(state, buffer)
state_serialized = torch.frombuffer(bytearray(buffer.getvalue()), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_serialized

def _decode_extra_state(self, state):
if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
# No FP8 is indicated by an empty tensor we don't need to deserialize.
if state.numel() == 0:
return
return pickle.loads(state.detach().cpu().numpy().tobytes())
raw_bytes = state.detach().cpu().numpy().tobytes()
buffer = io.BytesIO(raw_bytes)
return torch.load(buffer, map_location="cuda", weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the story for backward compatibility with existing checkpoints?

elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
Expand Down
46 changes: 33 additions & 13 deletions megatron/legacy/data/realm_index.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just remove this file?

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import itertools
import os
import pickle
import json
import shutil

import numpy as np
Expand Down Expand Up @@ -52,12 +52,16 @@ def load_from_file(self):
"""Populate members from instance saved to file"""

if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb'))
print("\n> Loading BlockData", flush=True)
with open(self.embedding_path, 'r') as f:
state_dict = json.load(f)
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
print(">> Finished loading BlockData\n", flush=True)

self.embed_data = state_dict['embed_data']
# Convert string keys back to ints and lists back to numpy float16 arrays
self.embed_data = {
int(k): np.float16(v) for k, v in state_dict['embed_data'].items()
}

def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""
Expand All @@ -72,6 +76,18 @@ def add_block_data(self, row_id, block_embeds, allow_overwrite=False):

self.embed_data[idx] = np.float16(embed)

def _state_to_json_serializable(self):
"""Convert state to a JSON-serializable format.

Converts numpy arrays to lists and int keys to strings for JSON compatibility.
"""
state = self.state()
state['embed_data'] = {
str(k): v.tolist() if hasattr(v, 'tolist') else v
for k, v in state['embed_data'].items()
}
return state

def save_shard(self):
"""
Save the block data that was created this in this process
Expand All @@ -80,9 +96,9 @@ def save_shard(self):
os.makedirs(self.temp_dir_name, exist_ok=True)

# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
with open('{}/{}.json'.format(self.temp_dir_name, self.rank), 'w') \
as writer:
pickle.dump(self.state(), writer)
json.dump(self._state_to_json_serializable(), writer)

def merge_shards_and_save(self):
#Combine all the shards made using save_shard
Expand All @@ -95,21 +111,25 @@ def merge_shards_and_save(self):
seen_own_shard = True
continue

with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
with open('{}/{}'.format(self.temp_dir_name, fname), 'r') as f:
data = json.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])

# add the shard's data and check to make sure there
# is no overlap
self.embed_data.update(data['embed_data'])
# is no overlap. Convert string keys back to ints and
# lists back to numpy float16 arrays.
loaded_embed_data = {
int(k): np.float16(v) for k, v in data['embed_data'].items()
}
self.embed_data.update(loaded_embed_data)
assert len(self.embed_data) == old_size + shard_size

assert seen_own_shard

# save the consolidated shards and remove temporary directory
with open(self.embedding_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
with open(self.embedding_path, 'w') as final_file:
json.dump(self._state_to_json_serializable(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)

print("Finished merging {} shards for a total of {} embeds".format(
Expand Down
4 changes: 2 additions & 2 deletions megatron/training/common_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class ProfilingConfig:
record_memory_history: bool = False
"""Record memory history in last rank."""

memory_snapshot_path: str = "snapshot.pickle"
"""Specifies where to dump the memory history pickle."""
memory_snapshot_path: str = "snapshot.json"
"""Specifies where to dump the memory history snapshot."""

record_shapes: bool = False
"""Record shapes of tensors."""
Expand Down
6 changes: 3 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,10 +2072,10 @@ def training_log(
if iteration % args.log_interval == 0 or is_first_iteration:
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
from json import dump

with open(args.memory_snapshot_path, 'wb') as f:
dump(snapshot, f)
with open(args.memory_snapshot_path, 'w') as f:
dump(snapshot, f, default=str)

elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset)
elapsed_time_per_iteration = elapsed_time / total_iterations
Expand Down
5 changes: 3 additions & 2 deletions tools/checkpoint/checkpoint_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,12 @@ def oom_observer(device, alloc, device_alloc, device_free):
)
)
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
from json import dump

dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_snapshot.pickle", "wb"),
open(f"oom_rank-{torch.distributed.get_rank()}_snapshot.json", "w"),
default=str,
)

torch._C._cuda_attach_out_of_memory_observer(oom_observer)
Expand Down
Loading