From 8fca23763cd009f9b0acb5479d6c65cff70d119c Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 12 Aug 2024 18:57:29 -0700 Subject: [PATCH 1/4] fix cache cleanup issues Signed-off-by: Vibhu Jawa --- .../backend/torch/hf/memory_curve_utils.py | 7 +++-- crossfit/backend/torch/hf/model.py | 5 ++-- crossfit/backend/torch/loader.py | 30 ++++++++++++------- crossfit/backend/torch/op/base.py | 16 +++++----- crossfit/report/beir/report.py | 7 ++--- crossfit/utils/torch_utils.py | 7 +++++ 6 files changed, 42 insertions(+), 30 deletions(-) diff --git a/crossfit/backend/torch/hf/memory_curve_utils.py b/crossfit/backend/torch/hf/memory_curve_utils.py index dfbe47ed..0dfad8fb 100644 --- a/crossfit/backend/torch/hf/memory_curve_utils.py +++ b/crossfit/backend/torch/hf/memory_curve_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import joblib import numpy as np @@ -22,6 +21,7 @@ from transformers import PreTrainedModel from crossfit.utils.model_adapter import adapt_model_input +from crossfit.utils.torch_utils import cleanup_torch_cache def fit_memory_estimate_curve( @@ -65,6 +65,8 @@ def fit_memory_estimate_curve( y.append(memory_used) except RuntimeError as e: + # Catching run time error because: + # https://github.com/pytorch/pytorch/issues/133280 if "out of memory" in str(e) or "out_of_memory" in str(e): # Early stopping for this batch size seq_len_pbar.close() @@ -75,8 +77,7 @@ def fit_memory_estimate_curve( del batch if "outputs" in vars(): del outputs - gc.collect() - torch.cuda.empty_cache() + cleanup_torch_cache() # Check if we've hit the memory limit for all sequence lengths if seq_len == start_seq_len: diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index df72a38b..b4316026 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import os from functools import lru_cache @@ -24,6 +23,7 @@ from crossfit.backend.torch.hf.memory_curve_utils import fit_memory_estimate_curve from crossfit.backend.torch.model import Model from crossfit.dataset.home import CF_HOME +from crossfit.utils.torch_utils import cleanup_torch_cache class HFModel(Model): @@ -88,8 +88,7 @@ def unload_from_worker(self, worker): delattr(worker, "torch_model") if hasattr(worker, "cfg"): delattr(worker, "cfg") - gc.collect() - torch.cuda.empty_cache() + cleanup_torch_cache() def load_model(self, device="cuda"): return AutoModel.from_pretrained(self.path_or_name).to(device) diff --git a/crossfit/backend/torch/loader.py b/crossfit/backend/torch/loader.py index 3514df23..2b91e9ef 100644 --- a/crossfit/backend/torch/loader.py +++ b/crossfit/backend/torch/loader.py @@ -24,6 +24,7 @@ from crossfit.data.dataframe.dispatch import CrossFrame from crossfit.op.tokenize import clip_tokens from crossfit.utils.model_adapter import adapt_model_input +from crossfit.utils.torch_utils import cleanup_torch_cache DEFAULT_BATCH_SIZE = 512 @@ -173,17 +174,24 @@ def __next__(self): batch = adapt_model_input(fn, batch) break - except torch.cuda.OutOfMemoryError: - mid = start + (end - start) // 2 - if mid == start: - raise - warnings.warn( - f"Not enough memory for a batch size of {end - start}. " - f"Retrying with a new batch size of {mid - start}. " - f"Consider setting initial batch size to {mid - start}." - ) - self.splits.insert(self.current_idx, mid) - end = min(self.splits[self.current_idx], self.num_rows) + except RuntimeError as e: + # Catching run time error because: + # https://github.com/pytorch/pytorch/issues/133280 + if "out of memory" in str(e) or "out_of_memory" in str(e): + mid = start + (end - start) // 2 + if mid == start: + raise + warnings.warn( + f"Not enough memory for a batch size of {end - start}. " + f"Retrying with a new batch size of {mid - start}. " + f"Consider setting initial batch size to {mid - start}." + ) + del batch + cleanup_torch_cache() + self.splits.insert(self.current_idx, mid) + end = min(self.splits[self.current_idx], self.num_rows) + else: + raise e self.current_idx += 1 diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 2c86fa13..8fb6e31f 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc from typing import Optional import cudf @@ -26,7 +25,7 @@ from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader from crossfit.backend.torch.model import Model from crossfit.op.base import Op -from crossfit.utils.torch_utils import concat_and_pad_tensors +from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors class Predictor(Op): @@ -55,7 +54,7 @@ def __init__( @torch.no_grad() def call(self, data, partition_info=None): - index = data.index + index = data.index.copy() if self.sorted_data_loader: loader = SortedSeqLoader( data[["input_ids", "attention_mask"]], @@ -71,7 +70,7 @@ def call(self, data, partition_info=None): progress_bar=self.create_progress_bar(len(data), partition_info), max_seq_len=self.model.max_seq_length(), ) - + del data all_outputs_ls = [] for output in loader.map(self.model.get_model(self.get_worker())): if isinstance(output, dict): @@ -91,16 +90,17 @@ def call(self, data, partition_info=None): ) ) _index = loader.sort_column(index.values) if self.sorted_data_loader else index + del all_outputs_ls + del loader + cleanup_torch_cache() if len(outputs.shape) <= 2: out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index) elif len(outputs.shape) == 3: out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index) else: raise RuntimeError(f"Unexpected output shape: {output.shape}") - - gc.collect() - torch.cuda.empty_cache() - + del outputs, _index + cleanup_torch_cache() return out def meta(self): diff --git a/crossfit/report/beir/report.py b/crossfit/report/beir/report.py index d21717fe..e6883f50 100644 --- a/crossfit/report/beir/report.py +++ b/crossfit/report/beir/report.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc from typing import List, Optional import cudf import cupy as cp import dask_cudf -import torch from cuml.preprocessing import LabelEncoder from crossfit.backend.dask.aggregate import aggregate @@ -32,6 +30,7 @@ from crossfit.op.vector_search import VectorSearchOp from crossfit.report.base import Report from crossfit.report.beir.embed import embed +from crossfit.utils.torch_utils import cleanup_torch_cache class BeirMetricAggregator(Aggregator): @@ -211,9 +210,7 @@ def beir_report( del data del embeddings - gc.collect() - torch.cuda.empty_cache() - + cleanup_torch_cache() aggregator = BeirMetricAggregator(ks) aggregator = Aggregator(aggregator, groupby=groupby, name="") diff --git a/crossfit/utils/torch_utils.py b/crossfit/utils/torch_utils.py index 18132c27..5fc3ec03 100644 --- a/crossfit/utils/torch_utils.py +++ b/crossfit/utils/torch_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc from typing import List, Union import torch @@ -92,3 +93,9 @@ def concat_and_pad_tensors( # Concatenate the padded tensors return torch.cat(padded_outputs, dim=0) + + +def cleanup_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + return None From 7e20f5ebecc2e3808558aa03c96b26800aa1a815 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 27 Aug 2024 00:16:15 -0700 Subject: [PATCH 2/4] Add memory stats after each partition --- crossfit/backend/torch/op/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 8fb6e31f..3b19665b 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -54,6 +54,13 @@ def __init__( @torch.no_grad() def call(self, data, partition_info=None): + # Get the current CUDA device + current_device = torch.cuda.current_device() + + # Print CUDA memory at the beginning of the method + print(f"CUDA memory at start (device {current_device}):") + print(torch.cuda.memory_summary(device=current_device)) + index = data.index.copy() if self.sorted_data_loader: loader = SortedSeqLoader( @@ -101,6 +108,10 @@ def call(self, data, partition_info=None): raise RuntimeError(f"Unexpected output shape: {output.shape}") del outputs, _index cleanup_torch_cache() + + # Print CUDA memory at the end of the method + print(f"CUDA memory at end (device {current_device}):") + print(torch.cuda.memory_summary(device=current_device)) return out def meta(self): From c7e09dffa34567a9aabb51dd4a753831b44ccb40 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 4 Sep 2024 17:53:39 -0700 Subject: [PATCH 3/4] Fix memory leak Signed-off-by: Vibhu Jawa --- crossfit/data/dataframe/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crossfit/data/dataframe/core.py b/crossfit/data/dataframe/core.py index 9ec969c7..a59a44c3 100644 --- a/crossfit/data/dataframe/core.py +++ b/crossfit/data/dataframe/core.py @@ -419,9 +419,12 @@ def columns(self): def assign(self, **kwargs): data = self.data.copy() - for k, v in kwargs.items(): - if self.columns and len(v) != len(self): - raise ValueError(f"Column {k} was length {len(v)}, but expected length {len(self)}") + # Uncommenting below caueses memory leak + # Find out why + # for k, v in kwargs.items(): + # if self.columns and len(v) != len(self): + # raise ValueError(f"Column {k} was length {len(v)}, " + # f"but expected length {len(self)}") data.update(**kwargs) return self.__class__(data) From 2d9a2a0d0c7ebcd936f9b722ccf78b80b523d1fa Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 5 Sep 2024 11:55:22 -0700 Subject: [PATCH 4/4] fix lru_cache --- crossfit/data/dataframe/core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/crossfit/data/dataframe/core.py b/crossfit/data/dataframe/core.py index a59a44c3..2e202f46 100644 --- a/crossfit/data/dataframe/core.py +++ b/crossfit/data/dataframe/core.py @@ -348,7 +348,7 @@ def _(data): # Fall-back `ArrayBundle` definition class ArrayBundle(FrameBackend): - @lru_cache + @lru_cache(maxsize=1) def __len__(self): _len = None for k, v in self.data.items(): @@ -419,12 +419,9 @@ def columns(self): def assign(self, **kwargs): data = self.data.copy() - # Uncommenting below caueses memory leak - # Find out why - # for k, v in kwargs.items(): - # if self.columns and len(v) != len(self): - # raise ValueError(f"Column {k} was length {len(v)}, " - # f"but expected length {len(self)}") + for k, v in kwargs.items(): + if self.columns and len(v) != len(self): + raise ValueError(f"Column {k} was length {len(v)} but expected length {len(self)}") data.update(**kwargs) return self.__class__(data)