diff --git a/ane_transformers/reference/layer_norm.py b/ane_transformers/reference/layer_norm.py index a59eeb0..4c52cb0 100644 --- a/ane_transformers/reference/layer_norm.py +++ b/ane_transformers/reference/layer_norm.py @@ -73,7 +73,7 @@ def forward(self, inputs): out = zero_mean * denom if self.elementwise_affine: - out = (out + self.bias.view(1, self.num_channels, 1, 1) - ) * self.weight.view(1, self.num_channels, 1, 1) + out = (out * self.weight.view(1, self.num_channels, 1, 1) + ) + self.bias.view(1, self.num_channels, 1, 1) return out