Skip to content
Draft
24 changes: 22 additions & 2 deletions charge3net_ft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def _build_parquet_index(parquet_dir: Path) -> tuple:
index.append((fi, ri))

n_valid = len(index)
print(f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files")
print(
f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files"
)
return file_paths, index


Expand Down Expand Up @@ -230,6 +232,7 @@ def build_dataloaders(
num_workers: int = 4,
seed: int = 42,
pin_memory: bool = False,
distributed: bool = False,
) -> tuple:
"""
Build train, validation, and test DataLoaders.
Expand Down Expand Up @@ -298,10 +301,27 @@ def build_dataloaders(

collate_fn = partial(collate_list_of_dicts, pin_memory=pin_memory)

# DDP path: shard the training set across ranks via DistributedSampler.
# Val/test stay non-distributed (each rank evaluates the whole set; only
# rank 0 reports). This wastes V+T compute but keeps eval simple and
# rank-agnostic. The data is tiny (5%+5% of 65k) so it's fine.
train_sampler = None
if distributed:
from torch.utils.data.distributed import DistributedSampler

train_sampler = DistributedSampler(
train_subset,
shuffle=True,
seed=seed,
drop_last=True,
)

train_loader = DataLoader(
train_subset,
batch_size=batch_size,
shuffle=True,
# shuffle and sampler are mutually exclusive in DataLoader.
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
Expand Down
Loading