Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
get_weight_block_size,
get_weight_scaling_factor,
get_weight_scaling_factor_2,
process_layer_quant_config,
to_quantized_weight,
)

Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
self.all_rules = self._populate_rule_book()
self.rules = self.all_rules[self.arch]
self.exclude_modules = []
self.layer_config_dict = {}

if not hasattr(model, "_modelopt_state"):
return
Expand Down Expand Up @@ -324,22 +326,32 @@ def save_pretrained(
print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors")

combined_exclude_modules = self._gather_exclude_modules()
combined_layer_config_dict = self._gather_layer_config_dict()

if is_last_stage_main_rank and quantization is not None:
self._hf_quant_config = {
if combined_layer_config_dict:
quantization_config = process_layer_quant_config(combined_layer_config_dict)
quantization_config["exclude_modules"] = combined_exclude_modules
else:
quantization_config = {
"quant_algo": quantization,
"exclude_modules": combined_exclude_modules,
}
if quantization == "NVFP4": # update block size
quantization_config["group_size"] = 16

if hasattr(self, "kv_cache_dtype"):
quantization_config["kv_cache_quant_algo"] = self.kv_cache_dtype

raw_hf_quant_config = {
"producer": {
"name": "modelopt",
"version": __version__,
},
"quantization": {
"quant_algo": quantization,
"exclude_modules": combined_exclude_modules,
},
"quantization": quantization_config,
}
if quantization == "NVFP4": # update block size
self._hf_quant_config["quantization"]["group_size"] = 16
if hasattr(self, "kv_cache_dtype"):
self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype
# Use one serving-facing config for both hf_quant_config.json and config.json.
self._hf_quant_config = convert_hf_quant_config_format(raw_hf_quant_config)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we change the format of hf_quant_config.json by this change?

with open(save_directory + "/hf_quant_config.json", "w") as f:
json.dump(self._hf_quant_config, f, indent=4)

Expand All @@ -359,10 +371,9 @@ def save_pretrained(
# Newer versions of VLLM expect config.json with hf_quant_config
config_json_file = save_directory + "/config.json"
if self._hf_quant_config and os.path.exists(config_json_file):
converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config)
with open(config_json_file) as f:
config_dict = json.load(f)
config_dict["quantization_config"] = converted_quant_config
config_dict["quantization_config"] = self._hf_quant_config
with open(config_json_file, "w") as f:
json.dump(config_dict, f, indent=4)

Expand Down Expand Up @@ -803,9 +814,7 @@ def _get_quantized_state(
name_to_value = {}
qformat: str = self._get_quantization_format(module)
if qformat is None and "norm" not in prefix:
# Add exclude layers for hf_quant_config. Note that if the prefix is not an empty
# string then it usually ends with "." which needs to be removed.
self.exclude_modules.append(prefix.removesuffix("."))
self._record_excluded_module(prefix)
block_size = get_weight_block_size(module)
Comment thread
jenchen13 marked this conversation as resolved.

name_to_value = self._get_weight_bias(module, dtype, name_to_value)
Expand Down Expand Up @@ -850,6 +859,27 @@ def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str):

return weight_scale, weight_scale_2

def _record_layer_quant_config(self, prefix: str, qformat: str | None, block_size: int):
"""Record per-HF-layer quantization metadata for mixed precision exports."""
if qformat in (None, QUANTIZATION_NONE):
return

layer_name = prefix.removesuffix(".")
if "{" in layer_name or not layer_name:
return

self.layer_config_dict[layer_name + ".quantization"] = qformat
self.layer_config_dict[layer_name + ".awq_block_size"] = block_size

def _record_excluded_module(self, prefix: str):
"""Record an unquantized HF module prefix for hf_quant_config."""
layer_name = prefix.removesuffix(".")
if "{" in layer_name or not layer_name:
return

if layer_name not in self.exclude_modules:
self.exclude_modules.append(layer_name)

def _name_remapping(
self,
module: torch.nn.Module | torch.Tensor,
Expand All @@ -866,6 +896,7 @@ def _name_remapping(
return

name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix)
self._record_layer_quant_config(prefix, qformat, block_size)

weight = name_to_value.pop("weight")
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
Expand Down Expand Up @@ -906,6 +937,8 @@ def _gated_mlp_slicing(

gate_proj_prefix = prefix + gate_proj_name + "."
up_proj_prefix = prefix + up_proj_name + "."
self._record_layer_quant_config(gate_proj_prefix, qformat, block_size)
self._record_layer_quant_config(up_proj_prefix, qformat, block_size)

ffn_hidden_size = module.config.ffn_hidden_size
gate_proj_weight = weight[:ffn_hidden_size, :]
Expand Down Expand Up @@ -986,6 +1019,7 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None):

for expert_id in range(num_experts):
expert_prefix = prefix.format(expert_id) + "."
self._record_layer_quant_config(expert_prefix, qformat, block_size)
weight_key = f"weight{expert_id}"

if weight_key not in state_dict:
Expand Down Expand Up @@ -1030,6 +1064,18 @@ def _qkv_slicing(
q_proj_prefix = prefix + q_proj_name + "."
k_proj_prefix = prefix + k_proj_name + "."
v_proj_prefix = prefix + v_proj_name + "."
self._record_layer_quant_config(q_proj_prefix, qformat, block_size)
self._record_layer_quant_config(k_proj_prefix, qformat, block_size)
self._record_layer_quant_config(v_proj_prefix, qformat, block_size)
if qformat in (None, QUANTIZATION_NONE):
# MCore stores Q/K/V as one fused linear_qkv module, but HF exports them
# as separate q_proj/k_proj/v_proj modules. Record the HF names so
# runtime quant configs do not miss excluded fused-QKV projections.
fused_prefix = prefix.removesuffix(".")
self.exclude_modules = [m for m in self.exclude_modules if m != fused_prefix]
self._record_excluded_module(q_proj_prefix)
self._record_excluded_module(k_proj_prefix)
self._record_excluded_module(v_proj_prefix)

config = module.config
hidden_size = config.hidden_size
Expand Down Expand Up @@ -1179,6 +1225,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None):
weight_scale_list.append(weight_scale)
weight_scale_2_list.append(weight_scale_2)
input_scale_list.append(input_scale)
self._record_layer_quant_config(prefix, qformat, block_size)

merged_weight = torch.stack(weight_list, dim=0)

Expand Down Expand Up @@ -1247,6 +1294,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None):
weight_scale_2_list.append(weight_scale_2)
input_scale_list.append(input_scale)
bias_list.append(bias)
self._record_layer_quant_config(prefix, qformat, block_size)

merged_weight = torch.stack(weight_list, dim=0)

Expand Down Expand Up @@ -1349,6 +1397,19 @@ def _gather_exclude_modules(self):
combined_exclude_modules.update(modules)
return sorted(combined_exclude_modules)

def _gather_layer_config_dict(self):
"""Get per-layer quantization metadata from all ranks for hf_quant_config."""
if not torch.distributed.is_initialized():
return dict(sorted(self.layer_config_dict.items()))

all_layer_config_dicts = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_layer_config_dicts, self.layer_config_dict)
combined_layer_config_dict = {}
for layer_config_dict in all_layer_config_dicts:
if layer_config_dict:
combined_layer_config_dict.update(layer_config_dict)
return dict(sorted(combined_layer_config_dict.items()))


def export_mcore_gpt_to_hf(
model: torch.nn.Module,
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ def __init__(
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
fp8_scale_sweep_stride: int = 1,
):
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax
self._fp8_scale_sweep_stride = max(1, fp8_scale_sweep_stride or 1)

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
Expand All @@ -197,4 +199,9 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor:
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
if self._fp8_scale_sweep_stride > 1:
candidates = fp8_values[:: self._fp8_scale_sweep_stride]
if candidates[-1] != fp8_values[-1]:
candidates = torch.cat([candidates, fp8_values[-1:]])
fp8_values = candidates
return fp8_values / 448.0
8 changes: 8 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
"start_multiplier, and stop_multiplier are ignored.",
)

fp8_scale_sweep_stride: int | None = ModeloptField(
default=1,
ge=1,
title="Stride for FP8 scale sweep candidates.",
description="Subsample every Nth valid FP8 E4M3 scale candidate when fp8_scale_sweep is True. "
"A value of 1 preserves the exhaustive sweep.",
)

distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
Expand Down
68 changes: 45 additions & 23 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,44 @@ def _has_expert_parallelism(module: nn.Module) -> bool:
return ps is not None and ps.expert_model_parallel_group.is_initialized()


def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete (some ranks have amax, others don't)."""
def _is_dynamic_block_quantizer(quantizer) -> bool:
block_sizes = getattr(quantizer, "block_sizes", None)
if isinstance(block_sizes, dict):
return block_sizes.get("type") == "dynamic"
return getattr(block_sizes, "type", None) == "dynamic"


def _iter_leaf_quantizers(quantizer):
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
_check_moe_calibration_complete(_q, parallel_state)
yield from _iter_leaf_quantizers(_q)
return
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
yield quantizer


def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete across distributed MoE ranks."""
for leaf_quantizer in _iter_leaf_quantizers(quantizer):
if _is_dynamic_block_quantizer(leaf_quantizer):
continue
has_amax = getattr(quantizer, "_amax", None) is not None
amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during calibration. "
"Increase --calib-size to ensure all experts see calibration data."

has_amax = getattr(leaf_quantizer, "_amax", None) is not None
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
continue
amax_states = DistributedProcessGroup.get_dist_syncd_obj(
has_amax, group, lambda objs: objs
)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during "
"calibration. Increase --calib-size to ensure all experts see calibration "
"data."
)


@torch.no_grad()
Expand Down Expand Up @@ -175,13 +193,13 @@ def max_calibrate(

def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
for leaf_quantizer in _iter_leaf_quantizers(quantizer):
if _is_dynamic_block_quantizer(leaf_quantizer):
continue
leaf_quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
leaf_quantizer.sync_amax_across_distributed_group(
parallel_state.expert_model_parallel_group
)
# TODO: create sync_bias_across_distributed_group

# Step 2:Sync amax across data parallelism
Expand Down Expand Up @@ -226,7 +244,7 @@ def sync_quantizer_amax_across_tp(
)
# Skip amax sync for INT4 / W4A8 block quantization
# Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale)
if getattr(quantizer.block_sizes, "type", None) == "dynamic":
if _is_dynamic_block_quantizer(quantizer):
return

if quantizer.axis in axes_for_sync and quantizer.amax is not None:
Expand Down Expand Up @@ -314,6 +332,7 @@ def mse_calibrate(
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
fp8_scale_sweep_stride: int = 1,
):
"""Calibrate the model using MSE-based amax search.

Expand All @@ -333,6 +352,8 @@ def mse_calibrate(
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).
fp8_scale_sweep_stride: Subsample every Nth FP8 E4M3 candidate when
fp8_scale_sweep is enabled. A value of 1 preserves exhaustive sweep.

See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
Expand Down Expand Up @@ -388,6 +409,7 @@ def mse_calibrate(
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
fp8_scale_sweep_stride=fp8_scale_sweep_stride,
)
continue

Expand Down
8 changes: 6 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ def _get_shard_axis_dict(self, state_dict):
"""
shard_axis_dict = {}
for k in state_dict:
if "weight_quantizer." in k:
# Static NVFP4 _global_amax is a scalar shared by all TP ranks; only shard
# per-block/per-channel weight quantizer state.
if "weight_quantizer." in k and not k.endswith("._global_amax"):
weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis
if weight_quantizer_axis is not None:
shard_axis_dict[k] = 0
Expand Down Expand Up @@ -427,7 +429,9 @@ def _get_shard_axis_dict(self, state_dict):
"""
shard_axis_dict = {}
for k in state_dict:
if "weight_quantizer." in k:
# Static NVFP4 _global_amax is a scalar shared by all TP ranks; only shard
# per-block/per-channel weight quantizer state.
if "weight_quantizer." in k and not k.endswith("._global_amax"):
weight_quantizer_axis = None
if isinstance(self.weight_quantizer, TensorQuantizer):
weight_quantizer_axis = self.weight_quantizer.axis
Expand Down
Loading
Loading