-
Notifications
You must be signed in to change notification settings - Fork 389
[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs #1324
Changes from all commits
eed2b86
9783c10
5d5398b
9e78129
56c5edb
8974812
80db86c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |
| import copy | ||
| import json | ||
| import os | ||
| from collections.abc import Callable | ||
| from collections.abc import Callable, Iterator | ||
| from contextlib import contextmanager, suppress | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any | ||
| from warnings import warn | ||
|
|
@@ -437,6 +438,36 @@ def get_supported_datasets() -> list[str]: | |
| return list(SUPPORTED_DATASET_CONFIG.keys()) | ||
|
|
||
|
|
||
| @contextmanager | ||
| def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: | ||
| """Set ``model.config.use_cache = False`` for the duration of the block. | ||
|
|
||
| KV caching is unwanted during calibration / memory-probe forward passes: | ||
| it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH) | ||
| the cache state is mutated in-place and breaks correctness. Setting | ||
| ``use_cache`` unconditionally (rather than only when it was already | ||
| present) also sidesteps configs that never assign the attribute at all | ||
| — e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward | ||
| code that reads ``self.config.use_cache`` would otherwise raise | ||
| ``AttributeError``. The prior value is restored on exit if one existed. | ||
| """ | ||
| config = getattr(model, "config", None) | ||
| if config is None: | ||
| yield | ||
| return | ||
| had_attr = hasattr(config, "use_cache") | ||
| prev = config.use_cache if had_attr else None | ||
| config.use_cache = False | ||
| try: | ||
| yield | ||
| finally: | ||
| if had_attr: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: When The finally:
if had_attr:
config.use_cache = prev
else:
delattr(config, "use_cache")Without this, after calling |
||
| config.use_cache = prev | ||
| else: | ||
| with suppress(AttributeError): | ||
| delattr(config, "use_cache") | ||
|
|
||
|
|
||
| def get_max_batch_size( | ||
| model: torch.nn.Module, | ||
| max_sample_length: int = 512, | ||
|
|
@@ -467,42 +498,43 @@ def _get_free_gpu_mem(): | |
| torch.ones([1, max_sample_length], dtype=torch.int32, device=model.device) * 100 | ||
| ) | ||
|
|
||
| # Calculate single batch inference with dummy input. | ||
| with torch.set_grad_enabled(enable_grad): | ||
| infer_method(sample_input_single_batch) | ||
| free_mem_after, max_allocated_after = _get_free_gpu_mem() | ||
| with _disable_use_cache(model): | ||
| # Calculate single batch inference with dummy input. | ||
| with torch.set_grad_enabled(enable_grad): | ||
| infer_method(sample_input_single_batch) | ||
| free_mem_after, max_allocated_after = _get_free_gpu_mem() | ||
|
|
||
| mem_diff_per_data_batch = ( | ||
| max( | ||
| (free_mem_before - free_mem_after), | ||
| (max_allocated_after - max_allocated_before), | ||
| mem_diff_per_data_batch = ( | ||
| max( | ||
| (free_mem_before - free_mem_after), | ||
| (max_allocated_after - max_allocated_before), | ||
| ) | ||
| * sample_memory_usage_ratio | ||
| ) | ||
| * sample_memory_usage_ratio | ||
| ) | ||
| if mem_diff_per_data_batch <= 0: | ||
| print( | ||
| "Warning: No measurable memory usage found for a single batch. " | ||
| "Falling back to batch_size=1." | ||
| if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case | ||
| print( # pragma: no cover | ||
| "Warning: No measurable memory usage found for a single batch. " | ||
| "Falling back to batch_size=1." | ||
| ) | ||
| target_data_batch = 1 # pragma: no cover | ||
| else: | ||
| target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) | ||
| target_input = sample_input_single_batch.expand( | ||
| [ | ||
| target_data_batch if index == 0 else dim | ||
| for index, dim in enumerate(sample_input_single_batch.shape) | ||
| ] | ||
| ) | ||
| target_data_batch = 1 | ||
| else: | ||
| target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) | ||
| target_input = sample_input_single_batch.expand( | ||
| [ | ||
| target_data_batch if index == 0 else dim | ||
| for index, dim in enumerate(sample_input_single_batch.shape) | ||
| ] | ||
| ) | ||
|
|
||
| # For some models on multi GPU, we observe the memory per batch is not a constant. | ||
| # So we just test the target batch size and make sure we do not go OOM. | ||
| while target_data_batch > 1: | ||
| with torch.set_grad_enabled(enable_grad): | ||
| try: | ||
| infer_method(target_input) | ||
| break | ||
| except torch.cuda.OutOfMemoryError: | ||
| target_data_batch = target_data_batch // 2 | ||
| # For some models on multi GPU, we observe the memory per batch is not a constant. | ||
| # So we just test the target batch size and make sure we do not go OOM. | ||
| while target_data_batch > 1: | ||
| with torch.set_grad_enabled(enable_grad): | ||
| try: | ||
| infer_method(target_input) | ||
| break | ||
| except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path | ||
| target_data_batch = target_data_batch // 2 # pragma: no cover | ||
|
|
||
|
Comment on lines
+522
to
538
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retry loop halves After an OOM, the loop updates 💡 Proposed fix- target_input = sample_input_single_batch.expand(
- [
- target_data_batch if index == 0 else dim
- for index, dim in enumerate(sample_input_single_batch.shape)
- ]
- )
-
# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
+ target_input = sample_input_single_batch.expand(
+ [
+ target_data_batch if index == 0 else dim
+ for index, dim in enumerate(sample_input_single_batch.shape)
+ ]
+ )
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError:
target_data_batch = target_data_batch // 2🤖 Prompt for AI Agents
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, but this is pre-existing — target_input is built outside the OOM retry loop on the parent commit too (verified via git show 204daaf^:modelopt/torch/utils/dataset_utils.py; the loop body is unchanged here, the lines only show in the diff because we re-indented to wrap them in with _disable_use_cache(model):). Will file a separate fix to move target_input = ...expand(...) inside the while loop. Out of scope for this PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🧠 Learnings used |
||
| # Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64 | ||
| if target_data_batch < 2: | ||
|
|
@@ -601,28 +633,16 @@ def _forward_loop( | |
| dataloader: DataLoader containing the batched input data | ||
| allowed_non_tensor_keys: Set of key names whose values may be non-tensor types | ||
| """ | ||
| # Disable KV caching during calibration — it is unnecessary overhead and causes | ||
| # correctness issues with hybrid Mamba/attention models whose cache state is mutated | ||
| # in-place (e.g., NemotronH). | ||
| config = getattr(model, "config", None) | ||
| prev_use_cache = getattr(config, "use_cache", None) | ||
| if config is not None and prev_use_cache is not None: | ||
| config.use_cache = False | ||
| with _disable_use_cache(model), torch.no_grad(): | ||
| is_enc_dec = model_type_is_enc_dec(model) | ||
| infer_method = model.generate if is_enc_dec else model.forward | ||
| max_working_batch_size = None # Initialize max working batch size as None | ||
|
|
||
| try: | ||
| with torch.no_grad(): | ||
| is_enc_dec = model_type_is_enc_dec(model) | ||
| infer_method = model.generate if is_enc_dec else model.forward | ||
| max_working_batch_size = None # Initialize max working batch size as None | ||
|
|
||
| for _, data in enumerate(tqdm(dataloader)): | ||
| # Process batch and update max working batch size | ||
| max_working_batch_size = _process_batch( | ||
| data, infer_method, max_working_batch_size, allowed_non_tensor_keys | ||
| ) | ||
| finally: | ||
| if config is not None and prev_use_cache is not None: | ||
| config.use_cache = prev_use_cache | ||
| for _, data in enumerate(tqdm(dataloader)): | ||
| # Process batch and update max working batch size | ||
| max_working_batch_size = _process_batch( | ||
| data, infer_method, max_working_batch_size, allowed_non_tensor_keys | ||
| ) | ||
|
|
||
|
|
||
| def create_forward_loop( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing tests: Please add unit tests for
_disable_use_cachecovering:configattribute (no-op)use_cache(restored on exit)use_cache(attribute should not persist after exit)These are simple to write with a mock
nn.Moduleand would directly validate the bug fix.