diff --git a/open_diloco/train_diloco_torch.py b/open_diloco/train_diloco_torch.py index 6fa66d0..ae3772e 100644 --- a/open_diloco/train_diloco_torch.py +++ b/open_diloco/train_diloco_torch.py @@ -210,7 +210,6 @@ def main( streaming=True, data_files={ "train": "en/c4-train.*.json.gz", - "validation": "en/c4-validation.00000-of-00008.json.gz", }, ) ) @@ -229,9 +228,18 @@ def tokenize_function(data): train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=per_device_train_batch_size) if eval_steps is not None: - eval_dataset = tokenized_datasets["validation"] + # Load validation dataset without streaming to avoid infinite loop + eval_dataset_raw = load_dataset( + "allenai/c4", + "en", + streaming=False, + data_files={ + "validation": "en/c4-validation.00000-of-00008.json.gz", + }, + ) + eval_tokenized = eval_dataset_raw.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"]) eval_dataloader = DataLoader( - eval_dataset, + eval_tokenized["validation"], collate_fn=data_collator, batch_size=per_device_train_batch_size, )