Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import copy
import json
import os
from collections.abc import Callable
from collections.abc import Callable, Iterator
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import TYPE_CHECKING, Any
from warnings import warn
Expand Down Expand Up @@ -437,6 +438,36 @@ def get_supported_datasets() -> list[str]:
return list(SUPPORTED_DATASET_CONFIG.keys())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tests: Please add unit tests for _disable_use_cache covering:

  1. Model with no config attribute (no-op)
  2. Model whose config already has use_cache (restored on exit)
  3. Model whose config lacks use_cache (attribute should not persist after exit)

These are simple to write with a mock nn.Module and would directly validate the bug fix.



@contextmanager
def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
"""Set ``model.config.use_cache = False`` for the duration of the block.

KV caching is unwanted during calibration / memory-probe forward passes:
it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH)
the cache state is mutated in-place and breaks correctness. Setting
``use_cache`` unconditionally (rather than only when it was already
present) also sidesteps configs that never assign the attribute at all
— e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward
code that reads ``self.config.use_cache`` would otherwise raise
``AttributeError``. The prior value is restored on exit if one existed.
"""
config = getattr(model, "config", None)
if config is None:
yield
return
had_attr = hasattr(config, "use_cache")
prev = config.use_cache if had_attr else None
config.use_cache = False
try:
yield
finally:
if had_attr:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: When had_attr is False (the config didn't originally have use_cache), the context manager sets config.use_cache = False on entry but never removes it on exit. This leaks a new attribute onto the config object.

The finally block should clean up:

finally:
    if had_attr:
        config.use_cache = prev
    else:
        delattr(config, "use_cache")

Without this, after calling get_max_batch_size or _forward_loop, a config that never had use_cache will now permanently have use_cache = False, which could change model behavior for subsequent inference.

config.use_cache = prev
else:
with suppress(AttributeError):
delattr(config, "use_cache")


def get_max_batch_size(
model: torch.nn.Module,
max_sample_length: int = 512,
Expand Down Expand Up @@ -467,42 +498,43 @@ def _get_free_gpu_mem():
torch.ones([1, max_sample_length], dtype=torch.int32, device=model.device) * 100
)

# Calculate single batch inference with dummy input.
with torch.set_grad_enabled(enable_grad):
infer_method(sample_input_single_batch)
free_mem_after, max_allocated_after = _get_free_gpu_mem()
with _disable_use_cache(model):
# Calculate single batch inference with dummy input.
with torch.set_grad_enabled(enable_grad):
infer_method(sample_input_single_batch)
free_mem_after, max_allocated_after = _get_free_gpu_mem()

mem_diff_per_data_batch = (
max(
(free_mem_before - free_mem_after),
(max_allocated_after - max_allocated_before),
mem_diff_per_data_batch = (
max(
(free_mem_before - free_mem_after),
(max_allocated_after - max_allocated_before),
)
* sample_memory_usage_ratio
)
* sample_memory_usage_ratio
)
if mem_diff_per_data_batch <= 0:
print(
"Warning: No measurable memory usage found for a single batch. "
"Falling back to batch_size=1."
if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case
print( # pragma: no cover
"Warning: No measurable memory usage found for a single batch. "
"Falling back to batch_size=1."
)
target_data_batch = 1 # pragma: no cover
else:
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
target_input = sample_input_single_batch.expand(
[
target_data_batch if index == 0 else dim
for index, dim in enumerate(sample_input_single_batch.shape)
]
)
target_data_batch = 1
else:
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
target_input = sample_input_single_batch.expand(
[
target_data_batch if index == 0 else dim
for index, dim in enumerate(sample_input_single_batch.shape)
]
)

# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError:
target_data_batch = target_data_batch // 2
# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path
target_data_batch = target_data_batch // 2 # pragma: no cover

Comment on lines +522 to 538
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Retry loop halves target_data_batch but reuses stale target_input.

After an OOM, the loop updates target_data_batch only; it keeps retrying with the original expanded tensor shape, so reduced batch sizes are never actually tested.

💡 Proposed fix
-        target_input = sample_input_single_batch.expand(
-            [
-                target_data_batch if index == 0 else dim
-                for index, dim in enumerate(sample_input_single_batch.shape)
-            ]
-        )
-
         # For some models on multi GPU, we observe the memory per batch is not a constant.
         # So we just test the target batch size and make sure we do not go OOM.
         while target_data_batch > 1:
+            target_input = sample_input_single_batch.expand(
+                [
+                    target_data_batch if index == 0 else dim
+                    for index, dim in enumerate(sample_input_single_batch.shape)
+                ]
+            )
             with torch.set_grad_enabled(enable_grad):
                 try:
                     infer_method(target_input)
                     break
                 except torch.cuda.OutOfMemoryError:
                     target_data_batch = target_data_batch // 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/dataset_utils.py` around lines 524 - 540, The retry loop
halves target_data_batch on OOM but never rebuilds target_input, so infer_method
keeps receiving the original (too-large) tensor; modify the loop to recreate
target_input from sample_input_single_batch after each reduction of
target_data_batch (i.e., move or add the target_input =
sample_input_single_batch.expand([...]) construction inside the while loop using
the updated target_data_batch and same shape logic with enumerate), then call
infer_method with the rebuilt tensor so each smaller batch size is actually
tested.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, but this is pre-existing — target_input is built outside the OOM retry loop on the parent commit too (verified via git show 204daaf^:modelopt/torch/utils/dataset_utils.py; the loop body is unchanged here, the lines only show in the diff because we re-indented to wrap them in with _disable_use_cache(model):). Will file a separate fix to move target_input = ...expand(...) inside the while loop. Out of scope for this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meenchen, thanks for verifying — that's a fair point. Since the bug predates this PR and you're planning a follow-up fix, would you like me to open a GitHub issue to track moving target_input = sample_input_single_batch.expand(...) inside the while loop? That way it won't get lost.


🧠 Learnings used
Learnt from: kevalmorabia97
Repo: NVIDIA/Model-Optimizer PR: 974
File: modelopt/torch/puzzletron/pruning/pruning_utils.py:195-201
Timestamp: 2026-03-06T14:26:21.439Z
Learning: In NVIDIA/Model-Optimizer, for PyTorch >= 2.6, torch.load() calls without an explicit weights_only argument are safe. Do not flag bare torch.load(...) as a security issue in files under the modelopt package (e.g., modelopt/torch/puzzletron/pruning/pruning_utils.py) as long as the PyTorch version constraint is maintained. If supporting PyTorch < 2.6, require an explicit weights_only argument to torch.load() to avoid potential issues.

# Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64
if target_data_batch < 2:
Expand Down Expand Up @@ -601,28 +633,16 @@ def _forward_loop(
dataloader: DataLoader containing the batched input data
allowed_non_tensor_keys: Set of key names whose values may be non-tensor types
"""
# Disable KV caching during calibration — it is unnecessary overhead and causes
# correctness issues with hybrid Mamba/attention models whose cache state is mutated
# in-place (e.g., NemotronH).
config = getattr(model, "config", None)
prev_use_cache = getattr(config, "use_cache", None)
if config is not None and prev_use_cache is not None:
config.use_cache = False
with _disable_use_cache(model), torch.no_grad():
is_enc_dec = model_type_is_enc_dec(model)
infer_method = model.generate if is_enc_dec else model.forward
max_working_batch_size = None # Initialize max working batch size as None

try:
with torch.no_grad():
is_enc_dec = model_type_is_enc_dec(model)
infer_method = model.generate if is_enc_dec else model.forward
max_working_batch_size = None # Initialize max working batch size as None

for _, data in enumerate(tqdm(dataloader)):
# Process batch and update max working batch size
max_working_batch_size = _process_batch(
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
)
finally:
if config is not None and prev_use_cache is not None:
config.use_cache = prev_use_cache
for _, data in enumerate(tqdm(dataloader)):
# Process batch and update max working batch size
max_working_batch_size = _process_batch(
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
)


def create_forward_loop(
Expand Down
88 changes: 87 additions & 1 deletion tests/unit/torch/utils/test_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

import pytest
import torch
from torch.utils.data import DataLoader

from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples
from modelopt.torch.utils.dataset_utils import (
_disable_use_cache,
_forward_loop,
_process_batch,
get_dataset_samples,
)


def setup_test_data():
Expand Down Expand Up @@ -145,6 +151,86 @@ def mock_infer(**kwargs):
_process_batch(batch_data, mock_infer, allowed_non_tensor_keys={"base_model_outputs"})


class _Config:
"""Minimal config stand-in; instances start with no `use_cache` attribute."""


def test_disable_use_cache_no_config_attr():
"""Model without a `config` attribute: CM is a no-op and does not raise."""
model = torch.nn.Linear(4, 4)
assert not hasattr(model, "config")

with _disable_use_cache(model):
assert not hasattr(model, "config")

assert not hasattr(model, "config")


@pytest.mark.parametrize("prev_value", [True, False])
def test_disable_use_cache_with_existing_attr(prev_value):
"""Config that already has `use_cache`: forced to False inside, restored on exit."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
model.config.use_cache = prev_value

with _disable_use_cache(model):
assert model.config.use_cache is False

assert model.config.use_cache is prev_value


def test_disable_use_cache_without_existing_attr():
"""Config that lacks `use_cache`: set to False inside, attribute removed on exit (no leak)."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
assert not hasattr(model.config, "use_cache")

with _disable_use_cache(model):
assert model.config.use_cache is False

assert not hasattr(model.config, "use_cache")


def test_forward_loop_runs_under_disabled_use_cache():
"""`_forward_loop` runs forward on every batch and restores `use_cache` on exit."""
seen_use_cache: list[bool] = []

class _Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.config = _Config()
self.config.use_cache = True

def forward(self, **kwargs):
seen_use_cache.append(self.config.use_cache)

model = _Model()

def _collate(samples):
return {"input_ids": torch.stack([s["input_ids"] for s in samples])}

data = [{"input_ids": torch.zeros(8, dtype=torch.long)} for _ in range(3)]
loader = DataLoader(data, batch_size=1, collate_fn=_collate)

_forward_loop(model, loader)

assert seen_use_cache == [False, False, False]
assert model.config.use_cache is True


def test_disable_use_cache_restores_on_exception():
"""Restore must run even if the with-block raises."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
model.config.use_cache = True

with pytest.raises(RuntimeError, match="boom"), _disable_use_cache(model):
assert model.config.use_cache is False
raise RuntimeError("boom")

assert model.config.use_cache is True


@pytest.mark.parametrize("test_local_path", [True, False])
def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path):
pytest.importorskip("datasets")
Expand Down
Loading