diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index cda21cc074801..8ee8ae883daa5 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -110,6 +110,19 @@ jobs: # TODO(https://github.com/pytorch/pytorch/issues/79293) ${CONDA_RUN} python3 test/test_nn.py -k mps --verbose + - name: Run MPS Test Ops + id: test_3 + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + shell: arch -arch arm64 bash {0} + # During bring up of NN don't show this as an error. + continue-on-error: true + run: | + # shellcheck disable=SC1090 + set -ex + # TODO(https://github.com/pytorch/pytorch/issues/79293) + ${CONDA_RUN} PYTORCH_TEST_WITH_SLOW=1 python3 test/test_ops.py -k mps --verbose + - name: Print remaining test logs shell: bash if: always() diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 537881e34da95..4b372243c8681 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -69,7 +69,7 @@ jobs: # shellcheck disable=SC1090 set -ex set +e - if ! ${CONDA_RUN} lintrunner --force-color test/*.py aten/src/ATen/native/mps/*.h aten/src/ATen/native/mps/*.mm aten/src/ATen/native/mps/operations/*; then + if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py test/test_modules.py test/test_ops.py; then echo "" echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m" echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" diff --git a/test/test_ops.py b/test/test_ops.py index d40b625f93eac..bbce0430e1593 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -18,6 +18,7 @@ from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, all_types_and_complex_and, + get_all_dtypes, ) from torch.testing._internal.common_utils import ( @@ -110,6 +111,17 @@ aten = torch.ops.aten +MPS_DTYPES = get_all_dtypes() +for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: + del MPS_DTYPES[MPS_DTYPES.index(t)] + +def _get_mps_error_msg(device, dtype, op, mps_blocklist): + if torch.backends.mps.is_available() and device == "mps" and dtype not in MPS_DTYPES: + return f"MPS doesn't support {str(dtype)} datatype" + if op.name.startswith(tuple(mps_blocklist)): + return "MPS doesn't support op " + str(op.name) + return None + # Tests that apply to all operators and aren't related to any particular # system class TestCommon(TestCase): @@ -256,12 +268,18 @@ def test_numpy_ref(self, device, dtype, op): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA @suppress_warnings @slowTest @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) def test_compare_cpu(self, device, dtype, op): + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) + def to_cpu(arg): if isinstance(arg, torch.Tensor): return arg.to(device='cpu') @@ -271,7 +289,7 @@ def to_cpu(arg): for sample in samples: cpu_sample = sample.transform(to_cpu) - cuda_results = op(sample.input, *sample.args, **sample.kwargs) + gpu_results = op(sample.input, *sample.args, **sample.kwargs) cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) # output_process_fn_grad has a very unfortunate name @@ -279,12 +297,12 @@ def to_cpu(arg): # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. # CPU and CUDA implementations of the SVD can return valid SVDs that are different. # We use this function to compare them. - cuda_results = sample.output_process_fn_grad(cuda_results) + gpu_results = sample.output_process_fn_grad(gpu_results) cpu_results = cpu_sample.output_process_fn_grad(cpu_results) # Lower tolerance because we are running this as a `@slowTest` # Don't want the periodic tests to fail frequently - self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) + self.assertEqual(gpu_results, cpu_results, atol=1e-3, rtol=1e-3) # Tests that experimental Python References can propagate shape, dtype, # and device metadata properly. @@ -479,11 +497,24 @@ def test_python_ref_torch_fallback(self, device, dtype, op): self._ref_test_helper(contextlib.nullcontext, device, dtype, op) @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @onlyCUDA @ops(python_ref_db) @parametrize('executor', ['aten', 'nvfuser']) @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_executor(self, device, dtype, op, executor): + if device == "mps" and executor == 'nvfuser': + return + MPS_BLOCKLIST = [ + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat + "_refs.floor_divide", # hard crash on unsupoorted ComplexFloat + "_refs.where", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) # TODO: Not all dtypes are supported with nvfuser from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map: @@ -663,6 +694,12 @@ def test_noncontiguous_samples(self, device, dtype, op): @ops(_ops_and_refs, dtypes=OpDTypes.none) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_out_warning(self, device, op): + MPS_BLOCKLIST = [ + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat + ] # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: @@ -673,6 +710,9 @@ def test_out_warning(self, device, op): else list(supported_dtypes)[0] ) + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) samples = op.sample_inputs(device, dtype) for sample in samples: # calls it normally to get the expected result @@ -716,7 +756,7 @@ def _extract_strides(out): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps": return () if isinstance(out, torch.Tensor): @@ -792,6 +832,23 @@ def _any_nonempty(out): @ops(_ops_and_refs, dtypes=OpDTypes.any_one) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_out(self, device, dtype, op): + MPS_BLOCKLIST = [ + "_refs._conversions.complex", # hard crash on unsupoorted ComplexFloat + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat + "bitwise_not", # hard crash on unsupoorted ComplexFloat + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) # Prefers running in float32 but has a fallback for the first listed supported dtype samples = op.sample_inputs(device, dtype) for sample in samples: @@ -836,7 +893,7 @@ def _extract_strides(out): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps": return () if isinstance(out, torch.Tensor): @@ -980,7 +1037,18 @@ def _case_four_transform(t): @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_variant_consistency_eager(self, device, dtype, op): # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) - + MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + "nn.functional.max_pool2d", # hard crash: buffer is not large enough + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) method = op.method_variant inplace = op.inplace_variant operator = op.operator_variant @@ -1163,6 +1231,9 @@ def _test_inplace_preserve_storage(samples, variants): @ops(op_db, allowed_dtypes=(torch.complex32,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_complex_half_reference_testing(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) if not op.supports_dtype(torch.complex32, device): unittest.skip("Does not support complex32") @@ -1431,6 +1502,12 @@ class TestCompositeCompliance(TestCase): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_operator(self, device, dtype, op): + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) samples = op.sample_inputs(device, dtype, requires_grad=False) for sample in samples: @@ -1444,6 +1521,18 @@ def test_operator(self, device, dtype, op): ) @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) def test_backward(self, device, dtype, op): + MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) + samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: @@ -1461,6 +1550,12 @@ def test_backward(self, device, dtype, op): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_forward_ad(self, device, dtype, op): + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) if torch.float not in op.supported_backward_dtypes(device): raise unittest.SkipTest("Does not support autograd") @@ -1594,6 +1689,10 @@ def clone_and_perform_view(input, **kwargs): @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_conj_view(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) + if not op.test_conjugated_samples: self.skipTest("Operation doesn't support conjugated inputs.") math_op_physical = torch.conj_physical @@ -1637,6 +1736,9 @@ def test_neg_view(self, device, dtype, op): @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_conj_view(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") if not op.test_conjugated_samples: @@ -2012,6 +2114,17 @@ def test_refs_are_in_decomp_table(self, op): class TestFakeTensor(TestCase): def _test_fake_helper(self, device, dtype, op, context): + if(device == "cpu"): + return + MPS_BLOCKLIST = [ + "bfloat16", # hard crash on unsupoorted type byte size + "cdouble", # hard crash on unsupoorted type byte size + "cfloat", # hard crash on unsupoorted type byte size + "chalf", # hard crash on unsupoorted type byte size + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) name = op.name if op.variant_test_name: name += "." + op.variant_test_name @@ -2163,10 +2276,20 @@ def _test_fake_crossref_helper(self, device, dtype, op, context): op.gradcheck_wrapper) @skipIfRocm - @onlyCUDA @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails) def test_fake_crossref_backward_no_amp(self, device, dtype, op): + MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) @skipIfRocm