From 984150bcd4ceb847102f501f6fee14fd352f1e45 Mon Sep 17 00:00:00 2001 From: Ronian526 Date: Fri, 24 Feb 2023 17:43:02 -0800 Subject: [PATCH 1/3] - update _mac-test-mps.yml file --- .github/workflows/_mac-test-mps.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index 9348cb590e95f..0b50032f5ff32 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -97,6 +97,19 @@ jobs: ${CONDA_RUN} python3 test/test_modules.py -k mps --verbose + - name: Run MPS Test Ops + id: test_3 + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + PYTORCH_TEST_WITH_SLOW: 1 + shell: arch -arch arm64 bash {0} + # During bring up of NN don't show this as an error. + continue-on-error: true + run: | + # shellcheck disable=SC1090 + set -ex + ${CONDA_RUN} python3 test/test_ops.py -k mps --verbose + - name: Print remaining test logs shell: bash if: always() From 68fef584618927f9f9788f4252c4846d98abf611 Mon Sep 17 00:00:00 2001 From: Ronian526 Date: Fri, 24 Feb 2023 16:22:58 -0800 Subject: [PATCH 2/3] - remove fft related test from blocklist --- test/test_ops.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index d40b625f93eac..5e22308639318 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -262,6 +262,10 @@ def test_numpy_ref(self, device, dtype, op): @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) def test_compare_cpu(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) + def to_cpu(arg): if isinstance(arg, torch.Tensor): return arg.to(device='cpu') @@ -484,6 +488,15 @@ def test_python_ref_torch_fallback(self, device, dtype, op): @parametrize('executor', ['aten', 'nvfuser']) @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_executor(self, device, dtype, op, executor): + if device == "mps" and executor == 'nvfuser': + return + MPS_BLOCKLIST = [ + "_refs.floor_divide", # hard crash on unsupoorted ComplexFloat + "_refs.where", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) # TODO: Not all dtypes are supported with nvfuser from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map: @@ -673,6 +686,9 @@ def test_out_warning(self, device, op): else list(supported_dtypes)[0] ) + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) samples = op.sample_inputs(device, dtype) for sample in samples: # calls it normally to get the expected result @@ -792,6 +808,13 @@ def _any_nonempty(out): @ops(_ops_and_refs, dtypes=OpDTypes.any_one) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_out(self, device, dtype, op): + MPS_BLOCKLIST = [ + "_refs._conversions.complex", # hard crash on unsupoorted ComplexFloat + "bitwise_not", # seg fault + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) # Prefers running in float32 but has a fallback for the first listed supported dtype samples = op.sample_inputs(device, dtype) for sample in samples: @@ -980,7 +1003,12 @@ def _case_four_transform(t): @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_variant_consistency_eager(self, device, dtype, op): # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) - + MPS_BLOCKLIST = [ + "nn.functional.max_pool2d", # hard crash: buffer is not large enough + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) method = op.method_variant inplace = op.inplace_variant operator = op.operator_variant @@ -1431,6 +1459,9 @@ class TestCompositeCompliance(TestCase): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_operator(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) samples = op.sample_inputs(device, dtype, requires_grad=False) for sample in samples: @@ -1444,6 +1475,10 @@ def test_operator(self, device, dtype, op): ) @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) def test_backward(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) + samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: @@ -1461,6 +1496,9 @@ def test_backward(self, device, dtype, op): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_forward_ad(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) if torch.float not in op.supported_backward_dtypes(device): raise unittest.SkipTest("Does not support autograd") @@ -2167,6 +2205,9 @@ def _test_fake_crossref_helper(self, device, dtype, op, context): @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails) def test_fake_crossref_backward_no_amp(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) @skipIfRocm From ad6f918e269b71dde80b5577f9b0e393ac13cb85 Mon Sep 17 00:00:00 2001 From: Ronian526 Date: Wed, 22 Feb 2023 15:18:43 -0800 Subject: [PATCH 3/3] Enable test_ops.py for mps - add file to lint runner - add file to git runner - skip tests that are crashing on MPS inside test_ops.py --- .github/workflows/_mac-test-mps.yml | 4 +- .github/workflows/lint.yml | 2 +- test/test_ops.py | 112 ++++++++++++++++++++++++---- 3 files changed, 100 insertions(+), 18 deletions(-) diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index 0b50032f5ff32..7160510229b2b 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -101,14 +101,14 @@ jobs: id: test_3 env: ENV_NAME: conda-test-env-${{ github.run_id }} - PYTORCH_TEST_WITH_SLOW: 1 shell: arch -arch arm64 bash {0} # During bring up of NN don't show this as an error. continue-on-error: true run: | # shellcheck disable=SC1090 set -ex - ${CONDA_RUN} python3 test/test_ops.py -k mps --verbose + # TODO(https://github.com/pytorch/pytorch/issues/79293) + ${CONDA_RUN} PYTORCH_TEST_WITH_SLOW=1 python3 test/test_ops.py -k mps --verbose - name: Print remaining test logs shell: bash diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28737339bafca..525bf3ae69847 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -69,7 +69,7 @@ jobs: # shellcheck disable=SC1090 set -ex set +e - if ! ${CONDA_RUN} lintrunner --force-color test/*.py aten/src/ATen/native/mps/*.h aten/src/ATen/native/mps/*.mm aten/src/ATen/native/mps/operations/*; then + if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py test/test_modules.py test/test_ops.py; then echo "" echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m" echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" diff --git a/test/test_ops.py b/test/test_ops.py index 5e22308639318..bbce0430e1593 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -18,6 +18,7 @@ from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, all_types_and_complex_and, + get_all_dtypes, ) from torch.testing._internal.common_utils import ( @@ -110,6 +111,17 @@ aten = torch.ops.aten +MPS_DTYPES = get_all_dtypes() +for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: + del MPS_DTYPES[MPS_DTYPES.index(t)] + +def _get_mps_error_msg(device, dtype, op, mps_blocklist): + if torch.backends.mps.is_available() and device == "mps" and dtype not in MPS_DTYPES: + return f"MPS doesn't support {str(dtype)} datatype" + if op.name.startswith(tuple(mps_blocklist)): + return "MPS doesn't support op " + str(op.name) + return None + # Tests that apply to all operators and aren't related to any particular # system class TestCommon(TestCase): @@ -256,13 +268,15 @@ def test_numpy_ref(self, device, dtype, op): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA @suppress_warnings @slowTest @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) def test_compare_cpu(self, device, dtype, op): - msg = _get_mps_error_msg(device, dtype, op, []) + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) @@ -275,7 +289,7 @@ def to_cpu(arg): for sample in samples: cpu_sample = sample.transform(to_cpu) - cuda_results = op(sample.input, *sample.args, **sample.kwargs) + gpu_results = op(sample.input, *sample.args, **sample.kwargs) cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) # output_process_fn_grad has a very unfortunate name @@ -283,12 +297,12 @@ def to_cpu(arg): # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. # CPU and CUDA implementations of the SVD can return valid SVDs that are different. # We use this function to compare them. - cuda_results = sample.output_process_fn_grad(cuda_results) + gpu_results = sample.output_process_fn_grad(gpu_results) cpu_results = cpu_sample.output_process_fn_grad(cpu_results) # Lower tolerance because we are running this as a `@slowTest` # Don't want the periodic tests to fail frequently - self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) + self.assertEqual(gpu_results, cpu_results, atol=1e-3, rtol=1e-3) # Tests that experimental Python References can propagate shape, dtype, # and device metadata properly. @@ -483,7 +497,6 @@ def test_python_ref_torch_fallback(self, device, dtype, op): self._ref_test_helper(contextlib.nullcontext, device, dtype, op) @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @onlyCUDA @ops(python_ref_db) @parametrize('executor', ['aten', 'nvfuser']) @skipIfTorchInductor("Takes too long for inductor") @@ -491,6 +504,11 @@ def test_python_ref_executor(self, device, dtype, op, executor): if device == "mps" and executor == 'nvfuser': return MPS_BLOCKLIST = [ + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat "_refs.floor_divide", # hard crash on unsupoorted ComplexFloat "_refs.where", # hard crash on unsupoorted ComplexFloat ] @@ -676,6 +694,12 @@ def test_noncontiguous_samples(self, device, dtype, op): @ops(_ops_and_refs, dtypes=OpDTypes.none) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_out_warning(self, device, op): + MPS_BLOCKLIST = [ + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat + ] # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: @@ -686,7 +710,7 @@ def test_out_warning(self, device, op): else list(supported_dtypes)[0] ) - msg = _get_mps_error_msg(device, dtype, op, []) + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) samples = op.sample_inputs(device, dtype) @@ -732,7 +756,7 @@ def _extract_strides(out): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps": return () if isinstance(out, torch.Tensor): @@ -810,7 +834,17 @@ def _any_nonempty(out): def test_out(self, device, dtype, op): MPS_BLOCKLIST = [ "_refs._conversions.complex", # hard crash on unsupoorted ComplexFloat - "bitwise_not", # seg fault + "_refs.fft.fft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat + "_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat + "bitwise_not", # hard crash on unsupoorted ComplexFloat + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat ] msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: @@ -859,7 +893,7 @@ def _extract_strides(out): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps": return () if isinstance(out, torch.Tensor): @@ -1004,7 +1038,13 @@ def _case_four_transform(t): def test_variant_consistency_eager(self, device, dtype, op): # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat "nn.functional.max_pool2d", # hard crash: buffer is not large enough + "stft", # hard crash on unsupoorted ComplexFloat ] msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: @@ -1191,6 +1231,9 @@ def _test_inplace_preserve_storage(samples, variants): @ops(op_db, allowed_dtypes=(torch.complex32,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_complex_half_reference_testing(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) if not op.supports_dtype(torch.complex32, device): unittest.skip("Does not support complex32") @@ -1459,7 +1502,10 @@ class TestCompositeCompliance(TestCase): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_operator(self, device, dtype, op): - msg = _get_mps_error_msg(device, dtype, op, []) + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) samples = op.sample_inputs(device, dtype, requires_grad=False) @@ -1475,7 +1521,15 @@ def test_operator(self, device, dtype, op): ) @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) def test_backward(self, device, dtype, op): - msg = _get_mps_error_msg(device, dtype, op, []) + MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) @@ -1496,7 +1550,10 @@ def test_backward(self, device, dtype, op): ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_forward_ad(self, device, dtype, op): - msg = _get_mps_error_msg(device, dtype, op, []) + MPS_BLOCKLIST = [ + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) if torch.float not in op.supported_backward_dtypes(device): @@ -1632,6 +1689,10 @@ def clone_and_perform_view(input, **kwargs): @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_conj_view(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) + if not op.test_conjugated_samples: self.skipTest("Operation doesn't support conjugated inputs.") math_op_physical = torch.conj_physical @@ -1675,6 +1736,9 @@ def test_neg_view(self, device, dtype, op): @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_conj_view(self, device, dtype, op): + msg = _get_mps_error_msg(device, dtype, op, []) + if msg is not None: + self.skipTest(msg) if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") if not op.test_conjugated_samples: @@ -2050,6 +2114,17 @@ def test_refs_are_in_decomp_table(self, op): class TestFakeTensor(TestCase): def _test_fake_helper(self, device, dtype, op, context): + if(device == "cpu"): + return + MPS_BLOCKLIST = [ + "bfloat16", # hard crash on unsupoorted type byte size + "cdouble", # hard crash on unsupoorted type byte size + "cfloat", # hard crash on unsupoorted type byte size + "chalf", # hard crash on unsupoorted type byte size + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) + if msg is not None: + self.skipTest(msg) name = op.name if op.variant_test_name: name += "." + op.variant_test_name @@ -2201,11 +2276,18 @@ def _test_fake_crossref_helper(self, device, dtype, op, context): op.gradcheck_wrapper) @skipIfRocm - @onlyCUDA @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails) def test_fake_crossref_backward_no_amp(self, device, dtype, op): - msg = _get_mps_error_msg(device, dtype, op, []) + MPS_BLOCKLIST = [ + "fft.fft", # hard crash on unsupoorted ComplexFloat + "fft.ifft", # hard crash on unsupoorted ComplexFloat + "fft.ihfft", # hard crash on unsupoorted ComplexFloat + "fft.rfft2", # hard crash on unsupoorted ComplexFloat + "fft.rfft", # hard crash on unsupoorted ComplexFloat + "stft", # hard crash on unsupoorted ComplexFloat + ] + msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST) if msg is not None: self.skipTest(msg) self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)