diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index a4257e53..8cd48ef2 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -19,6 +19,7 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass +from joblib import Parallel, delayed from typing import Any, Dict, List, Optional, Union from tqdm import tqdm @@ -191,22 +192,20 @@ def _process_with_dask(self, metrics): def _process_with_multiprocessing(self, metrics): with open(self.output_manifest_file, "wt", encoding="utf8") as fout: for manifest_chunk in self._chunk_manifest(): - data = itertools.chain( - *process_map( - self.process_dataset_entry, - manifest_chunk, - max_workers=self.max_workers, - chunksize=self.chunksize, - ) + # Parallel processing using joblib + results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")( + delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk ) - for data_entry in tqdm(data): - metrics.append(data_entry.metrics) - if data_entry.data is None: - continue - json.dump(data_entry.data, fout, ensure_ascii=False) - fout.write("\n") - self.number_of_entries += 1 - self.total_duration += data_entry.data.get("duration", 0) + + for result_group in tqdm(results): + for data_entry in result_group: + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + fout.write("\n") + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) def _chunk_manifest(self): """Splits the input manifest into chunks of in_memory_chunksize size. @@ -379,24 +378,22 @@ def process(self): metrics = [] with open(self.output_manifest_file, "wt", encoding="utf8") as fout: for manifest_chunk in self._chunk_manifest(): - # this will unroll all inner lists - data = itertools.chain( - *process_map( - self.process_dataset_entry, - manifest_chunk, - max_workers=self.max_workers, - chunksize=self.chunksize, - ) + + results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")( + delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk ) - for data_entry in tqdm(data): - if data_entry.metrics is not None: - pass # optionally accumulate metrics here - if data_entry.data is None: - continue - json.dump(data_entry.data, fout, ensure_ascii=False) - self.number_of_entries += 1 - self.total_duration += data_entry.data.get("duration", 0) - fout.write("\n") + + for result_group in tqdm(results): + for data_entry in result_group: + if data_entry.metrics is not None: + pass # optionally accumulate metrics here + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + fout.write("\n") + self.finalize(self.test_cases) def prepare(self):