diff --git a/optimum/amd/brevitas/configuration.py b/optimum/amd/brevitas/configuration.py index b5f4e637..8cd2dba1 100644 --- a/optimum/amd/brevitas/configuration.py +++ b/optimum/amd/brevitas/configuration.py @@ -137,5 +137,8 @@ def __post_init__(self): self.activations_group_size = None self.activations_param_method = None + def add_bias_to_linear(self): + return self.apply_bias_correction and self.device == "auto" + def requires_fx_graph(self): return self.activations_equalization == "cross_layer" or self.apply_weight_equalization diff --git a/optimum/amd/brevitas/quantizer.py b/optimum/amd/brevitas/quantizer.py index d06e143b..88e7ecb6 100644 --- a/optimum/amd/brevitas/quantizer.py +++ b/optimum/amd/brevitas/quantizer.py @@ -192,6 +192,8 @@ def quantize( if use_accelerate: remove_hooks(model) device = None + if quantization_config.add_bias_to_linear(): + model = add_zero_bias_to_linear(model) else: device = next(model.parameters()).device @@ -244,6 +246,7 @@ def quantize( apply_bias_correction( model, calibration_dataset, + skip_if_no_bias=use_accelerate, # We can't add keys to the state dict if accelerate is being used ) logger.info("Bias Correction applied.") @@ -331,7 +334,21 @@ def apply_calibration(model: torch.nn.Module, dataset: List[Dict]) -> None: @torch.no_grad() -def apply_bias_correction(model: torch.nn.Module, dataset: List[Dict]) -> None: - with bias_correction_mode(model): +def apply_bias_correction(model: torch.nn.Module, dataset: List[Dict], skip_if_no_bias: bool = False) -> None: + with bias_correction_mode(model, skip_if_no_bias=skip_if_no_bias): for inps in tqdm(dataset): model(**inps) + + +@torch.no_grad() +def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module: + for name, module in model.named_modules(): + if type(module) == torch.nn.Linear: + if module.bias is None: + module.register_parameter( + "bias", + torch.nn.Parameter( + torch.zeros((module.weight.shape[0],), device=module.weight.device, dtype=module.weight.dtype) + ), + ) + return model