diff --git a/test/xpu/quantization/core/test_quantized_op_xpu.py b/test/xpu/quantization/core/test_quantized_op_xpu.py index e4c133b2dd..cb1998421f 100644 --- a/test/xpu/quantization/core/test_quantized_op_xpu.py +++ b/test/xpu/quantization/core/test_quantized_op_xpu.py @@ -12,6 +12,7 @@ import torch from torch.nn.modules.utils import _pair from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import run_tests, TestCase try: @@ -75,7 +76,13 @@ def _test_max_pool2d_pt2e(self): self.assertEqual(a_pool, a_hat, msg="ops.quantized.max_pool2d results are off") +def _test_qsoftmax_qnnpack(self): + with override_quantized_engine("qnnpack"): + self.test_qsoftmax_xpu() + + TestQuantizedOps.test_max_pool2d_pt2e = _test_max_pool2d_pt2e +TestQuantizedOps.test_qsoftmax_qnnpack = _test_qsoftmax_qnnpack instantiate_device_type_tests( TestQuantizedOps, globals(), only_for="xpu", allow_xpu=True