Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
from transformers import AutoTokenizer
from vllm.v1.worker.gpu_worker import Worker as BaseWorker
from vllm_ptq_utils import calibrate_fun, get_quant_config
from vllm_reload_utils import (
convert_dict_to_vllm,
Expand All @@ -38,6 +37,7 @@
)
from modelopt.torch.utils import safe_load
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
from vllm.v1.worker.gpu_worker import Worker as BaseWorker

quant_config: dict[str, Any] = {
"dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
Expand Down
4 changes: 2 additions & 2 deletions examples/vllm_serve/vllm_ptq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput

import modelopt.torch.quantization as mtq
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput


def _create_new_data_cls(data_cls, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion examples/vllm_serve/vllm_reload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Any

import torch
from vllm.distributed.parallel_state import get_tp_group

from modelopt.torch.export.plugins.vllm_fakequant_hf import (
infer_quantizer_prefix_remap,
Expand All @@ -38,6 +37,7 @@
)
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.utils import is_quantized
from vllm.distributed.parallel_state import get_tp_group


def _union_quantizer_keys_across_ranks(local_quantizer_keys: list[str]) -> set[str]:
Expand Down
3 changes: 2 additions & 1 deletion examples/vllm_serve/vllm_serve_fakequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@
from pathlib import Path

import uvloop
import vllm
from packaging import version

import vllm
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser

Expand Down
15 changes: 14 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,12 +1189,25 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)

if hf_quant_config is not None:
# Only treat the export as quantized when at least one quant_algo field is set.
# get_quant_config always returns a dict (even for sparsity-only or unmodified models),
# so emitting hf_quant_config.json unconditionally produces a file with
# "quant_algo": null that downstream loaders (e.g. TensorRT-LLM) reject as a
# malformed pre-quantized checkpoint.
quantization_details = (hf_quant_config or {}).get("quantization", {})
is_quantized_export = (
quantization_details.get("quant_algo") is not None
or quantization_details.get("kv_cache_quant_algo") is not None
)

if is_quantized_export:
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
else:
hf_quant_config = None

# Remove hf_quantizer from model so post_state_dict can be exported.
if getattr(model, "hf_quantizer", None) is not None:
Expand Down
Loading