diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2fb86925334..6dbab609d2c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,8 @@ repos: entry: python tools/precommit/check_modelopt_recipes.py language: system files: ^modelopt_recipes/ + # configs/ contains reusable snippets (not full recipes) — skip recipe validation + exclude: ^modelopt_recipes/configs/ # Instructions to change license file if ever needed: # https://github.com/Lucas-C/pre-commit-hooks#removing-old-license-and-replacing-it-with-a-new-one diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d2369885431..2e3bf10709e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Changelog **New Features** +- Add composable ``$import`` system for recipe YAML configs, enabling reusable config snippets referenced via ``{$import: name}`` markers. All built-in PTQ recipes converted to use imports with shared snippets under ``modelopt_recipes/configs/`` (numeric formats, quant_cfg building blocks, presets). See :ref:`composable-imports`. - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. diff --git a/docs/source/guides/10_recipes.rst b/docs/source/guides/10_recipes.rst index 468a1d2d6ca..26f28afd756 100644 --- a/docs/source/guides/10_recipes.rst +++ b/docs/source/guides/10_recipes.rst @@ -54,14 +54,18 @@ A recipe contains two top-level sections: ``metadata`` and a type-specific configuration section (for example, ``quantize`` for PTQ recipes). These can live in a single YAML file or be split across files in a directory. +Recipes support two authoring styles: **inline** (all values written directly) +and **import-based** (reusable snippets referenced via ``$import``). Both +styles can be used in a single-file or directory layout. + Single-file format ------------------ -The simplest form is a single ``.yml`` or ``.yaml`` file. Here is a PTQ example: +The simplest form is a single ``.yaml`` file. -.. code-block:: yaml +**Inline style** — all config values are written directly: - # modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml +.. code-block:: yaml metadata: recipe_type: ptq @@ -81,11 +85,42 @@ The simplest form is a single ``.yml`` or ``.yaml`` file. Here is a PTQ example num_bits: e4m3 axis: - quantizer_name: '*[kv]_bmm_quantizer' - enable: true cfg: num_bits: e4m3 # ... standard exclusions omitted for brevity +**Import style** — the same recipe using reusable config snippets: + +.. code-block:: yaml + + imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + + metadata: + recipe_type: ptq + description: FP8 per-tensor weight and activation (W8A8), FP8 KV cache, max calibration. + + quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 + - quantizer_name: '*weight_quantizer' + cfg: + $import: fp8 + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: fp8 + - $import: default_disabled + +Both styles produce identical results at load time. The import style reduces +duplication when multiple recipes share the same numeric formats or exclusion +lists. See :ref:`composable-imports` below for the full ``$import`` specification. + Directory format ---------------- @@ -96,10 +131,10 @@ example: .. code-block:: text my_recipe/ - recipe.yml # metadata section - quantize.yml # quantize section (quant_cfg + algorithm) + recipe.yaml # metadata section (+ optional imports) + quantize.yaml # quantize section (+ optional imports) -``recipe.yml``: +``recipe.yaml``: .. code-block:: yaml @@ -107,7 +142,7 @@ example: recipe_type: ptq description: My custom NVFP4 recipe. -``quantize.yml``: +``quantize.yaml``: .. code-block:: yaml @@ -124,6 +159,160 @@ example: num_bits: e4m3 axis: +Both inline and import styles work with the directory format. Any YAML file +in the directory can have its own ``imports`` section — ``recipe.yaml``, +``quantize.yaml``, or any other config file. + +.. _composable-imports: + +Composable imports +------------------ + +Recipes can import **reusable config snippets** via the ``imports`` section. +This eliminates duplication — numeric format definitions and standard exclusion +lists are authored once and referenced by name across recipes. + +The ``imports`` section is a dict mapping short names to config file paths. +References use the explicit ``{$import: name}`` marker so they are never +confused with literal values. + +.. note:: + + ``imports`` (no ``$``) is a **top-level structural section** — like + ``metadata`` or ``quantize``, it declares the recipe's dependencies. + ``$import`` (with ``$``) is an **inline directive** that appears inside + data values and gets resolved at load time. + +The ``$import`` marker can appear anywhere in the recipe: + +- As a **dict value** — the marker is replaced with the snippet content. +- As a **list element** — the snippet (which must itself be a list) is spliced + into the surrounding list. + +As a **dict value**, ``$import`` supports composition with clear override +precedence (lowest to highest): + +1. **Imports in list order** — ``$import: [base, override]``: later snippets + override earlier ones on key conflicts. +2. **Inline keys** — extra keys alongside ``$import`` override all imported + values. + +This is equivalent to calling ``dict.update()`` in order: imports first (in +list order), then inline keys last. + +.. code-block:: yaml + + # Single import + cfg: + $import: nvfp4 + + # Import + override — import nvfp4, then override type inline + cfg: + $import: nvfp4 # imports {num_bits: e2m1, block_sizes: {-1: 16, type: dynamic, ...}} + block_sizes: + -1: 16 + type: static # overrides type: dynamic → static calibration + + # Multiple imports — later snippet overrides earlier on conflict + cfg: + $import: [base_format, kv_tweaks] # kv_tweaks wins on shared keys + + # All three: multi-import + inline override + cfg: + $import: [bits, scale] + axis: 0 # highest precedence + +As a **list element**, ``$import`` must be the only key — extra keys alongside +a list splice are not supported. + +.. code-block:: yaml + + imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + + metadata: + recipe_type: ptq + description: FP8 W8A8, FP8 KV cache. + + quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all # spliced from a single-element list snippet + - quantizer_name: '*weight_quantizer' + cfg: + $import: fp8 # cfg value replaced with imported dict + - $import: default_disabled # spliced from a multi-element list snippet + +In this example: + +- ``$import: base_disable_all`` and ``$import: default_disabled`` are **list elements** + — their snippets (YAML lists) are spliced into ``quant_cfg``. +- ``$import: fp8`` under ``cfg`` is a **dict value** — the snippet (a YAML dict of + quantizer attributes) replaces the ``cfg`` field. + +Import paths are resolved via :func:`~modelopt.recipe.load_config` — the +built-in ``modelopt_recipes/`` library is checked first, then the filesystem. + +**Recursive imports:** An imported snippet may itself contain an ``imports`` +section. Each file's imports are scoped to that file — the same name can be +used in different files without conflict. Circular imports are detected and +raise ``ValueError``. + +Multi-document snippets +^^^^^^^^^^^^^^^^^^^^^^^ + +Dict-valued snippets (e.g., numeric format definitions) can use ``imports`` +directly because the ``imports`` key and the snippet content are both part of +the same YAML mapping. List-valued snippets have a problem: YAML only allows +one root node per document, so a file cannot be both a mapping (for +``imports``) and a list (for entries) at the same time. + +The solution is **multi-document YAML**: the first document holds the +``imports``, and the second document (after ``---``) holds the list content. +The loader parses both documents, resolves ``$import`` markers in the content, +and returns the resolved list: + +.. code-block:: yaml + + # configs/ptq/units/fp8_kv.yaml — list snippet that imports a dict snippet + imports: + fp8: configs/numerics/fp8 + --- + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: fp8 + +This enables full composability — list snippets can reference dict snippets, +dict snippets can reference other dict snippets, and recipes can reference +any of them. All import resolution happens at load time with the same +precedence rules. + +Built-in config snippets +^^^^^^^^^^^^^^^^^^^^^^^^ + +Reusable snippets are stored under ``modelopt_recipes/configs/``: + +.. list-table:: + :header-rows: 1 + :widths: 45 55 + + * - Snippet path + - Description + * - ``configs/numerics/fp8`` + - FP8 E4M3 quantizer attributes + * - ``configs/numerics/nvfp4`` + - NVFP4 E2M1 blockwise, dynamic calibration, FP8 scales (default) + * - ``configs/numerics/nvfp4_static`` + - NVFP4 E2M1 blockwise, static calibration, FP8 scales + * - ``configs/ptq/units/base_disable_all`` + - Disable all quantizers (deny-all-then-configure pattern) + * - ``configs/ptq/units/default_disabled_quantizers`` + - Standard exclusions (LM head, routers, BatchNorm, etc.) + * - ``configs/ptq/units/fp8_kv`` + - FP8 E4M3 KV cache quantization (multi-document, imports ``fp8``) + Metadata section ================ @@ -287,7 +476,7 @@ type depends on the ``recipe_type`` in the metadata: .. code-block:: python # Load a custom recipe from the filesystem (file or directory) - recipe = load_recipe("/path/to/my_custom_recipe.yml") + recipe = load_recipe("/path/to/my_custom_recipe.yaml") # or: recipe = load_recipe("/path/to/my_recipe_dir/") Command-line usage @@ -341,7 +530,7 @@ This means built-in recipes can be referenced without any prefix: # These are all equivalent: load_recipe("general/ptq/fp8_default-fp8_kv") - load_recipe("general/ptq/fp8_default-fp8_kv.yml") + load_recipe("general/ptq/fp8_default-fp8_kv.yaml") Writing a custom recipe @@ -355,11 +544,15 @@ To create a custom recipe: 3. Update the ``metadata.description`` to describe your changes. 4. Save the file (or directory) and pass its path to ``load_recipe()`` or ``--recipe``. -Example -- creating a custom PTQ recipe (INT8 per-channel): +Example -- creating a custom PTQ recipe using imports: .. code-block:: yaml - # my_int8_recipe.yml + # my_int8_recipe.yaml + imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled: configs/ptq/units/default_disabled_quantizers + metadata: recipe_type: ptq description: INT8 per-channel weight, per-tensor activation. @@ -367,8 +560,7 @@ Example -- creating a custom PTQ recipe (INT8 per-channel): quantize: algorithm: max quant_cfg: - - quantizer_name: '*' - enable: false + - $import: base_disable_all - quantizer_name: '*weight_quantizer' cfg: num_bits: 8 @@ -377,10 +569,11 @@ Example -- creating a custom PTQ recipe (INT8 per-channel): cfg: num_bits: 8 axis: - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*output_layer*' - enable: false + - $import: default_disabled + +The built-in snippets (``base_disable_all``, ``default_disabled``) handle the +deny-all prefix and standard exclusions. Only the format-specific entries need +to be written inline. Recipe repository layout @@ -394,15 +587,31 @@ The ``modelopt_recipes/`` package is organized as follows: +-- __init__.py +-- general/ # Model-agnostic recipes | +-- ptq/ - | +-- fp8_default-fp8_kv.yml - | +-- nvfp4_default-fp8_kv.yml - | +-- nvfp4_mlp_only-fp8_kv.yml - | +-- nvfp4_experts_only-fp8_kv.yml - | +-- nvfp4_omlp_only-fp8_kv.yml + | +-- fp8_default-fp8_kv.yaml + | +-- nvfp4_default-fp8_kv.yaml + | +-- nvfp4_mlp_only-fp8_kv.yaml + | +-- nvfp4_experts_only-fp8_kv.yaml + | +-- nvfp4_omlp_only-fp8_kv.yaml +-- models/ # Model-specific recipes | +-- Step3.5-Flash/ | +-- nvfp4-mlp-only.yaml - +-- configs/ # Shared configuration fragments + +-- configs/ # Reusable config snippets (imported via $import) + +-- numerics/ # Numeric format definitions + | +-- fp8.yaml + | +-- nvfp4_static.yaml + | +-- nvfp4.yaml + +-- ptq/ + +-- units/ # Reusable quant_cfg building blocks + | +-- base_disable_all.yaml + | +-- default_disabled_quantizers.yaml + | +-- fp8_kv.yaml + | +-- w8a8_fp8_fp8.yaml + | +-- w4a4_nvfp4_nvfp4.yaml + +-- presets/ # Complete configs (backward compat with *_CFG dicts) + +-- model/ + | +-- fp8.yaml + +-- kv/ + +-- fp8.yaml Recipe data model diff --git a/modelopt/recipe/_config_loader.py b/modelopt/recipe/_config_loader.py index 188dcf236f7..1abbd36c98d 100644 --- a/modelopt/recipe/_config_loader.py +++ b/modelopt/recipe/_config_loader.py @@ -13,101 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""YAML config loading utilities. - -This module is intentionally free of ``modelopt.torch`` imports so that -``modelopt.torch.quantization.config`` can import :func:`load_config` without -triggering a circular import through ``modelopt.recipe.loader``. -""" - -from importlib.resources import files - -try: - from importlib.resources.abc import Traversable -except ImportError: # Python < 3.11 - from importlib.abc import Traversable -import re -from pathlib import Path -from typing import Any - -import yaml - -# Root to all built-in recipes. Users can create own recipes. -BUILTIN_RECIPES_LIB = files("modelopt_recipes") - -_EXMY_RE = re.compile(r"^[Ee](\d+)[Mm](\d+)$") -_EXMY_KEYS = frozenset({"num_bits", "scale_bits"}) - - -def _parse_exmy_num_bits(obj: Any) -> Any: - """Recursively convert ``ExMy`` strings in ``num_bits`` / ``scale_bits`` to ``(x, y)`` tuples.""" - if isinstance(obj, dict): - return { - k: ( - _parse_exmy(v) - if k in _EXMY_KEYS and isinstance(v, str) - else _parse_exmy_num_bits(v) - ) - for k, v in obj.items() - } - if isinstance(obj, list): - return [_parse_exmy_num_bits(item) for item in obj] - return obj - - -def _parse_exmy(s: str) -> tuple[int, int] | str: - m = _EXMY_RE.match(s) - if m: - return (int(m.group(1)), int(m.group(2))) - return s - - -def load_config(config_file: str | Path | Traversable) -> dict[str, Any]: - """Load a config yaml. - - config_file: Path to a config yaml file. The path suffix can be omitted. - """ - paths_to_check: list[Path | Traversable] = [] - if isinstance(config_file, str): - if not config_file.endswith(".yml") and not config_file.endswith(".yaml"): - paths_to_check.append(Path(f"{config_file}.yml")) - paths_to_check.append(Path(f"{config_file}.yaml")) - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml")) - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml")) - else: - paths_to_check.append(Path(config_file)) - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(config_file)) - elif isinstance(config_file, Path): - if config_file.suffix in (".yml", ".yaml"): - paths_to_check.append(config_file) - if not config_file.is_absolute(): - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(str(config_file))) - else: - paths_to_check.append(Path(f"{config_file}.yml")) - paths_to_check.append(Path(f"{config_file}.yaml")) - if not config_file.is_absolute(): - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml")) - paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml")) - elif isinstance(config_file, Traversable): - paths_to_check.append(config_file) - else: - raise ValueError(f"Invalid config file of {config_file}") - - config_path = None - for path in paths_to_check: - if path.is_file(): - config_path = path - break - if not config_path: - raise ValueError( - f"Cannot find config file of {config_file}, paths checked: {paths_to_check}" - ) - - _raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) - if _raw is None: - return {} - if not isinstance(_raw, dict): - raise ValueError( - f"Config file {config_path} must contain a YAML mapping, got {type(_raw).__name__}" - ) - return _parse_exmy_num_bits(_raw) +"""Re-export config loading utilities from ``modelopt.torch.opt.config_loader``.""" + +from modelopt.torch.opt.config_loader import ( + BUILTIN_CONFIG_ROOT, + _load_raw_config, + _resolve_imports, + load_config, +) + +BUILTIN_RECIPES_LIB = BUILTIN_CONFIG_ROOT + +__all__ = [ + "BUILTIN_CONFIG_ROOT", + "BUILTIN_RECIPES_LIB", + "_load_raw_config", + "_resolve_imports", + "load_config", +] diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 3a9c66fb22d..2a0fbdf5268 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -20,8 +20,9 @@ except ImportError: # Python < 3.11 from importlib.abc import Traversable from pathlib import Path +from typing import Any -from ._config_loader import BUILTIN_RECIPES_LIB, load_config +from ._config_loader import BUILTIN_RECIPES_LIB, _load_raw_config, _resolve_imports, load_config from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeType __all__ = ["load_config", "load_recipe"] @@ -86,7 +87,15 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas The file must contain a ``metadata`` section with at least ``recipe_type``, plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. """ - data = load_config(recipe_file) + raw = _load_raw_config(recipe_file) + if not isinstance(raw, dict): + raise ValueError( + f"Recipe file {recipe_file} must be a YAML mapping, got {type(raw).__name__}." + ) + data = _resolve_imports(raw) + assert isinstance(data, dict), ( + f"Recipe file {recipe_file} resolved to {type(data).__name__}; expected dict." + ) metadata = data.get("metadata", {}) recipe_type = metadata.get("recipe_type") @@ -105,7 +114,24 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: - """Load a recipe from a directory containing ``recipe.yml`` and ``quantize.yml``.""" + """Load a recipe from a directory containing ``recipe.yml`` and ``quantize.yml``. + + Import resolution is deliberately two-pass: + + 1. ``quantize.yaml`` is resolved first against its own ``imports:`` section + (if any). After this pass, every ``$import`` that references a + ``quantize.yaml``-declared name is already expanded. + 2. The resolved ``quantize`` dict is then wrapped under a ``quantize:`` key + and merged with ``recipe.yaml``'s ``imports:``, and ``_resolve_imports`` + is called again. This second pass only fires ``$import`` markers that + name imports declared in ``recipe.yaml`` — which, by step 1, cannot + alias a ``quantize.yaml`` import name. + + Practical consequence: each file's ``imports:`` section defines names + scoped to that file; there is no cross-file import sharing. If + ``recipe.yaml`` and ``quantize.yaml`` both declare an import with the + same name but different paths, each file sees only its own. + """ recipe_file = None for name in ("recipe.yml", "recipe.yaml"): candidate = recipe_dir.joinpath(name) @@ -117,7 +143,12 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: f"Cannot find a recipe descriptor in {recipe_dir}. Looked for: recipe.yml, recipe.yaml" ) - metadata = load_config(recipe_file).get("metadata", {}) + recipe_data = _load_raw_config(recipe_file) + if not isinstance(recipe_data, dict): + raise ValueError( + f"Recipe file {recipe_file} must be a YAML mapping, got {type(recipe_data).__name__}." + ) + metadata = recipe_data.get("metadata", {}) recipe_type = metadata.get("recipe_type") if recipe_type is None: raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") @@ -133,9 +164,32 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: raise ValueError( f"Cannot find quantize in {recipe_dir}. Looked for: quantize.yml, quantize.yaml" ) + # Resolve imports from both recipe.yaml and quantize.yaml + quantize_data = _load_raw_config(quantize_file) + if not isinstance(quantize_data, dict): + raise ValueError( + f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." + ) + # Resolve quantize.yaml's own imports first (if any) + if "imports" in quantize_data: + resolved = _resolve_imports(quantize_data) + assert isinstance(resolved, dict), ( + f"{quantize_file} resolved to {type(resolved).__name__}; expected dict." + ) + quantize_data = resolved + # Then resolve recipe.yaml's imports applied to the quantize data + combined: dict[str, Any] = {"quantize": quantize_data} + imports = recipe_data.get("imports") + if imports: + combined["imports"] = imports + resolved = _resolve_imports(combined) + assert isinstance(resolved, dict), ( + f"Recipe {recipe_dir} resolved to {type(resolved).__name__}; expected dict." + ) + combined = resolved return ModelOptPTQRecipe( recipe_type=RecipeType.PTQ, description=metadata.get("description", "PTQ recipe."), - quantize=load_config(quantize_file), + quantize=combined["quantize"], ) raise ValueError(f"Unsupported recipe type: {recipe_type!r}") diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py new file mode 100644 index 00000000000..58770233c53 --- /dev/null +++ b/modelopt/torch/opt/config_loader.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose YAML config loading with ``$import`` resolution. + +This module provides the config loading infrastructure used by both +``modelopt.recipe`` and ``modelopt.torch.quantization.config``. It lives +in ``modelopt.torch.opt`` (the lowest dependency layer) to avoid circular +imports. +""" + +from dataclasses import dataclass, field +from importlib.resources import files + +try: + from importlib.resources.abc import Traversable +except ImportError: # Python < 3.11 + from importlib.abc import Traversable +import re +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass +class _ListSnippet: + """Multi-document YAML: a header dict (with optional ``imports:``) + a list body. + + YAML requires one root node per document, so a file that is "a list with an + ``imports`` section" has to use two documents separated by ``---``. This + wrapper is the internal transport carrying both pieces from + :func:`_load_raw_config` to :func:`_resolve_imports` without smuggling them + through a sentinel dict key (which would collide if a user happened to + choose the same key name). + """ + + imports: dict[str, Any] = field(default_factory=dict) + content: list[Any] = field(default_factory=list) + + +# Root to all built-in configs and recipes. +BUILTIN_CONFIG_ROOT = files("modelopt_recipes") + +_EXMY_RE = re.compile(r"^[Ee](\d+)[Mm](\d+)$") +_EXMY_KEYS = frozenset({"num_bits", "scale_bits"}) + + +def _parse_exmy_num_bits(obj: Any) -> Any: + """Recursively convert ``ExMy`` strings in ``num_bits`` / ``scale_bits`` to ``(x, y)`` tuples.""" + if isinstance(obj, dict): + return { + k: ( + _parse_exmy(v) + if k in _EXMY_KEYS and isinstance(v, str) + else _parse_exmy_num_bits(v) + ) + for k, v in obj.items() + } + if isinstance(obj, list): + return [_parse_exmy_num_bits(item) for item in obj] + return obj + + +def _parse_exmy(s: str) -> tuple[int, int] | str: + m = _EXMY_RE.match(s) + if m: + return (int(m.group(1)), int(m.group(2))) + return s + + +def _resolve_config_path(config_file: str | Path | Traversable) -> Path | Traversable: + """Probe the filesystem and built-in library to locate a config file. + + Return type mirrors the input family: filesystem paths return ``Path``; + built-in package resources return a ``Traversable``. Raises ``ValueError`` + if no candidate exists. + + Factored out of :func:`_load_raw_config` so :func:`_resolve_imports` can + compute a canonical cycle-detection key without reading the file twice. + """ + # Probe order: filesystem first, then built-in library. + # This lets users override built-in configs by placing a file locally. + paths_to_check: list[Path | Traversable] = [] + if isinstance(config_file, str): + if not config_file.endswith(".yml") and not config_file.endswith(".yaml"): + paths_to_check.append(Path(f"{config_file}.yml")) + paths_to_check.append(Path(f"{config_file}.yaml")) + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yml")) + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yaml")) + else: + paths_to_check.append(Path(config_file)) + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(config_file)) + elif isinstance(config_file, Path): + if config_file.suffix in (".yml", ".yaml"): + paths_to_check.append(config_file) + if not config_file.is_absolute(): + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(str(config_file))) + else: + paths_to_check.append(Path(f"{config_file}.yml")) + paths_to_check.append(Path(f"{config_file}.yaml")) + if not config_file.is_absolute(): + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yml")) + paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yaml")) + elif isinstance(config_file, Traversable): + paths_to_check.append(config_file) + else: + raise ValueError(f"Invalid config file of {config_file}") + + for path in paths_to_check: + if path.is_file(): + return path + raise ValueError(f"Cannot find config file of {config_file}, paths checked: {paths_to_check}") + + +def _canonical_key(path: Path | Traversable) -> str: + """Stable cycle-detection key for :func:`_resolve_imports`. + + Filesystem paths are resolved (``Path.resolve()``) so that aliases like + ``foo/bar``, ``./foo/bar``, and their absolute form produce the same key. + Built-in ``Traversable`` resources are already canonical — their ``str()`` + points into the installed package. + """ + if isinstance(path, Path): + try: + return str(path.resolve()) + except OSError: + return str(path) + return str(path) + + +def _load_raw_config( + config_file: str | Path | Traversable, +) -> dict[str, Any] | list[Any] | _ListSnippet: + """Load a config YAML without resolving ``$import`` references. + + config_file: Path to a config yaml file. The path suffix can be omitted. + + Return type: + * ``dict`` — single-document or two-document-dict YAML. + * ``list`` — single-document list YAML. + * :class:`_ListSnippet` — two-document YAML with a list body; + carries the header's ``imports`` alongside the list content. + """ + config_path = _resolve_config_path(config_file) + text = config_path.read_text(encoding="utf-8") + docs = list(yaml.safe_load_all(text)) + + if len(docs) == 0 or docs[0] is None: + return {} + if len(docs) == 1: + _raw = docs[0] + elif len(docs) == 2: + # Multi-document: first doc is imports/metadata, second is content. + # Merge the imports into the content for downstream resolution. + header, content = docs[0], docs[1] + if not isinstance(header, dict): + raise ValueError( + f"Config file {config_path}: first YAML document must be a mapping, " + f"got {type(header).__name__}" + ) + if content is None: + content = {} + if isinstance(content, dict): + _raw = {**header, **content} + elif isinstance(content, list): + # List body with a header dict (for declaring ``imports:``). + # Only ``imports`` from the header is carried forward; any other + # header keys are meaningless alongside a list body. + imports = header.get("imports", {}) or {} + return _ListSnippet( + imports=imports, + content=_parse_exmy_num_bits(content), + ) + else: + raise ValueError( + f"Config file {config_path}: second YAML document must be a mapping or list, " + f"got {type(content).__name__}" + ) + else: + raise ValueError( + f"Config file {config_path}: expected 1 or 2 YAML documents, got {len(docs)}" + ) + + if not isinstance(_raw, (dict, list)): + raise ValueError( + f"Config file {config_path} must contain a YAML mapping or list, got {type(_raw).__name__}" + ) + return _parse_exmy_num_bits(_raw) + + +# --------------------------------------------------------------------------- +# $import resolution +# --------------------------------------------------------------------------- + +_IMPORT_KEY = "$import" + + +def _resolve_imports( + data: dict[str, Any] | _ListSnippet, _loading: frozenset[str] | None = None +) -> dict[str, Any] | list[Any]: + """Resolve the ``imports`` section and ``$import`` references. + + Accepts either a raw dict (with optional top-level ``imports:``) or a + :class:`_ListSnippet` (a list body carrying its own ``imports``). Returns + a dict for the former and a list for the latter — the imports section is + consumed. + + See ``modelopt.recipe.loader`` module docstring for the full specification. + This function lives in ``_config_loader`` (not ``loader``) so that it can be + used from ``modelopt.torch.quantization.config`` without circular imports. + """ + if isinstance(data, _ListSnippet): + imports_dict = data.imports + body: dict[str, Any] | list[Any] = data.content + else: + imports_dict = data.get("imports") + body = {k: v for k, v in data.items() if k != "imports"} + + if not imports_dict: + return body + + if not isinstance(imports_dict, dict): + raise ValueError( + f"'imports' must be a dict mapping names to config paths, got: {type(imports_dict).__name__}" + ) + + if _loading is None: + _loading = frozenset() + + # Build name → config mapping (recursively resolve nested imports). + # Cycle detection uses the *resolved* file path as the key so that aliases + # such as ``foo/bar``, ``./foo/bar``, and its absolute form all map to the + # same cycle entry. + import_map: dict[str, Any] = {} + for name, config_path in imports_dict.items(): + if not config_path: + raise ValueError(f"Import {name!r} has an empty config path.") + resolved_path = _resolve_config_path(config_path) + cycle_key = _canonical_key(resolved_path) + if cycle_key in _loading: + raise ValueError( + f"Circular import detected: {config_path!r} (resolves to " + f"{cycle_key!r}) is already being loaded. " + f"Import chain: {sorted(_loading)}" + ) + snippet = _load_raw_config(config_path) + if isinstance(snippet, _ListSnippet) or ( + isinstance(snippet, dict) and "imports" in snippet + ): + snippet = _resolve_imports(snippet, _loading | {cycle_key}) + import_map[name] = snippet + + def _lookup(ref_name: str, context: str) -> Any: + if ref_name not in import_map: + raise ValueError( + f"Unknown $import reference {ref_name!r} in {context}. " + f"Available imports: {list(import_map.keys())}" + ) + return import_map[ref_name] + + def _resolve_value(obj: Any) -> Any: + """Recursively resolve ``$import`` markers anywhere in the config tree. + + - Dict with ``$import`` as only key and list value → splice (in list context) + - Dict with ``$import`` key → replace/merge (import + override with inline keys) + - List → resolve each element (with list-splice for ``$import`` entries) + - Other → return as-is + """ + if isinstance(obj, dict): + if _IMPORT_KEY in obj: + # {$import: name, ...inline} → import, merge, override. + # Read without mutating ``obj`` so _resolve_value stays pure and + # idempotent — double resolution must be a no-op on the first + # result, not silently corrupt it. + ref = obj[_IMPORT_KEY] + inline_keys = {k: v for k, v in obj.items() if k != _IMPORT_KEY} + ref_names = ref if isinstance(ref, list) else [ref] + + merged: dict[str, Any] = {} + for rname in ref_names: + snippet = _lookup(rname, "dict value") + if not isinstance(snippet, dict): + raise ValueError( + f"$import {rname!r} in dict must resolve to a dict, " + f"got {type(snippet).__name__}." + ) + merged.update(snippet) + + merged.update(inline_keys) + return _resolve_value(merged) # resolve any nested $import in result + else: + return {k: _resolve_value(v) for k, v in obj.items()} + elif isinstance(obj, list): + resolved: list[Any] = [] + for entry in obj: + if isinstance(entry, dict) and _IMPORT_KEY in entry and len(entry) == 1: + # {$import: name} as sole key in list → splice + imported = _lookup(entry[_IMPORT_KEY], "list entry") + if not isinstance(imported, list): + raise ValueError( + f"$import {entry[_IMPORT_KEY]!r} in list must resolve to a " + f"list, got {type(imported).__name__}." + ) + resolved.extend(_resolve_value(imported)) + else: + resolved.append(_resolve_value(entry)) + return resolved + return obj + + return _resolve_value(body) + + +def load_config(config_path: str | Path | Traversable) -> dict[str, Any] | list[Any]: + """Load a YAML config and resolve all ``$import`` references. + + This is the primary config loading entry point. It loads the YAML file, + resolves any ``imports`` / ``$import`` directives, and returns the final + config dict or list. + """ + data = _load_raw_config(config_path) + if isinstance(data, _ListSnippet) or (isinstance(data, dict) and "imports" in data): + data = _resolve_imports(data) + return data diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 34e7f692ca0..dc2859dd363 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -158,6 +158,7 @@ from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike @@ -203,35 +204,11 @@ def find_quant_cfg_entry_by_path( return result -_base_disable_all: list[QuantizerCfgEntry] = [ - {"quantizer_name": "*", "enable": False}, -] +_base_disable_all: list[QuantizerCfgEntry] = load_config("configs/ptq/units/base_disable_all") -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ - {"parent_class": "nn.BatchNorm1d", "quantizer_name": "*", "enable": False}, - {"parent_class": "nn.BatchNorm2d", "quantizer_name": "*", "enable": False}, - {"parent_class": "nn.BatchNorm3d", "quantizer_name": "*", "enable": False}, - {"parent_class": "nn.LeakyReLU", "quantizer_name": "*", "enable": False}, - {"quantizer_name": "*lm_head*", "enable": False}, - { - "quantizer_name": "*proj_out.*", - "enable": False, - }, # In Whisper model, lm_head has key name proj_out - { - "quantizer_name": "*block_sparse_moe.gate*", - "enable": False, - }, # Skip the MOE router - {"quantizer_name": "*router*", "enable": False}, # Skip the MOE router - {"quantizer_name": "*mlp.gate.*", "enable": False}, # Skip the MOE router - { - "quantizer_name": "*mlp.shared_expert_gate.*", - "enable": False, - }, # Skip the MOE router - {"quantizer_name": "*linear_attn.conv1d*", "enable": False}, - {"quantizer_name": "*mixer.conv1d*", "enable": False}, # Skip mamba conv1d - {"quantizer_name": "*output_layer*", "enable": False}, - {"quantizer_name": "output.*", "enable": False}, -] +_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( + "configs/ptq/units/default_disabled_quantizers" +) _mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE @@ -280,21 +257,7 @@ def find_quant_cfg_entry_by_path( "algorithm": "max", } -FP8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} +FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") MAMBA_MOE_FP8_AGGRESSIVE_CFG = { "quant_cfg": [ @@ -539,14 +502,7 @@ def find_quant_cfg_entry_by_path( # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG = { - "quant_cfg": [ - { - "quantizer_name": "*[kv]_bmm_quantizer", - "cfg": {"num_bits": (4, 3)}, - }, - ] -} +FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") FP8_AFFINE_KV_CFG = { "quant_cfg": [ diff --git a/modelopt_recipes/configs/numerics/fp8.yaml b/modelopt_recipes/configs/numerics/fp8.yaml new file mode 100644 index 00000000000..dec6c20a58c --- /dev/null +++ b/modelopt_recipes/configs/numerics/fp8.yaml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FP8 E4M3 quantizer attributes (per-tensor; used for weight/activation/KV). +# ``axis: null`` is explicit to match the hardcoded ``FP8_DEFAULT_CFG`` shape — +# downstream code that keys on ``"axis" in cfg`` sees the same dict layout. +num_bits: e4m3 +axis: diff --git a/modelopt_recipes/configs/numerics/nvfp4.yaml b/modelopt_recipes/configs/numerics/nvfp4.yaml new file mode 100644 index 00000000000..0639e51c140 --- /dev/null +++ b/modelopt_recipes/configs/numerics/nvfp4.yaml @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NVFP4 E2M1 blockwise quantizer attributes with FP8 E4M3 scales (dynamic calibration, the default). +num_bits: e2m1 +block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 diff --git a/modelopt_recipes/configs/numerics/nvfp4_static.yaml b/modelopt_recipes/configs/numerics/nvfp4_static.yaml new file mode 100644 index 00000000000..9dda0cae918 --- /dev/null +++ b/modelopt_recipes/configs/numerics/nvfp4_static.yaml @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NVFP4 E2M1 blockwise quantizer attributes with FP8 E4M3 scales (static calibration). +num_bits: e2m1 +block_sizes: + -1: 16 + type: static + scale_bits: e4m3 diff --git a/modelopt_recipes/configs/ptq/presets/README.md b/modelopt_recipes/configs/ptq/presets/README.md new file mode 100644 index 00000000000..402e8c5265d --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/README.md @@ -0,0 +1,32 @@ +# PTQ Preset Configs + +This directory holds preset quantization configurations that serve as the +YAML source of truth for the hardcoded `*_CFG` dicts in +`modelopt.torch.quantization.config` (e.g., `FP8_DEFAULT_CFG`, +`FP8_KV_CFG`). + +Presets compose from the reusable snippets in `configs/numerics/` and +`configs/ptq/units/` via the `$import` system, and are split into two +kinds: + +- **`model/`** — *full* quantization presets. Each file is a complete, + self-contained config (it sets `algorithm` and a full `quant_cfg` with + a base-disable-all prefix + standard exclusions) and can be passed + directly to `mtq.quantize()`. Example: `model/fp8.yaml` + (the YAML source of `FP8_DEFAULT_CFG`). +- **`kv/`** — *partial* KV-cache quantization fragments. Each file + contains only the KV-specific `quant_cfg` entries (no `algorithm`, no + base-disable-all). They are **not** standalone — they are designed to + be merged on top of a `model/` preset via `$import` to produce a + complete config. Example: `kv/fp8.yaml` (the YAML source of + `FP8_KV_CFG`). + +**Note:** The main purpose of these presets is to support the existing +`hf_ptq.py` script's `--qformat` / `--kv_cache_qformat` flags and other +code paths that reference +the hardcoded `*_CFG` dicts, maintaining backward compatibility during +the transition to recipe-based workflows. Users are encouraged to use +`load_recipe` with full recipe files under `general/` or `models/` +instead. Some or all of these presets may be deprecated or removed in +future releases as the recipe-based workflow becomes the standard entry +point. diff --git a/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml b/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml new file mode 100644 index 00000000000..f23ba541457 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FP8 E4M3 KV cache quantization preset. +# Equivalent to the hardcoded FP8_KV_CFG in config.py. +# This is a partial config (no algorithm, no base_disable_all) — designed +# to be merged with a primary model quantization config. +imports: + fp8_kv: configs/ptq/units/fp8_kv + +quant_cfg: + - $import: fp8_kv diff --git a/modelopt_recipes/configs/ptq/presets/model/fp8.yaml b/modelopt_recipes/configs/ptq/presets/model/fp8.yaml new file mode 100644 index 00000000000..3f7ef9f8606 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/fp8.yaml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FP8 per-tensor weight and activation (W8A8), max calibration. +# Equivalent to the hardcoded FP8_DEFAULT_CFG in config.py. +imports: + base_disable_all: configs/ptq/units/base_disable_all + w8a8_fp8_fp8: configs/ptq/units/w8a8_fp8_fp8 + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - $import: w8a8_fp8_fp8 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/units/README.md b/modelopt_recipes/configs/ptq/units/README.md new file mode 100644 index 00000000000..50cf028c15b --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/README.md @@ -0,0 +1,17 @@ +# PTQ Config Units + +Reusable building blocks for composing PTQ quantization configurations. +Each file defines one or more `quant_cfg` entries that can be imported +into recipes or presets via `$import`. + +Units are **not** standalone configs — they don't have `algorithm` or +`metadata`. They are meant to be composed into complete configs by +recipes (under `general/` or `models/`) or presets (under `presets/`). + +| File | Description | +|------|-------------| +| `base_disable_all.yaml` | Deny-all entry: disables all quantizers as the first step | +| `default_disabled_quantizers.yaml` | Standard exclusions (LM head, routers, BatchNorm, etc.) | +| `fp8_kv.yaml` | FP8 E4M3 KV cache quantizer entry | +| `w8a8_fp8_fp8.yaml` | FP8 weight + activation quantizer entries (W8A8) | +| `w4a4_nvfp4_nvfp4.yaml` | NVFP4 weight + activation quantizer entries (W4A4) | diff --git a/modelopt_recipes/configs/ptq/units/base_disable_all.yaml b/modelopt_recipes/configs/ptq/units/base_disable_all.yaml new file mode 100644 index 00000000000..35bdf2c6a45 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/base_disable_all.yaml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Disable all quantizers by default (deny-all-then-configure pattern). + + - quantizer_name: '*' + enable: false diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml new file mode 100644 index 00000000000..a8c04357d7d --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard quantizer exclusions: layers that should not be quantized. + + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/modelopt_recipes/configs/ptq/units/fp8_kv.yaml b/modelopt_recipes/configs/ptq/units/fp8_kv.yaml new file mode 100644 index 00000000000..85ff617ead4 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/fp8_kv.yaml @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FP8 E4M3 KV cache quantization. +# +# This snippet uses multi-document YAML (separated by ---) because it is a +# list-valued snippet that also needs to $import another snippet. YAML only +# allows one root node per document, so a file cannot be both a mapping +# (for imports) and a list (for entries). The first document holds the +# imports, the second holds the list content that references them. +imports: + fp8: configs/numerics/fp8 +--- + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: fp8 diff --git a/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml new file mode 100644 index 00000000000..2fc516e5dc1 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# W4A4 NVFP4: NVFP4 E2M1 dynamic weight and activation quantizers. +imports: + nvfp4: configs/numerics/nvfp4 +--- + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml b/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml new file mode 100644 index 00000000000..c55cbf1d6b9 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# W8A8 FP8: FP8 E4M3 weight and activation quantizers. +imports: + fp8: configs/numerics/fp8 +--- + - quantizer_name: '*weight_quantizer' + cfg: + $import: fp8 + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 diff --git a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yaml b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yaml index c80904e8eb9..85267c86726 100644 --- a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yaml @@ -13,55 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + w8a8_fp8_fp8: configs/ptq/units/w8a8_fp8_fp8 + fp8_kv: configs/ptq/units/fp8_kv + metadata: recipe_type: ptq description: FP8 per-tensor weight and activation (W8A8), FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: - - quantizer_name: '*' - enable: false - - quantizer_name: '*input_quantizer' - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*weight_quantizer' - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true - cfg: - num_bits: e4m3 - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + - $import: base_disable_all + - $import: w8a8_fp8_fp8 + - $import: fp8_kv + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml index 862929ef34c..36594d41e20 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml @@ -13,63 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + w4a4_nvfp4_nvfp4: configs/ptq/units/w4a4_nvfp4_nvfp4 + fp8_kv: configs/ptq/units/fp8_kv + metadata: recipe_type: ptq description: NVFP4 W4A4, FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: - - quantizer_name: '*' - enable: false - - quantizer_name: '*weight_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true - cfg: - num_bits: e4m3 - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + - $import: base_disable_all + - $import: w4a4_nvfp4_nvfp4 + - $import: fp8_kv + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml index 99098c9d6d0..6aabb04a150 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml @@ -13,6 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4_static: configs/numerics/nvfp4_static + nvfp4: configs/numerics/nvfp4 + metadata: recipe_type: ptq description: NVFP4 weight and activation (W4A4), gptq layerwise calibration. @@ -22,53 +28,13 @@ quantize: layerwise: true layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - - quantizer_name: '*' - enable: false + - $import: base_disable_all - quantizer_name: '*weight_quantizer' cfg: - block_sizes: - -1: 16 - type: static - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4_static - quantizer_name: '*input_quantizer' cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*[kv]_bmm_quantizer' enable: false - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index 7c557039631..619597c0288 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -13,6 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + fp8_kv: configs/ptq/units/fp8_kv + metadata: recipe_type: ptq description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. @@ -24,73 +30,18 @@ quantize: # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). layerwise: false quant_cfg: - - quantizer_name: '*' - enable: false + - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*mlp.experts*input_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true cfg: - num_bits: e4m3 - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + $import: nvfp4 + - $import: fp8_kv + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml index 0222274af03..9e300b25018 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml @@ -13,79 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + fp8_kv: configs/ptq/units/fp8_kv + metadata: recipe_type: ptq description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: - - quantizer_name: '*' - enable: false + - $import: base_disable_all - quantizer_name: '*mlp*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*mlp*input_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true cfg: - num_bits: e4m3 - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + $import: nvfp4 + - $import: fp8_kv + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yaml index 3fdd79888d5..2c83641137d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yaml @@ -13,95 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + fp8_kv: configs/ptq/units/fp8_kv + metadata: recipe_type: ptq description: NVFP4 static weight and dynamic activation for all linear layers including output projections, FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: - - quantizer_name: '*' - enable: false + - $import: base_disable_all - quantizer_name: '*mlp*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*mlp*input_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*block_sparse_moe*input_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*o_proj*weight_quantizer' - enable: true cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + $import: nvfp4 - quantizer_name: '*o_proj*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true cfg: - num_bits: e4m3 - - quantizer_name: '*block_sparse_moe.gate*' - enable: false - - quantizer_name: '*linear_attn.conv1d*' - enable: false - - quantizer_name: '*lm_head*' - enable: false - - quantizer_name: '*mixer.conv1d*' - enable: false - - quantizer_name: '*mlp.gate.*' - enable: false - - quantizer_name: '*mlp.shared_expert_gate.*' - enable: false - - quantizer_name: '*output_layer*' - enable: false - - quantizer_name: '*proj_out.*' - enable: false - - quantizer_name: '*router*' - enable: false - - quantizer_name: 'output.*' - enable: false - - parent_class: 'nn.BatchNorm1d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm2d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.BatchNorm3d' - quantizer_name: '*' - enable: false - - parent_class: 'nn.LeakyReLU' - quantizer_name: '*' - enable: false + $import: nvfp4 + - $import: fp8_kv + - $import: default_disabled_quantizers diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 6926d89a5d2..814f8985d84 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -110,6 +110,8 @@ def test_load_recipe_builtin_description(): "general/ptq/nvfp4_default-fp8_kv", "general/ptq/nvfp4_default-fp8_cast_kv", "general/ptq/nvfp4_default-nvfp4_cast_kv", + "general/ptq/nvfp4_default-none_kv_gptq", + "general/ptq/nvfp4_experts_only-fp8_kv", "general/ptq/nvfp4_mlp_only-fp8_kv", "general/ptq/nvfp4_omlp_only-fp8_kv", ] @@ -252,3 +254,838 @@ def _sort_key(entry): assert sorted(python_entries, key=_sort_key) == sorted(yaml_entries, key=_sort_key) assert model_cfg["algorithm"] == yaml_data["quantize"]["algorithm"] + + +# --------------------------------------------------------------------------- +# imports — named config snippet resolution +# --------------------------------------------------------------------------- + + +def test_import_resolves_cfg_reference(tmp_path): + """$import in cfg is replaced with the imported config dict.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + ) + recipe = load_recipe(recipe_file) + entry = recipe.quantize["quant_cfg"][0] + assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + + +def test_import_same_name_used_twice(tmp_path): + """The same import can be referenced in multiple quant_cfg entries.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + f" - quantizer_name: '*input_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + ) + recipe = load_recipe(recipe_file) + assert recipe.quantize["quant_cfg"][0]["cfg"] == recipe.quantize["quant_cfg"][1]["cfg"] + + +def test_import_multiple_snippets(tmp_path): + """Multiple imports with different names resolve independently.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + (tmp_path / "nvfp4.yml").write_text("num_bits: e2m1\nblock_sizes:\n -1: 16\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f" nvfp4: {tmp_path / 'nvfp4.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: nvfp4\n" + f" - quantizer_name: '*[kv]_bmm_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + ) + recipe = load_recipe(recipe_file) + assert recipe.quantize["quant_cfg"][0]["cfg"]["num_bits"] == (2, 1) + assert recipe.quantize["quant_cfg"][1]["cfg"]["num_bits"] == (4, 3) + + +def test_import_inline_cfg_not_affected(tmp_path): + """Inline dict cfg entries without $import are not touched.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + f" - quantizer_name: '*input_quantizer'\n" + f" cfg:\n" + f" num_bits: 8\n" + f" axis: 0\n" + ) + recipe = load_recipe(recipe_file) + assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + + +def test_import_unknown_reference_raises(tmp_path): + """Referencing an undefined import name raises ValueError.""" + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + "imports:\n" + " fp8: configs/numerics/fp8\n" + "metadata:\n" + " recipe_type: ptq\n" + "quantize:\n" + " algorithm: max\n" + " quant_cfg:\n" + " - quantizer_name: '*weight_quantizer'\n" + " cfg:\n" + " $import: nonexistent\n" + ) + with pytest.raises(ValueError, match=r"Unknown \$import reference"): + load_recipe(recipe_file) + + +def test_import_empty_path_raises(tmp_path): + """Import with empty config path raises ValueError.""" + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + "imports:\n" + " fp8:\n" + "metadata:\n" + " recipe_type: ptq\n" + "quantize:\n" + " algorithm: max\n" + " quant_cfg: []\n" + ) + with pytest.raises(ValueError, match="empty config path"): + load_recipe(recipe_file) + + +def test_import_not_a_dict_raises(tmp_path): + """Import section that is not a dict raises ValueError.""" + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + "imports:\n" + " - configs/numerics/fp8\n" + "metadata:\n" + " recipe_type: ptq\n" + "quantize:\n" + " algorithm: max\n" + " quant_cfg: []\n" + ) + with pytest.raises(ValueError, match="must be a dict"): + load_recipe(recipe_file) + + +def test_import_no_imports_section(tmp_path): + """Recipes without imports load normally.""" + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + "metadata:\n" + " recipe_type: ptq\n" + "quantize:\n" + " algorithm: max\n" + " quant_cfg:\n" + " - quantizer_name: '*'\n" + " enable: false\n" + ) + recipe = load_recipe(recipe_file) + assert recipe.quantize["quant_cfg"][0]["enable"] is False + + +def test_import_builtin_recipe_with_imports(): + """Built-in recipes using $import load and resolve correctly.""" + recipe = load_recipe("general/ptq/fp8_default-fp8_kv") + assert recipe.quantize + # Verify $import was resolved — cfg should be a dict, not a {$import: ...} marker + for entry in recipe.quantize["quant_cfg"]: + if "cfg" in entry and entry["cfg"] is not None: + assert "$import" not in entry["cfg"], f"Unresolved $import in {entry}" + + +def test_import_entry_single_element_list(tmp_path): + """$import splices a single-element list snippet into quant_cfg.""" + (tmp_path / "disable.yml").write_text("- quantizer_name: '*'\n enable: false\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" disable_all: {tmp_path / 'disable.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - $import: disable_all\n" + ) + recipe = load_recipe(recipe_file) + assert len(recipe.quantize["quant_cfg"]) == 1 + entry = recipe.quantize["quant_cfg"][0] + assert entry["quantizer_name"] == "*" + assert entry["enable"] is False + + +def test_import_entry_non_list_raises(tmp_path): + """$import in quant_cfg list position raises if snippet is not a list.""" + (tmp_path / "disable.yml").write_text("quantizer_name: '*'\nenable: false\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" disable_all: {tmp_path / 'disable.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - $import: disable_all\n" + ) + with pytest.raises(ValueError, match="must resolve to a list"): + load_recipe(recipe_file) + + +def test_import_entry_list_splice(tmp_path): + """$import as a quant_cfg list entry splices a list-valued snippet.""" + (tmp_path / "disables.yml").write_text( + "- quantizer_name: '*lm_head*'\n enable: false\n" + "- quantizer_name: '*router*'\n enable: false\n" + ) + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" disables: {tmp_path / 'disables.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*'\n" + f" enable: false\n" + f" - $import: disables\n" + ) + recipe = load_recipe(recipe_file) + assert len(recipe.quantize["quant_cfg"]) == 3 + assert recipe.quantize["quant_cfg"][1]["quantizer_name"] == "*lm_head*" + assert recipe.quantize["quant_cfg"][2]["quantizer_name"] == "*router*" + + +def test_import_entry_sibling_keys_with_list_snippet_raises(tmp_path): + """$import with sibling keys raises when the import resolves to a list (not a dict).""" + (tmp_path / "disable.yml").write_text("- quantizer_name: '*'\n enable: false\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" disable_all: {tmp_path / 'disable.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - $import: disable_all\n" + f" quantizer_name: '*extra*'\n" + ) + with pytest.raises(ValueError, match="must resolve to a dict"): + load_recipe(recipe_file) + + +def test_import_cfg_extend(tmp_path): + """$import in cfg with extra non-conflicting keys extends the snippet.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + f" axis: 0\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + assert cfg == {"num_bits": (4, 3), "axis": 0} + + +def test_import_cfg_inline_overrides_import(tmp_path): + """Inline keys override imported values (highest precedence).""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + f" num_bits: 8\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + # inline num_bits: 8 overrides imported num_bits: e4m3 → (4,3) + assert cfg["num_bits"] == 8 + # imported axis: None is preserved (no inline override) + assert cfg["axis"] is None + + +def test_import_in_non_cfg_dict_value(tmp_path): + """$import resolves in any dict value, not just cfg (tested via load_config to skip validation).""" + (tmp_path / "extra.yml").write_text("foo: bar\nbaz: 42\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" extra: {tmp_path / 'extra.yml'}\n" + f"quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" my_field:\n" + f" $import: extra\n" + ) + data = load_config(config_file) + entry = data["quant_cfg"][0] + assert entry["my_field"] == {"foo": "bar", "baz": 42} + + +def test_import_in_multiple_dict_values(tmp_path): + """$import resolves independently in multiple dict values of the same entry.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + (tmp_path / "extra.yml").write_text("foo: bar\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f" extra: {tmp_path / 'extra.yml'}\n" + f"quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + f" my_field:\n" + f" $import: extra\n" + ) + data = load_config(config_file) + entry = data["quant_cfg"][0] + assert entry["cfg"] == {"num_bits": (4, 3)} + assert entry["my_field"] == {"foo": "bar"} + + +def test_import_cfg_multi_import(tmp_path): + """$import with a list of names merges non-overlapping snippets.""" + (tmp_path / "bits.yml").write_text("num_bits: e4m3\n") + (tmp_path / "axis.yml").write_text("axis: 0\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" bits: {tmp_path / 'bits.yml'}\n" + f" axis: {tmp_path / 'axis.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: [bits, axis]\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + assert cfg == {"num_bits": (4, 3), "axis": 0} + + +def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): + """In $import list, later snippets override earlier ones on key conflicts.""" + (tmp_path / "a.yml").write_text("num_bits: e4m3\naxis: 0\n") + (tmp_path / "b.yml").write_text("num_bits: 8\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" a: {tmp_path / 'a.yml'}\n" + f" b: {tmp_path / 'b.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: [a, b]\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + # b overrides a's num_bits; a's axis is preserved + assert cfg["num_bits"] == 8 + assert cfg["axis"] == 0 + + +def test_import_cfg_multi_import_with_extend(tmp_path): + """$import list + inline keys all merge without conflicts.""" + (tmp_path / "bits.yml").write_text("num_bits: e4m3\n") + (tmp_path / "extra.yml").write_text("fake_quant: false\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" bits: {tmp_path / 'bits.yml'}\n" + f" extra: {tmp_path / 'extra.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: [bits, extra]\n" + f" axis: 0\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + + +def test_import_dir_format(tmp_path): + """Imports in recipe.yml work with the directory recipe format.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\naxis:\n") + (tmp_path / "recipe.yml").write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f" description: Dir with imports.\n" + ) + (tmp_path / "quantize.yml").write_text( + "algorithm: max\n" + "quant_cfg:\n" + " - quantizer_name: '*weight_quantizer'\n" + " cfg:\n" + " $import: fp8\n" + ) + recipe = load_recipe(tmp_path) + assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + + +# --------------------------------------------------------------------------- +# imports — multi-document snippets +# --------------------------------------------------------------------------- + + +def test_import_multi_document_list_snippet(tmp_path): + """List snippet using multi-document YAML (imports --- content) resolves $import.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + (tmp_path / "kv.yaml").write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"---\n" + f"- quantizer_name: '*[kv]_bmm_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + ) + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" kv: {tmp_path / 'kv.yaml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - $import: kv\n" + ) + recipe = load_recipe(recipe_file) + assert len(recipe.quantize["quant_cfg"]) == 1 + assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" + assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + + +def test_import_builtin_fp8_kv_snippet(): + """Built-in fp8_kv snippet uses multi-document format and resolves correctly.""" + recipe = load_recipe("general/ptq/fp8_default-fp8_kv") + kv_entries = [ + e for e in recipe.quantize["quant_cfg"] if e.get("quantizer_name") == "*[kv]_bmm_quantizer" + ] + assert len(kv_entries) == 1 + assert kv_entries[0]["cfg"]["num_bits"] == (4, 3) + + +# --------------------------------------------------------------------------- +# imports — general tree-wide resolution (not just quant_cfg) +# --------------------------------------------------------------------------- + + +def test_import_in_top_level_dict_value(tmp_path): + """$import resolves in a top-level dict value (not inside any list).""" + (tmp_path / "algo.yml").write_text("method: gptq\nuse_layerwise: true\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n algo: {tmp_path / 'algo.yml'}\nalgorithm:\n $import: algo\nquant_cfg: []\n" + ) + data = load_config(config_file) + assert data["algorithm"] == {"method": "gptq", "use_layerwise": True} + + +def test_import_in_nested_dict(tmp_path): + """$import resolves in deeply nested dicts.""" + (tmp_path / "settings.yml").write_text("lr: 0.001\nepochs: 10\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" settings: {tmp_path / 'settings.yml'}\n" + f"training:\n" + f" optimizer:\n" + f" params:\n" + f" $import: settings\n" + ) + data = load_config(config_file) + assert data["training"]["optimizer"]["params"] == {"lr": 0.001, "epochs": 10} + + +def test_import_list_splice_outside_quant_cfg(tmp_path): + """$import list splice works in any list, not just quant_cfg.""" + (tmp_path / "extra_tasks.yml").write_text("- name: task_b\n- name: task_c\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" extra: {tmp_path / 'extra_tasks.yml'}\n" + f"tasks:\n" + f" - name: task_a\n" + f" - $import: extra\n" + f" - name: task_d\n" + ) + data = load_config(config_file) + assert data["tasks"] == [ + {"name": "task_a"}, + {"name": "task_b"}, + {"name": "task_c"}, + {"name": "task_d"}, + ] + + +def test_import_in_nested_list_of_dicts(tmp_path): + """$import in dict values within a nested list resolves correctly.""" + (tmp_path / "defaults.yml").write_text("timeout: 30\nretries: 3\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" defaults: {tmp_path / 'defaults.yml'}\n" + f"stages:\n" + f" - name: build\n" + f" config:\n" + f" $import: defaults\n" + f" verbose: true\n" + f" - name: test\n" + f" config:\n" + f" $import: defaults\n" + ) + data = load_config(config_file) + assert data["stages"][0]["config"] == {"timeout": 30, "retries": 3, "verbose": True} + assert data["stages"][1]["config"] == {"timeout": 30, "retries": 3} + + +def test_import_mixed_tree(tmp_path): + """$import resolves at multiple levels in the same config.""" + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + (tmp_path / "disables.yml").write_text("- quantizer_name: '*lm_head*'\n enable: false\n") + (tmp_path / "meta.yml").write_text("version: 2\nauthor: test\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f" disables: {tmp_path / 'disables.yml'}\n" + f" meta: {tmp_path / 'meta.yml'}\n" + f"info:\n" + f" $import: meta\n" + f"items:\n" + f" - name: a\n" + f" cfg:\n" + f" $import: fp8\n" + f" - $import: disables\n" + ) + data = load_config(config_file) + # Top-level dict import + assert data["info"] == {"version": 2, "author": "test"} + # Dict import inside list entry + assert data["items"][0]["cfg"] == {"num_bits": (4, 3)} + # List splice + assert data["items"][1] == {"quantizer_name": "*lm_head*", "enable": False} + + +# --------------------------------------------------------------------------- +# imports — recursive resolution and cycle detection +# --------------------------------------------------------------------------- + + +def test_import_recursive(tmp_path): + """A list snippet can import a dict snippet (recursive resolution via multi-doc).""" + # base: dict snippet with FP8 attributes + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + # mid: list snippet that imports base and uses $import in cfg + (tmp_path / "mid.yaml").write_text( + f"imports:\n" + f" fp8: {tmp_path / 'fp8.yml'}\n" + f"---\n" + f"- quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fp8\n" + ) + # recipe imports mid + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" mid: {tmp_path / 'mid.yaml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - $import: mid\n" + ) + recipe = load_recipe(recipe_file) + cfg = recipe.quantize["quant_cfg"][0]["cfg"] + assert cfg == {"num_bits": (4, 3)} + + +def test_import_circular_raises(tmp_path): + """Circular imports are detected and raise ValueError.""" + (tmp_path / "a.yml").write_text(f"imports:\n b: {tmp_path / 'b.yml'}\nnum_bits: 8\n") + (tmp_path / "b.yml").write_text(f"imports:\n a: {tmp_path / 'a.yml'}\nnum_bits: 4\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" a: {tmp_path / 'a.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg: []\n" + ) + with pytest.raises(ValueError, match="Circular import"): + load_recipe(recipe_file) + + +def test_import_circular_via_path_aliases_raises(tmp_path): + """Circular detection survives path aliases (absolute vs relative vs no-suffix). + + ``a.yml`` imports ``b`` using the absolute path with ``.yml`` suffix, while + ``b.yml`` imports back using the relative path without suffix. Without path + canonicalization these are distinct strings, and the cycle goes undetected. + """ + (tmp_path / "a.yml").write_text(f"imports:\n b: {tmp_path / 'b.yml'}\nnum_bits: 8\n") + # b imports a via a sibling-relative path + no suffix, so the import key + # differs textually from the absolute path a was loaded under. + (tmp_path / "b.yml").write_text("imports:\n a: ./a\nnum_bits: 4\n") + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" a: {tmp_path / 'a.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg: []\n" + ) + import os + + cwd = os.getcwd() + os.chdir(tmp_path) + try: + with pytest.raises(ValueError, match="Circular import"): + load_recipe(recipe_file) + finally: + os.chdir(cwd) + + +def test_import_cross_file_same_name_no_conflict(tmp_path): + """Same import name in parent and child resolve independently (scoped). + + This test intentionally exercises both sides of the scope boundary: + + * Parent's ``fmt`` → fp8 (resolved when the recipe's own ``$import: fmt`` + fires). + * Child's ``fmt`` → nvfp4 (resolved inside ``child.yml`` before the parent + ever sees the snippet). + + Both values must survive together in the final recipe — if the names were + accidentally shared across files, one would clobber the other. + """ + (tmp_path / "fp8.yml").write_text("num_bits: e4m3\n") + (tmp_path / "nvfp4.yml").write_text("num_bits: e2m1\n") + # child.yml uses its own "fmt" (→ nvfp4) via an inline $import. When the + # parent imports `child`, the snippet it sees has inner.$import already + # resolved in child's scope. + (tmp_path / "child.yml").write_text( + f"imports:\n fmt: {tmp_path / 'nvfp4.yml'}\ninner:\n $import: fmt\n" + ) + recipe_file = tmp_path / "recipe.yml" + recipe_file.write_text( + f"imports:\n" + f" fmt: {tmp_path / 'fp8.yml'}\n" + f" child: {tmp_path / 'child.yml'}\n" + f"metadata:\n" + f" recipe_type: ptq\n" + f"quantize:\n" + f" algorithm: max\n" + f" quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" $import: fmt\n" + f" - quantizer_name: '*input_quantizer'\n" + f" cfg:\n" + f" $import: child\n" + ) + recipe = load_recipe(recipe_file) + # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. + assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. + assert recipe.quantize["quant_cfg"][1]["cfg"] == {"inner": {"num_bits": (2, 1)}} + + +# --------------------------------------------------------------------------- +# Coverage: _load_raw_config edge cases +# --------------------------------------------------------------------------- + + +def test_load_config_path_object(tmp_path): + """load_config accepts a Path object.""" + cfg_file = tmp_path / "test.yaml" + cfg_file.write_text("key: value\n") + data = load_config(cfg_file) + assert data == {"key": "value"} + + +def test_load_config_path_without_suffix(tmp_path): + """load_config probes .yml/.yaml suffixes for a Path without suffix.""" + cfg_file = tmp_path / "test.yaml" + cfg_file.write_text("key: value\n") + data = load_config(tmp_path / "test") # no suffix + assert data == {"key": "value"} + + +def test_load_config_empty_yaml(tmp_path): + """load_config returns empty dict for empty YAML file.""" + cfg_file = tmp_path / "empty.yaml" + cfg_file.write_text("") + data = load_config(cfg_file) + assert data == {} + + +def test_load_config_null_yaml(tmp_path): + """load_config returns empty dict for YAML file containing only null.""" + cfg_file = tmp_path / "null.yaml" + cfg_file.write_text("---\n") + data = load_config(cfg_file) + assert data == {} + + +def test_load_config_multi_doc_dict_dict(tmp_path): + """Multi-document YAML with two dicts merges them.""" + cfg_file = tmp_path / "multi.yaml" + cfg_file.write_text("imports:\n fp8: some/path\n---\nalgorithm: max\n") + from modelopt.torch.opt.config_loader import _load_raw_config + + data = _load_raw_config(cfg_file) + assert data["imports"] == {"fp8": "some/path"} + assert data["algorithm"] == "max" + + +def test_load_config_multi_doc_null_content(tmp_path): + """Multi-document YAML where second doc is null treats content as empty dict.""" + cfg_file = tmp_path / "multi_null.yaml" + cfg_file.write_text("key: value\n---\n") + from modelopt.torch.opt.config_loader import _load_raw_config + + data = _load_raw_config(cfg_file) + assert data == {"key": "value"} + + +def test_load_config_multi_doc_first_not_dict_raises(tmp_path): + """Multi-document YAML with non-dict first document raises ValueError.""" + cfg_file = tmp_path / "bad_multi.yaml" + cfg_file.write_text("- item1\n---\nkey: value\n") + with pytest.raises(ValueError, match="first YAML document must be a mapping"): + load_config(cfg_file) + + +def test_load_config_multi_doc_second_not_dict_or_list_raises(tmp_path): + """Multi-document YAML with scalar second document raises ValueError.""" + cfg_file = tmp_path / "bad_multi2.yaml" + cfg_file.write_text("key: value\n---\njust a string\n") + with pytest.raises(ValueError, match="second YAML document must be a mapping or list"): + load_config(cfg_file) + + +def test_load_config_three_docs_raises(tmp_path): + """YAML with 3+ documents raises ValueError.""" + cfg_file = tmp_path / "three_docs.yaml" + cfg_file.write_text("a: 1\n---\nb: 2\n---\nc: 3\n") + with pytest.raises(ValueError, match="expected 1 or 2 YAML documents"): + load_config(cfg_file) + + +def test_load_config_invalid_type_raises(): + """load_config with non-string/Path/Traversable raises ValueError.""" + with pytest.raises(ValueError, match="Invalid config file"): + load_config(12345) + + +def test_load_config_list_valued_yaml(tmp_path): + """load_config handles top-level YAML list.""" + cfg_file = tmp_path / "list.yaml" + cfg_file.write_text("- name: a\n value: 1\n- name: b\n value: 2\n") + data = load_config(cfg_file) + assert isinstance(data, list) + assert len(data) == 2 + assert data[0] == {"name": "a", "value": 1} + + +# --------------------------------------------------------------------------- +# Coverage: _resolve_imports edge cases +# --------------------------------------------------------------------------- + + +def test_import_dict_value_resolves_to_list_raises(tmp_path): + """$import in dict value position raises when snippet is a list.""" + (tmp_path / "entries.yml").write_text("- a: 1\n- b: 2\n") + config_file = tmp_path / "config.yml" + config_file.write_text( + f"imports:\n entries: {tmp_path / 'entries.yml'}\nmy_field:\n $import: entries\n" + ) + with pytest.raises(ValueError, match="must resolve to a dict"): + load_config(config_file) + + +def test_import_imports_not_a_dict_raises(tmp_path): + """imports section that is a list raises ValueError.""" + config_file = tmp_path / "config.yml" + config_file.write_text("imports:\n - some/path\nkey: value\n") + with pytest.raises(ValueError, match="must be a dict"): + load_config(config_file) diff --git a/tools/precommit/check_modelopt_recipes.py b/tools/precommit/check_modelopt_recipes.py index b964b4b040d..2c5706ee73b 100644 --- a/tools/precommit/check_modelopt_recipes.py +++ b/tools/precommit/check_modelopt_recipes.py @@ -51,14 +51,25 @@ def _check_quant_cfg(quant_cfg, label: str) -> list[str]: if not isinstance(entry, dict): errors.append( f"{label}: quant_cfg[{i}] must be a dict with " - f"'quantizer_name', got {type(entry).__name__}. " + f"'quantizer_name' or '$import', got {type(entry).__name__}. " "See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html" ) continue + # {$import: name} entries are resolved at load time + if "$import" in entry: + ref = entry["$import"] + if not isinstance(ref, (str, list)) or ( + isinstance(ref, list) and not all(isinstance(r, str) for r in ref) + ): + errors.append( + f"{label}: quant_cfg[{i}] '$import' must be a string or list of strings, " + f"got {type(ref).__name__}: {ref!r}" + ) + continue if "quantizer_name" not in entry: errors.append( f"{label}: quant_cfg[{i}] is missing 'quantizer_name'. " - "Each entry must have an explicit 'quantizer_name' key. " + "Each entry must have an explicit 'quantizer_name' or '$import' key. " "See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html" ) return errors