From eed2b863d371df85d6505fb2111be747dbdf74fa Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:11:41 +0000 Subject: [PATCH 1/7] fix step3.5 Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 125 +++++++++++++++----------- 1 file changed, 71 insertions(+), 54 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 01cb3abe88..03319e364e 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -18,7 +18,8 @@ import copy import json import os -from collections.abc import Callable +from collections.abc import Callable, Iterator +from contextlib import contextmanager from pathlib import Path from typing import TYPE_CHECKING, Any from warnings import warn @@ -437,6 +438,33 @@ def get_supported_datasets() -> list[str]: return list(SUPPORTED_DATASET_CONFIG.keys()) +@contextmanager +def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: + """Set ``model.config.use_cache = False`` for the duration of the block. + + KV caching is unwanted during calibration / memory-probe forward passes: + it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH) + the cache state is mutated in-place and breaks correctness. Setting + ``use_cache`` unconditionally (rather than only when it was already + present) also sidesteps configs that never assign the attribute at all + — e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward + code that reads ``self.config.use_cache`` would otherwise raise + ``AttributeError``. The prior value is restored on exit if one existed. + """ + config = getattr(model, "config", None) + if config is None: + yield + return + had_attr = hasattr(config, "use_cache") + prev = config.use_cache if had_attr else None + config.use_cache = False + try: + yield + finally: + if had_attr: + config.use_cache = prev + + def get_max_batch_size( model: torch.nn.Module, max_sample_length: int = 512, @@ -467,42 +495,43 @@ def _get_free_gpu_mem(): torch.ones([1, max_sample_length], dtype=torch.int32, device=model.device) * 100 ) - # Calculate single batch inference with dummy input. - with torch.set_grad_enabled(enable_grad): - infer_method(sample_input_single_batch) - free_mem_after, max_allocated_after = _get_free_gpu_mem() + with _disable_use_cache(model): + # Calculate single batch inference with dummy input. + with torch.set_grad_enabled(enable_grad): + infer_method(sample_input_single_batch) + free_mem_after, max_allocated_after = _get_free_gpu_mem() - mem_diff_per_data_batch = ( - max( - (free_mem_before - free_mem_after), - (max_allocated_after - max_allocated_before), + mem_diff_per_data_batch = ( + max( + (free_mem_before - free_mem_after), + (max_allocated_after - max_allocated_before), + ) + * sample_memory_usage_ratio ) - * sample_memory_usage_ratio - ) - if mem_diff_per_data_batch <= 0: - print( - "Warning: No measurable memory usage found for a single batch. " - "Falling back to batch_size=1." + if mem_diff_per_data_batch <= 0: + print( + "Warning: No measurable memory usage found for a single batch. " + "Falling back to batch_size=1." + ) + target_data_batch = 1 + else: + target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) + target_input = sample_input_single_batch.expand( + [ + target_data_batch if index == 0 else dim + for index, dim in enumerate(sample_input_single_batch.shape) + ] ) - target_data_batch = 1 - else: - target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) - target_input = sample_input_single_batch.expand( - [ - target_data_batch if index == 0 else dim - for index, dim in enumerate(sample_input_single_batch.shape) - ] - ) - # For some models on multi GPU, we observe the memory per batch is not a constant. - # So we just test the target batch size and make sure we do not go OOM. - while target_data_batch > 1: - with torch.set_grad_enabled(enable_grad): - try: - infer_method(target_input) - break - except torch.cuda.OutOfMemoryError: - target_data_batch = target_data_batch // 2 + # For some models on multi GPU, we observe the memory per batch is not a constant. + # So we just test the target batch size and make sure we do not go OOM. + while target_data_batch > 1: + with torch.set_grad_enabled(enable_grad): + try: + infer_method(target_input) + break + except torch.cuda.OutOfMemoryError: + target_data_batch = target_data_batch // 2 # Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64 if target_data_batch < 2: @@ -601,28 +630,16 @@ def _forward_loop( dataloader: DataLoader containing the batched input data allowed_non_tensor_keys: Set of key names whose values may be non-tensor types """ - # Disable KV caching during calibration — it is unnecessary overhead and causes - # correctness issues with hybrid Mamba/attention models whose cache state is mutated - # in-place (e.g., NemotronH). - config = getattr(model, "config", None) - prev_use_cache = getattr(config, "use_cache", None) - if config is not None and prev_use_cache is not None: - config.use_cache = False + with _disable_use_cache(model), torch.no_grad(): + is_enc_dec = model_type_is_enc_dec(model) + infer_method = model.generate if is_enc_dec else model.forward + max_working_batch_size = None # Initialize max working batch size as None - try: - with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward - max_working_batch_size = None # Initialize max working batch size as None - - for _, data in enumerate(tqdm(dataloader)): - # Process batch and update max working batch size - max_working_batch_size = _process_batch( - data, infer_method, max_working_batch_size, allowed_non_tensor_keys - ) - finally: - if config is not None and prev_use_cache is not None: - config.use_cache = prev_use_cache + for _, data in enumerate(tqdm(dataloader)): + # Process batch and update max working batch size + max_working_batch_size = _process_batch( + data, infer_method, max_working_batch_size, allowed_non_tensor_keys + ) def create_forward_loop( From 9783c102f4cbe0d8501f1dd15b626e667e0253a6 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 27 Apr 2026 18:47:34 +0000 Subject: [PATCH 2/7] address comment Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 5 ++ tests/unit/torch/utils/test_dataset_utils.py | 60 +++++++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 03319e364e..2812487370 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -463,6 +463,11 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: finally: if had_attr: config.use_cache = prev + else: + try: + delattr(config, "use_cache") + except AttributeError: + pass def get_max_batch_size( diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 9a89d53672..8db2e8aa6e 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -18,7 +18,11 @@ import pytest import torch -from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples +from modelopt.torch.utils.dataset_utils import ( + _disable_use_cache, + _process_batch, + get_dataset_samples, +) def setup_test_data(): @@ -145,6 +149,60 @@ def mock_infer(**kwargs): _process_batch(batch_data, mock_infer, allowed_non_tensor_keys={"base_model_outputs"}) +class _Config: + """Minimal config stand-in; instances start with no `use_cache` attribute.""" + + +def test_disable_use_cache_no_config_attr(): + """Model without a `config` attribute: CM is a no-op and does not raise.""" + model = torch.nn.Linear(4, 4) + assert not hasattr(model, "config") + + with _disable_use_cache(model): + assert not hasattr(model, "config") + + assert not hasattr(model, "config") + + +@pytest.mark.parametrize("prev_value", [True, False]) +def test_disable_use_cache_with_existing_attr(prev_value): + """Config that already has `use_cache`: forced to False inside, restored on exit.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.use_cache = prev_value + + with _disable_use_cache(model): + assert model.config.use_cache is False + + assert model.config.use_cache is prev_value + + +def test_disable_use_cache_without_existing_attr(): + """Config that lacks `use_cache`: set to False inside, attribute removed on exit (no leak).""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + assert not hasattr(model.config, "use_cache") + + with _disable_use_cache(model): + assert model.config.use_cache is False + + assert not hasattr(model.config, "use_cache") + + +def test_disable_use_cache_restores_on_exception(): + """Restore must run even if the with-block raises.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.use_cache = True + + with pytest.raises(RuntimeError, match="boom"): + with _disable_use_cache(model): + assert model.config.use_cache is False + raise RuntimeError("boom") + + assert model.config.use_cache is True + + @pytest.mark.parametrize("test_local_path", [True, False]) def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path): pytest.importorskip("datasets") From 5d5398b64d7d2849e1b9b70cf76c4f0f0f3d5976 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:23:48 +0000 Subject: [PATCH 3/7] fix code quanlity check Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 6 ++---- tests/unit/torch/utils/test_dataset_utils.py | 7 +++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 2812487370..bcbf277fbc 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -19,7 +19,7 @@ import json import os from collections.abc import Callable, Iterator -from contextlib import contextmanager +from contextlib import contextmanager, suppress from pathlib import Path from typing import TYPE_CHECKING, Any from warnings import warn @@ -464,10 +464,8 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: if had_attr: config.use_cache = prev else: - try: + with suppress(AttributeError): delattr(config, "use_cache") - except AttributeError: - pass def get_max_batch_size( diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 8db2e8aa6e..d00b9465aa 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -195,10 +195,9 @@ def test_disable_use_cache_restores_on_exception(): model.config = _Config() model.config.use_cache = True - with pytest.raises(RuntimeError, match="boom"): - with _disable_use_cache(model): - assert model.config.use_cache is False - raise RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"), _disable_use_cache(model): + assert model.config.use_cache is False + raise RuntimeError("boom") assert model.config.use_cache is True From 9e78129860ecd83eb62025a798e744d90d8560f4 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 27 Apr 2026 20:09:06 +0000 Subject: [PATCH 4/7] fix coverage Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- tests/unit/torch/utils/test_dataset_utils.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index d00b9465aa..94a2a5a6aa 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -17,9 +17,11 @@ import pytest import torch +from torch.utils.data import DataLoader from modelopt.torch.utils.dataset_utils import ( _disable_use_cache, + _forward_loop, _process_batch, get_dataset_samples, ) @@ -189,6 +191,33 @@ def test_disable_use_cache_without_existing_attr(): assert not hasattr(model.config, "use_cache") +def test_forward_loop_runs_under_disabled_use_cache(): + """`_forward_loop` runs forward on every batch and restores `use_cache` on exit.""" + seen_use_cache: list[bool] = [] + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = _Config() + self.config.use_cache = True + + def forward(self, **kwargs): + seen_use_cache.append(self.config.use_cache) + + model = _Model() + + def _collate(samples): + return {"input_ids": torch.stack([s["input_ids"] for s in samples])} + + data = [{"input_ids": torch.zeros(8, dtype=torch.long)} for _ in range(3)] + loader = DataLoader(data, batch_size=1, collate_fn=_collate) + + _forward_loop(model, loader) + + assert seen_use_cache == [False, False, False] + assert model.config.use_cache is True + + def test_disable_use_cache_restores_on_exception(): """Restore must run even if the with-block raises.""" model = torch.nn.Linear(4, 4) From 56c5edb283f12a4c4706e7fc8700631a8f29590b Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:00:17 +0000 Subject: [PATCH 5/7] test: add get_max_batch_size CPU coverage Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- tests/unit/torch/utils/test_dataset_utils.py | 39 ++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 94a2a5a6aa..ed97d0ba4f 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -24,6 +24,7 @@ _forward_loop, _process_batch, get_dataset_samples, + get_max_batch_size, ) @@ -218,6 +219,44 @@ def _collate(samples): assert model.config.use_cache is True +def test_get_max_batch_size_disables_use_cache_during_probe(monkeypatch): + """Exercise `get_max_batch_size` on CPU by mocking CUDA memory probes. + + Verifies that the probe forward pass runs with `config.use_cache = False` + and that the prior value is restored after the function returns. + """ + seen_use_cache: list[bool] = [] + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = _Config() + self.config.use_cache = True + self.device = torch.device("cpu") + + def forward(self, _input): + seen_use_cache.append(self.config.use_cache) + + model = _Model() + + total_mem = 100_000_000_000 # 100 GB + free_seq = iter([total_mem, total_mem - 1_000_000_000] * 8) # before/after, with extras + + monkeypatch.setattr(torch.cuda, "empty_cache", lambda: None) + monkeypatch.setattr( + torch.cuda, "get_device_properties", lambda _i: type("P", (), {"total_memory": total_mem})() + ) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (next(free_seq), total_mem)) + monkeypatch.setattr(torch.cuda, "max_memory_allocated", lambda _i: 0) + + bsize = get_max_batch_size(model, max_sample_length=4) + + assert isinstance(bsize, int) and bsize >= 1 + assert seen_use_cache and all(v is False for v in seen_use_cache) + assert model.config.use_cache is True + + def test_disable_use_cache_restores_on_exception(): """Restore must run even if the with-block raises.""" model = torch.nn.Linear(4, 4) From 89748120c8c4cc7473391c4bfa4d89898fd5a0f6 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:22:06 +0000 Subject: [PATCH 6/7] test: drop CPU-mocked get_max_batch_size test, mark GPU-only branches no-cover Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 4 +- tests/unit/torch/utils/test_dataset_utils.py | 39 -------------------- 2 files changed, 2 insertions(+), 41 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index bcbf277fbc..c24de20a23 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -511,7 +511,7 @@ def _get_free_gpu_mem(): ) * sample_memory_usage_ratio ) - if mem_diff_per_data_batch <= 0: + if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case print( "Warning: No measurable memory usage found for a single batch. " "Falling back to batch_size=1." @@ -533,7 +533,7 @@ def _get_free_gpu_mem(): try: infer_method(target_input) break - except torch.cuda.OutOfMemoryError: + except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path target_data_batch = target_data_batch // 2 # Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64 diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index ed97d0ba4f..94a2a5a6aa 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -24,7 +24,6 @@ _forward_loop, _process_batch, get_dataset_samples, - get_max_batch_size, ) @@ -219,44 +218,6 @@ def _collate(samples): assert model.config.use_cache is True -def test_get_max_batch_size_disables_use_cache_during_probe(monkeypatch): - """Exercise `get_max_batch_size` on CPU by mocking CUDA memory probes. - - Verifies that the probe forward pass runs with `config.use_cache = False` - and that the prior value is restored after the function returns. - """ - seen_use_cache: list[bool] = [] - - class _Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.config = _Config() - self.config.use_cache = True - self.device = torch.device("cpu") - - def forward(self, _input): - seen_use_cache.append(self.config.use_cache) - - model = _Model() - - total_mem = 100_000_000_000 # 100 GB - free_seq = iter([total_mem, total_mem - 1_000_000_000] * 8) # before/after, with extras - - monkeypatch.setattr(torch.cuda, "empty_cache", lambda: None) - monkeypatch.setattr( - torch.cuda, "get_device_properties", lambda _i: type("P", (), {"total_memory": total_mem})() - ) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) - monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (next(free_seq), total_mem)) - monkeypatch.setattr(torch.cuda, "max_memory_allocated", lambda _i: 0) - - bsize = get_max_batch_size(model, max_sample_length=4) - - assert isinstance(bsize, int) and bsize >= 1 - assert seen_use_cache and all(v is False for v in seen_use_cache) - assert model.config.use_cache is True - - def test_disable_use_cache_restores_on_exception(): """Restore must run even if the with-block raises.""" model = torch.nn.Linear(4, 4) From 80db86c8ba4fe5dbb7b35e35033c7e1045a22866 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:47:52 +0000 Subject: [PATCH 7/7] test: extend pragma no-cover to GPU-only branch bodies Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index c24de20a23..73cb917f37 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -512,11 +512,11 @@ def _get_free_gpu_mem(): * sample_memory_usage_ratio ) if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case - print( + print( # pragma: no cover "Warning: No measurable memory usage found for a single batch. " "Falling back to batch_size=1." ) - target_data_batch = 1 + target_data_batch = 1 # pragma: no cover else: target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) target_input = sample_input_single_batch.expand( @@ -534,7 +534,7 @@ def _get_free_gpu_mem(): infer_method(target_input) break except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path - target_data_batch = target_data_batch // 2 + target_data_batch = target_data_batch // 2 # pragma: no cover # Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64 if target_data_batch < 2: