From 8b7b64b6a231318ea0318dcaf8dda36ac24b007d Mon Sep 17 00:00:00 2001 From: Adam Stachowicz Date: Thu, 21 May 2026 12:44:35 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20Fix=20qsoftmax=20qnnpack=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/xpu/quantization/core/test_quantized_op_xpu.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/xpu/quantization/core/test_quantized_op_xpu.py b/test/xpu/quantization/core/test_quantized_op_xpu.py index e4c133b2dd..cb1998421f 100644 --- a/test/xpu/quantization/core/test_quantized_op_xpu.py +++ b/test/xpu/quantization/core/test_quantized_op_xpu.py @@ -12,6 +12,7 @@ import torch from torch.nn.modules.utils import _pair from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import run_tests, TestCase try: @@ -75,7 +76,13 @@ def _test_max_pool2d_pt2e(self): self.assertEqual(a_pool, a_hat, msg="ops.quantized.max_pool2d results are off") +def _test_qsoftmax_qnnpack(self): + with override_quantized_engine("qnnpack"): + self.test_qsoftmax_xpu() + + TestQuantizedOps.test_max_pool2d_pt2e = _test_max_pool2d_pt2e +TestQuantizedOps.test_qsoftmax_qnnpack = _test_qsoftmax_qnnpack instantiate_device_type_tests( TestQuantizedOps, globals(), only_for="xpu", allow_xpu=True