diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 01cb3abe88..73cb917f37 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -18,7 +18,8 @@ import copy import json import os -from collections.abc import Callable +from collections.abc import Callable, Iterator +from contextlib import contextmanager, suppress from pathlib import Path from typing import TYPE_CHECKING, Any from warnings import warn @@ -437,6 +438,36 @@ def get_supported_datasets() -> list[str]: return list(SUPPORTED_DATASET_CONFIG.keys()) +@contextmanager +def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: + """Set ``model.config.use_cache = False`` for the duration of the block. + + KV caching is unwanted during calibration / memory-probe forward passes: + it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH) + the cache state is mutated in-place and breaks correctness. Setting + ``use_cache`` unconditionally (rather than only when it was already + present) also sidesteps configs that never assign the attribute at all + — e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward + code that reads ``self.config.use_cache`` would otherwise raise + ``AttributeError``. The prior value is restored on exit if one existed. + """ + config = getattr(model, "config", None) + if config is None: + yield + return + had_attr = hasattr(config, "use_cache") + prev = config.use_cache if had_attr else None + config.use_cache = False + try: + yield + finally: + if had_attr: + config.use_cache = prev + else: + with suppress(AttributeError): + delattr(config, "use_cache") + + def get_max_batch_size( model: torch.nn.Module, max_sample_length: int = 512, @@ -467,42 +498,43 @@ def _get_free_gpu_mem(): torch.ones([1, max_sample_length], dtype=torch.int32, device=model.device) * 100 ) - # Calculate single batch inference with dummy input. - with torch.set_grad_enabled(enable_grad): - infer_method(sample_input_single_batch) - free_mem_after, max_allocated_after = _get_free_gpu_mem() + with _disable_use_cache(model): + # Calculate single batch inference with dummy input. + with torch.set_grad_enabled(enable_grad): + infer_method(sample_input_single_batch) + free_mem_after, max_allocated_after = _get_free_gpu_mem() - mem_diff_per_data_batch = ( - max( - (free_mem_before - free_mem_after), - (max_allocated_after - max_allocated_before), + mem_diff_per_data_batch = ( + max( + (free_mem_before - free_mem_after), + (max_allocated_after - max_allocated_before), + ) + * sample_memory_usage_ratio ) - * sample_memory_usage_ratio - ) - if mem_diff_per_data_batch <= 0: - print( - "Warning: No measurable memory usage found for a single batch. " - "Falling back to batch_size=1." + if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case + print( # pragma: no cover + "Warning: No measurable memory usage found for a single batch. " + "Falling back to batch_size=1." + ) + target_data_batch = 1 # pragma: no cover + else: + target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) + target_input = sample_input_single_batch.expand( + [ + target_data_batch if index == 0 else dim + for index, dim in enumerate(sample_input_single_batch.shape) + ] ) - target_data_batch = 1 - else: - target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) - target_input = sample_input_single_batch.expand( - [ - target_data_batch if index == 0 else dim - for index, dim in enumerate(sample_input_single_batch.shape) - ] - ) - # For some models on multi GPU, we observe the memory per batch is not a constant. - # So we just test the target batch size and make sure we do not go OOM. - while target_data_batch > 1: - with torch.set_grad_enabled(enable_grad): - try: - infer_method(target_input) - break - except torch.cuda.OutOfMemoryError: - target_data_batch = target_data_batch // 2 + # For some models on multi GPU, we observe the memory per batch is not a constant. + # So we just test the target batch size and make sure we do not go OOM. + while target_data_batch > 1: + with torch.set_grad_enabled(enable_grad): + try: + infer_method(target_input) + break + except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path + target_data_batch = target_data_batch // 2 # pragma: no cover # Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64 if target_data_batch < 2: @@ -601,28 +633,16 @@ def _forward_loop( dataloader: DataLoader containing the batched input data allowed_non_tensor_keys: Set of key names whose values may be non-tensor types """ - # Disable KV caching during calibration — it is unnecessary overhead and causes - # correctness issues with hybrid Mamba/attention models whose cache state is mutated - # in-place (e.g., NemotronH). - config = getattr(model, "config", None) - prev_use_cache = getattr(config, "use_cache", None) - if config is not None and prev_use_cache is not None: - config.use_cache = False + with _disable_use_cache(model), torch.no_grad(): + is_enc_dec = model_type_is_enc_dec(model) + infer_method = model.generate if is_enc_dec else model.forward + max_working_batch_size = None # Initialize max working batch size as None - try: - with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward - max_working_batch_size = None # Initialize max working batch size as None - - for _, data in enumerate(tqdm(dataloader)): - # Process batch and update max working batch size - max_working_batch_size = _process_batch( - data, infer_method, max_working_batch_size, allowed_non_tensor_keys - ) - finally: - if config is not None and prev_use_cache is not None: - config.use_cache = prev_use_cache + for _, data in enumerate(tqdm(dataloader)): + # Process batch and update max working batch size + max_working_batch_size = _process_batch( + data, infer_method, max_working_batch_size, allowed_non_tensor_keys + ) def create_forward_loop( diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 9a89d53672..94a2a5a6aa 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -17,8 +17,14 @@ import pytest import torch +from torch.utils.data import DataLoader -from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples +from modelopt.torch.utils.dataset_utils import ( + _disable_use_cache, + _forward_loop, + _process_batch, + get_dataset_samples, +) def setup_test_data(): @@ -145,6 +151,86 @@ def mock_infer(**kwargs): _process_batch(batch_data, mock_infer, allowed_non_tensor_keys={"base_model_outputs"}) +class _Config: + """Minimal config stand-in; instances start with no `use_cache` attribute.""" + + +def test_disable_use_cache_no_config_attr(): + """Model without a `config` attribute: CM is a no-op and does not raise.""" + model = torch.nn.Linear(4, 4) + assert not hasattr(model, "config") + + with _disable_use_cache(model): + assert not hasattr(model, "config") + + assert not hasattr(model, "config") + + +@pytest.mark.parametrize("prev_value", [True, False]) +def test_disable_use_cache_with_existing_attr(prev_value): + """Config that already has `use_cache`: forced to False inside, restored on exit.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.use_cache = prev_value + + with _disable_use_cache(model): + assert model.config.use_cache is False + + assert model.config.use_cache is prev_value + + +def test_disable_use_cache_without_existing_attr(): + """Config that lacks `use_cache`: set to False inside, attribute removed on exit (no leak).""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + assert not hasattr(model.config, "use_cache") + + with _disable_use_cache(model): + assert model.config.use_cache is False + + assert not hasattr(model.config, "use_cache") + + +def test_forward_loop_runs_under_disabled_use_cache(): + """`_forward_loop` runs forward on every batch and restores `use_cache` on exit.""" + seen_use_cache: list[bool] = [] + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = _Config() + self.config.use_cache = True + + def forward(self, **kwargs): + seen_use_cache.append(self.config.use_cache) + + model = _Model() + + def _collate(samples): + return {"input_ids": torch.stack([s["input_ids"] for s in samples])} + + data = [{"input_ids": torch.zeros(8, dtype=torch.long)} for _ in range(3)] + loader = DataLoader(data, batch_size=1, collate_fn=_collate) + + _forward_loop(model, loader) + + assert seen_use_cache == [False, False, False] + assert model.config.use_cache is True + + +def test_disable_use_cache_restores_on_exception(): + """Restore must run even if the with-block raises.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.use_cache = True + + with pytest.raises(RuntimeError, match="boom"), _disable_use_cache(model): + assert model.config.use_cache is False + raise RuntimeError("boom") + + assert model.config.use_cache is True + + @pytest.mark.parametrize("test_local_path", [True, False]) def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path): pytest.importorskip("datasets")