diff --git a/.lintrunner.toml b/.lintrunner.toml index 91a3034d96..631c004e67 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -13,6 +13,11 @@ exclude_patterns = [ 'test/regressions/test_unary.py', 'tools/codegen/remove_headers.py', '.github/scripts/*.py', + 'test/xpu/test_*.py', + 'test/xpu/**/test_*.py', + 'test/xpu/dynamo/test_*.py', + 'test/xpu/dynamo/test_subclasses_xpu.py', + 'test/xpu/dynamo/test_structured_trace_xpu.py', ] command = [ 'uv', @@ -810,6 +815,8 @@ exclude_patterns = [ # Port of upstream test/dynamo/test_misc.py which legitimately calls # ShapeEnv.create_unbacked_* to test ShapeEnv equality. "test/xpu/dynamo/test_misc_xpu.py", + "test/xpu/dynamo/test_structured_trace_xpu.py", + "test/xpu/dynamo/test_subclasses_xpu.py", ] command = [ 'python3', diff --git a/test/xpu/dynamo/test_activation_checkpointing_xpu.py b/test/xpu/dynamo/test_activation_checkpointing_xpu.py new file mode 100644 index 0000000000..a154dcf5d0 --- /dev/null +++ b/test/xpu/dynamo/test_activation_checkpointing_xpu.py @@ -0,0 +1,3335 @@ +# Owner(s): ["module: dynamo"] +import contextlib +import copy +import functools +import math +import re +import unittest +from importlib import import_module + +import torch +import torch._dynamo.config +import torch._dynamo.test_case +import torch._functorch.config +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from functorch.compile import ( + default_partition, + min_cut_rematerialization_partition, + nop, +) +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.testing import ( + AotEagerAndRecordGraphs, + CompileCounterWithBackend, + normalize_gm, +) +from torch._higher_order_ops.wrap import tag_activation_checkpoint +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + parametrize, + skipIfHpu, + TEST_CUDA, +) +from torch.testing._internal.inductor_utils import HAS_GPU_AND_TRITON +from torch.testing._internal.triton_utils import requires_gpu_and_triton +from torch.testing._internal.two_tensor import TwoTensor +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) + +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + +if HAS_GPU_AND_TRITON: + import triton + from triton import language as tl + + @triton.jit + def add_one_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x + 1 + tl.store(out_ptr + offsets, output, mask=mask) + + +requires_distributed = functools.partial( + unittest.skipIf, not dist.is_available(), "requires distributed" +) + + +def checkpoint_wrapper(fn): + def inner(*args): + return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) + + return inner + + +def _grad(*args, **kwargs): + return torch.autograd.grad(*args, **kwargs) + + +def count_ops( + gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None +): + def match_rng_op(node, op): + if isinstance(node.target, torch._ops.HigherOrderOperator): + if node.name == "run_and_save_rng_state": + return node.args[0] == op + elif node.name == "run_with_rng_state": + return node.args[1] == op + elif node.name == "graphsafe_run_with_rng_state": + return node.args[0] == op + return False + + # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) + if op is not None: + if isinstance(op, list): + raise AssertionError("Expected op to not be a list") + ops = [op] + if freq is not None: + freqs = [freq] + if freq_ge is not None: + freqs_ge = [freq_ge] + if freqs: + for op, freq in zip(ops, freqs): + actual_count = 0 + for node in gm.graph.nodes: + if match_rng_op(node, op) or node.target == op: + actual_count += 1 + err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}." + if actual_count != freq: + raise AssertionError(err_msg) + else: + if freqs_ge is None: + raise AssertionError("Expected freqs_ge to not be None") + for op, freq_ge in zip(ops, freqs_ge): + actual_count = 0 + for node in gm.graph.nodes: + if match_rng_op(node, op) or node.target == op: + actual_count += 1 + if actual_count < freq_ge: + raise AssertionError( + f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + ) + return gm + + +def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]): + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph + return_node = list(graph.nodes)[-1] + if return_node.target != "output": + raise AssertionError( + f"Expected return_node.target to be 'output', got {return_node.target}" + ) + for x in return_node.args[0]: + fwd_outputs.add(str(x)) + + +class _InvalidContext: + def __init__(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +def _invalid_context_gen(): + return _InvalidContext(), _InvalidContext() + + +def find_first_node(gm, func): + for node in gm.graph.nodes: + if node.target is func: + return node + return None + + +def op_count(gm): + result = 0 + for node in gm.graph.nodes: + if "call" in node.op: + result += 1 + return result + + +def _get_custom_policy(no_recompute_list=None, must_recompute_list=None): + def _custom_policy(ctx, func, *args, **kwargs): + if no_recompute_list is not None and func in no_recompute_list: + return CheckpointPolicy.MUST_SAVE + if must_recompute_list is not None and func in must_recompute_list: + return CheckpointPolicy.MUST_RECOMPUTE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + +class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): + def _validate( + self, + fn, + backend, + *args, + skip_check=False, + fullgraph=True, + compiled_autograd=False, + ): + cloned_args = [] + for arg in args: + cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad)) + + cloned_fn = copy.deepcopy(fn) + + torch.manual_seed(0) + expected = fn(*args) + expected.sum().backward() + + torch.manual_seed(0) + compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend) + ctx = contextlib.nullcontext() + if compiled_autograd: + ctx = torch._dynamo.compiled_autograd._enable( + lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend) + ) + with ctx: + result = compiled_fn(*cloned_args) + result.sum().backward() + + if not skip_check: + self.assertEqual( + result, + expected, + msg="Output mismatch between torch.compile and eager versions", + ) + for arg, cloned_arg in zip(args, cloned_args): + self.assertEqual( + arg.grad, + cloned_arg.grad, + msg="Gradient mismatch between torch.compile and eager versions", + ) + + def _compare_orig_and_checkpointed_fns( + self, orig_fn, checkpointed_fn, *args, fullgraph=True + ): + # The original version and the checkpointed version of the same function + # should produce the same outputs and the same gradients under torch.compile. + + def clone_args(args): + cloned_args = [] + for arg in args: + cloned_args.append( + arg.detach().clone().requires_grad_(arg.requires_grad) + ) + return cloned_args + + def run(compiler): + # Run original version + cloned_args_orig_fn = clone_args(args) + torch.manual_seed(0) + compiled_orig_fn = compiler(orig_fn) + result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn) + result_orig_fn.sum().backward() + + # Run checkpointed version + cloned_args_checkpointed_fn = clone_args(args) + torch.manual_seed(0) + compiled_checkpointed_fn = compiler(copy.deepcopy(checkpointed_fn)) + result_checkpointed_fn = compiled_checkpointed_fn( + *cloned_args_checkpointed_fn + ) + result_checkpointed_fn.sum().backward() + + # Check that outputs and gradients are equal + self.assertEqual( + result_orig_fn, + result_checkpointed_fn, + msg="Output mismatch between the original version and the checkpointed version of the same function", + ) + for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip( + cloned_args_orig_fn, cloned_args_checkpointed_fn + ): + self.assertEqual( + cloned_arg_orig_fn.grad, + cloned_arg_checkpointed_fn.grad, + msg="Gradient mismatch between the original version and the checkpointed version of the same function", + ) + + run(functools.partial(torch.compile, fullgraph=fullgraph)) + if fullgraph: + + def export_compiler(fn): + class WrapAsModule(nn.Module): + def forward(self, *args, **kwargs): + return fn(*args, **kwargs) + + mod = WrapAsModule() + + def runtime_wrapper(*runtime_args): + from torch.export import _trace + + gm = _trace._export_to_torch_ir( + f=mod, + args=tuple(clone_args(args)), + kwargs={}, + dynamic_shapes=None, + preserve_module_call_signature=(), + restore_fqn=False, + prefer_deferred_runtime_asserts_over_guards=False, + _log_export_usage=False, + ) + # NOTE: this is necessary for rng to be added to the exported graph + return torch.compile( + gm, fullgraph=False, backend="aot_eager_decomp_partition" + )(*runtime_args) + + return runtime_wrapper + + run(export_compiler) + + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function(self, device, partition_fn): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) + bw_compiler = functools.partial( + count_ops, freq=3, op=torch.ops.aten.mm.default + ) # mm recomputed in the bwd + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_via_global_checkpoint(self, device, partition_fn): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + # This goes through VariableBuilder + return checkpoint(gn, torch.sin(x), y, use_reentrant=True) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) + bw_compiler = functools.partial( + count_ops, freq=3, op=torch.ops.aten.mm.default + ) # mm recomputed in the bwd + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_with_kwargs(self, device, partition_fn): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=False + ) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) + bw_compiler = functools.partial( + count_ops, freq=3, op=torch.ops.aten.mm.default + ) # mm recomputed in the bwd + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + def test_checkpoint_shows_tags_in_tlparse(self, device): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + ) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + def partition_fn(joint_gm, *args, **kwargs): + gm_str = joint_gm.print_readable(print_output=False) + # Check for the pattern with any graph ID (the ID depends on test order) + self.assertTrue( + re.search(r"# ac_graph_id: \d+ - PREFER_RECOMPUTE", gm_str), + f"Expected ac_graph_id pattern not found in:\n{gm_str}", + ) + return min_cut_rematerialization_partition(joint_gm, *args, **kwargs) + + backend = aot_autograd( + fw_compiler=nop, bw_compiler=nop, partition_fn=partition_fn + ) + _ = torch.compile(fn, backend=backend)(x, y) + + @requires_gpu_and_triton + def test_ac_tags_through_custom_autograd_function(self, device): + class MyMM(torch.autograd.Function): + @staticmethod + def forward(x, w): + return x @ w + + @staticmethod + def setup_context(ctx, inputs, output): + x, w = inputs + ctx.save_for_backward(x, w) + + @staticmethod + def backward(ctx, grad): + x, w = ctx.saved_tensors + return grad @ w.t(), x.t() @ grad + + def gn(x, w): + return MyMM.apply(x, w) + + def fn(x, w): + return torch.utils.checkpoint.checkpoint(gn, x, w, use_reentrant=False) + + x = torch.randn(4, 4, device=device, requires_grad=True) + w = torch.randn(4, 4, device=device, requires_grad=True) + + def partition_fn(joint_gm, *args, **kwargs): + fwd_mm_nodes = [ + node + for node in joint_gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.mm.default + and node.meta.get("partitioner_tag") == "is_forward" + ] + self.assertTrue( + fwd_mm_nodes, "Expected forward mm nodes in the joint graph" + ) + for node in fwd_mm_nodes: + self.assertIn("recompute", node.meta) + self.assertIn("ac_graph_id", node.meta) + return min_cut_rematerialization_partition(joint_gm, *args, **kwargs) + + backend = aot_autograd( + fw_compiler=nop, bw_compiler=nop, partition_fn=partition_fn + ) + out = torch.compile(fn, backend=backend)(x, w) + out.sum().backward() + + @requires_gpu_and_triton + def test_sac_tags_through_custom_autograd_function(self, device): + class MyMM(torch.autograd.Function): + @staticmethod + def forward(x, w): + return x @ w + + @staticmethod + def setup_context(ctx, inputs, output): + x, w = inputs + ctx.save_for_backward(x, w) + + @staticmethod + def backward(ctx, grad): + x, w = ctx.saved_tensors + return grad @ w.t(), x.t() @ grad + + def gn(x, w): + return MyMM.apply(x, w) + + context_fn = functools.partial( + torch.utils.checkpoint.create_selective_checkpoint_contexts, + lambda ctx, + op, + *args, + **kwargs: torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE, + ) + + def fn(x, w): + return torch.utils.checkpoint.checkpoint( + gn, x, w, use_reentrant=False, context_fn=context_fn + ) + + x = torch.randn(4, 4, device=device, requires_grad=True) + w = torch.randn(4, 4, device=device, requires_grad=True) + + def partition_fn(joint_gm, *args, **kwargs): + fwd_mm_nodes = [ + node + for node in joint_gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.mm.default + and node.meta.get("partitioner_tag") == "is_forward" + ] + self.assertTrue( + fwd_mm_nodes, "Expected forward mm nodes in the joint graph" + ) + for node in fwd_mm_nodes: + self.assertIn("recompute", node.meta) + self.assertIn("ac_graph_id", node.meta) + return min_cut_rematerialization_partition(joint_gm, *args, **kwargs) + + backend = aot_autograd( + fw_compiler=nop, bw_compiler=nop, partition_fn=partition_fn + ) + out = torch.compile(fn, backend=backend)(x, w) + out.sum().backward() + + @requires_gpu_and_triton + def test_tangent_placeholders_have_is_backward_tag(self, device): + """Test that tangent placeholders in the joint graph are tagged with is_backward.""" + + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=False + ) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + def partition_fn(joint_gm, *args, **kwargs): + # Check partitioner_tag on placeholder nodes + for node in joint_gm.graph.nodes: + if node.op == "placeholder": + if "tangents" in str(node.target): + self.assertTrue( + "is_backward" in node.meta.get("partitioner_tag", "") + ) + else: + self.assertTrue( + "is_forward" in node.meta.get("partitioner_tag", "") + ) + return min_cut_rematerialization_partition(joint_gm, *args, **kwargs) + + backend = aot_autograd( + fw_compiler=nop, bw_compiler=nop, partition_fn=partition_fn + ) + _ = torch.compile(fn, backend=backend)(x, y) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_sequential_layers(self, device, partition_fn): + def gn(x): + x = x.cos() + for _ in range(3): + x = torch.mm(x, x) + x = x.cos() + return x + + def fn(x): + x = torch.utils.checkpoint.checkpoint(gn, x) + x = torch.utils.checkpoint.checkpoint(gn, x) + return x + + x = torch.randn(4, 4, device=device, requires_grad=True) + + fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default) + bw_compiler = functools.partial( + count_ops, + freqs=[2, 18], + ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], + ) # mm recomputed in the bwd + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_multiple_checkpoints(self, device, partition_fn): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + x = torch.sin(x) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(z) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + return z + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) + bw_compiler = functools.partial( + count_ops, freq=6, op=torch.ops.aten.mm.default + ) # mm recomputed in the bwd + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_module(self, device, partition_fn): + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return torch.sigmoid(self.linear(x)) + + mod = MockModule().to(device) + + def fn(x): + return torch.utils.checkpoint.checkpoint( + mod, torch.sin(x), use_reentrant=True + ) + + x = torch.randn(10, 10, device=device, requires_grad=True) + + fw_compiler = functools.partial( + count_ops, freq=1, op=torch.ops.aten.sigmoid.default + ) + bw_compiler = functools.partial( + count_ops, freq=1, op=torch.ops.aten.sigmoid.default + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_decomps(self, device, partition_fn): + # Ensures that tags are passed on through decompositions as well + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return torch.nn.functional.gelu(self.linear(x)) + + mod = MockModule().to(device) + + def fn(x): + return torch.utils.checkpoint.checkpoint( + mod, torch.sin(x), use_reentrant=True + ) + + x = torch.randn(10, 10, device=device, requires_grad=True) + + fw_compiler = functools.partial( + count_ops, freq=1, op=torch.ops.aten.erf.default + ) + bw_compiler = functools.partial( + count_ops, freq=1, op=torch.ops.aten.erf.default + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + ) + self._validate(fn, backend, x) + + @requires_gpu_and_triton + @torch._inductor.config.patch(fallback_random=True) + def test_tags_recomputed_rand(self, device): + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + return z + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) + # bw_compiler = functools.partial( + # count_ops, freq=6, op=torch.ops.aten.mm.default + # ) # mm recomputed in the bwd + # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = "inductor" + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @torch._inductor.config.patch(fallback_random=True) + def test_tags_rand(self, device): + def gn(x, y): + x = torch.mm(x, y) + x = torch.mm(x, y) + return x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + # x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + return x + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) + # bw_compiler = functools.partial( + # count_ops, freq=6, op=torch.ops.aten.mm.default + # ) # mm recomputed in the bwd + # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + # backend = "aot_eager" + backend = "inductor" + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @torch._inductor.config.patch(fallback_random=True) + def test_tags_dropout(self, device): + # Figure out a way to test the number of inductor_random calls + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.dropout = torch.nn.Dropout(0.2) + + def forward(self, x): + return self.dropout(self.linear(x)) + + mod = MockModule().to(device) + + def fn(x): + return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) + + x = torch.randn(10, 10, device=device, requires_grad=True) + backend = "inductor" + # rand decomps do not have have numerical results as eager + self._validate(fn, backend, x, skip_check=True) + + @skipIfHpu + @torch._functorch.config.patch(recompute_views=True) + @torch._inductor.config.patch(fx_graph_cache=False) + def test_tags_must_save_tensor_that_has_backward_hook(self): + def my_post_forward_hook(submod, args, output): + output.register_hook(my_backward_hook) + return output + + def my_backward_hook(grad): + return grad + + class MySubmod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.matmul(x, x) + z = y * y + return z + + class MyMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod = MySubmod() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + out = torch.utils.checkpoint.checkpoint( + self.submod, x, use_reentrant=False + ) + norm_out = self.norm(out) + return norm_out + + def _factory_fn(): + mod = MyMod() + x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True) + backend = "inductor" + return mod, x, backend + + mod_no_hook, x, backend = _factory_fn() + mod_no_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs + ) + ): + self._validate( + mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + torch._dynamo.reset() + mod_with_hook, x, backend = _factory_fn() + mod_with_hook.submod.register_forward_hook(my_post_forward_hook) + mod_with_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs + ) + ): + self._validate( + mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + # If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors. + mod_no_hook_fwd_outputs_no_primal = { + x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_") + } + mod_with_hook_fwd_outputs_no_primal = { + x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_") + } + additional_saved_tensors = ( + mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal + ) + expected_additional_saved_tensors = {"mul"} + self.assertEqual( + additional_saved_tensors, + expected_additional_saved_tensors, + f""" +Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}. +Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}. +Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", + ) + + @requires_gpu_and_triton + def test_fallback(self, device): + def gn(x, y): + torch._dynamo.graph_break() + a = torch.sigmoid(torch.matmul(x, y)) + torch._dynamo.graph_break() + return torch.cos(a) + + def fn(x, y): + return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False)) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + args = (x, y) + + backend = "aot_eager" + cnt = CompileCounterWithBackend(backend) + + expected = fn(*args) + result = torch.compile(fn, backend=cnt)(*args) + + self.assertEqual(result, expected) + + # One graph for torch.sin on the input, and other for torch.cos. + self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.op_count, 2) + self.assertEqual(len(cnt.graphs), 2) + + @requires_gpu_and_triton + def test_kwargs(self, device): + def gn(x, y, z=None): + a = torch.matmul(x, y) + if z is not None: + return torch.matmul(a, z) + return a + + def fn(x, y, z): + return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + z = torch.randn(4, 4, requires_grad=True, device=device) + args = (x, y, z) + + backend = "aot_eager" + cnt = CompileCounterWithBackend(backend) + + expected = fn(*args) + result = torch.compile(fn, backend=cnt)(*args) + + self.assertEqual(result, expected) + + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(len(cnt.graphs), 1) + + wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) + # one for checkpoint, and 3 for x, y, z + self.assertEqual(len(wrap_node.args), 4) + + body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) + self.assertEqual(op_count(body_function), 2) + + @requires_gpu_and_triton + def test_symints_location(self, device): + def gn(x, y): + return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + + backend = "aot_eager" + cnt = CompileCounterWithBackend(backend) + opt_fn = torch.compile(fn, backend=cnt) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + args = (x, y) + expected = fn(*args) + result = opt_fn(*args) + + x = torch.randn(5, 5, requires_grad=True, device=device) + y = torch.randn(5, 5, requires_grad=True, device=device) + args = (x, y) + expected = fn(*args) + result = opt_fn(*args) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(cnt.frame_count, 2) + self.assertEqual(len(cnt.graphs), 2) + wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) + self.assertEqual(len(wrap_node.args), 3) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): + def context_fn_must_recompute_mm(): + must_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy( + must_recompute_list=must_recompute_list, + ), + ) + + def context_fn_no_recompute_mm(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy( + no_recompute_list=no_recompute_list, + ), + ) + + def _test(context_fn, bw_compiler, partition_fn): + def gn(x): + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) + + def fn(x): + return torch.utils.checkpoint.checkpoint( + gn, + x, + use_reentrant=False, + context_fn=context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=2, + op=torch.ops.aten.mm.default, + ) + + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x) + + _test( + context_fn=context_fn_must_recompute_mm, + bw_compiler=functools.partial( + count_ops, + freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) + op=torch.ops.aten.mm.default, + ), + partition_fn=partition_fn, + ) + _test( + context_fn=context_fn_no_recompute_mm, + bw_compiler=functools.partial( + count_ops, + freq=4, # 2 bwd mm ops per fwd matmul + op=torch.ops.aten.mm.default, + ), + partition_fn=partition_fn, + ) + + def test_sac_with_partial_context_fn(self): + class CustomPolicy: + def __init__(self): + super().__init__() + + def __call__(self, ctx, out, func, *args, **kwargs): + return CheckpointPolicy.MUST_SAVE + + def f(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + context_fn1 = functools.partial( + create_selective_checkpoint_contexts, CustomPolicy() + ) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + f, + x, + y, + use_reentrant=False, + context_fn=context_fn1, + ) + + opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True) + a = torch.randn(4, 4, requires_grad=True, device="cpu") + b = torch.randn(4, 4, requires_grad=True, device="cpu") + + expected = fn(a, b) + result = opt_fn(a, b) + self.assertEqual(result, expected) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm( + self, device, partition_fn + ): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=2, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # We would've expected 6 here + # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) + # if we didn't enable selective checkpointing. + freq=4, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( + self, device, partition_fn + ): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=1, + op=torch.ops.aten.sigmoid.default, + ) + bw_compiler = functools.partial( + count_ops, + # Main check here is just that sigmoid is properly recomputed + # (we will see a sigmoid() and sigmoid_backward() in the bw graph) + freq=1, + op=torch.ops.aten.sigmoid.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + disable_functionalization=True, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): + # Copy of the above test, but make sure that having a triton kernel in the + # region does not error. + def add_one(x): + out = torch.empty_like(x) + n_elements = x.numel() + add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + return out + + class AddOne(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return add_one(x) + + @staticmethod + def backward(ctx, x): + return x + + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return ( + torch.sigmoid(torch.matmul(torch.matmul(AddOne.apply(x.sin()), y), y)) + * y + ) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=2, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # We would've expected 6 here + # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) + # if we didn't enable selective checkpointing. + freq=4, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + rand_tensor = torch.randn(4, 4, requires_grad=True, device=device) + + # tensor subclasses as inputs + x = TwoTensor(rand_tensor, rand_tensor.clone()) + y = TwoTensor(rand_tensor.clone(), rand_tensor.clone()) + + fw_compiler = functools.partial( + count_ops, + freq=4, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # We would've expected 12 here + # (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12) + # if we didn't enable selective checkpointing. + freq=8, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): + def _get_custom_policy(meta): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + + def _custom_policy(mode, func, *args, **kwargs): + mm_count_key = f"{mode}_mm_count" + if mm_count_key not in meta: + meta[mm_count_key] = 0 + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except second mm + # (i.e. we will hint the partitioner to recompute second mm in backward pass) + return func in no_recompute_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] == 2 + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = {} + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + def gn(x, y): + return torch.sigmoid( + torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y) + ) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=2, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # Q: How do we come to this number 4? + # A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass, + # so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in + # the forward pass is recomputed in the backward pass is up to the partitioner to decide. + freq_ge=4, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): + def selective_checkpointing_context_fn(no_recompute_list): + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=functools.partial( + selective_checkpointing_context_fn, [torch.ops.aten.mm.default] + ), + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freq=2, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # We would've expected 6 here + # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) + # if we didn't enable selective checkpointing. + freq=4, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + torch.ops.aten.sigmoid.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list), + ) + + def gn(x, y): + return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu() + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freqs=[2, 1], + ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + ) + bw_compiler = functools.partial( + count_ops, + freqs=[4, 0], + ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): + def selective_checkpointing_context_fn(): + # recompute everything + no_recompute_list = [] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.cat([x, y]).sin() + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freqs=[1], + ops=[torch.ops.aten.cat.default], + ) + bw_compiler = functools.partial( + count_ops, + freqs=[1], + ops=[torch.ops.aten.cat.default], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skip( + "In-place op support in selective checkpointing + torch.compile " + "requires TorchDispatchMode + torch.compile work to complete" + ) + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + torch.ops.aten.sigmoid.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid( + torch.selu_(torch.matmul(torch.matmul(x, y), y)) + ).relu_() + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freqs=[2, 1], + ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + ) + bw_compiler = functools.partial( + count_ops, + freqs=[4, 0], + ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @torch._inductor.config.patch(fallback_random=True) + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_random_op(self, device, partition_fn): + for preserve_rng_state in [True, False]: + + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.sigmoid.default, + ] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x): + return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True)) + + def fn(x): + return torch.utils.checkpoint.checkpoint( + gn, + x, + use_reentrant=False, + # Regardless of whether `preserve_rng_state` is True or False, + # we will always preserve RNG state when using `torch.compile`. + preserve_rng_state=preserve_rng_state, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freqs=[2, 1], + ops=[ + torch.ops.aten.sigmoid.default, + torch.ops.aten.native_dropout.default, + ], + ) + bw_compiler = functools.partial( + count_ops, + # NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1). + freqs=[0, 1], + ops=[ + torch.ops.aten.sigmoid.default, + torch.ops.aten.native_dropout.default, + ], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + + # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, + # because eager version doesn't preserve RNG state while torch.compile still does. + # Hence when `preserve_rng_state` is False, we skip the output and gradient comparison + # between torch.compile and eager. + self._validate(fn, backend, x, skip_check=not preserve_rng_state) + self._compare_orig_and_checkpointed_fns(gn, fn, x) + + @requires_gpu_and_triton + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_invalid_context(self, partition_fn): + def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=_invalid_context_gen, + ) + + x = torch.randn(4, 4, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + + fw_compiler = functools.partial( + count_ops, + freq=1, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + freq_ge=2, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + with self.assertRaisesRegex( + Exception, "must generate a tuple of two `TorchDispatchMode`s" + ): + self._validate(fn, backend, x, y) + + @requires_gpu_and_triton + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_parametrization(self, partition_fn): + def sac_policy(): + def _recomp_policy(): + def _custom_policy(ctx, func, *args, **kwargs): + to_recompute = func in { + torch.ops.aten.mul.Tensor, + torch.ops.aten.sigmoid.default, + } + return ( + CheckpointPolicy.MUST_RECOMPUTE + if to_recompute + else CheckpointPolicy.MUST_SAVE + ) + + return _custom_policy + + return create_selective_checkpoint_contexts(_recomp_policy()) + + class Parametrization(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def parametrization(self, x): + return torch.sigmoid(torch.mul(x, x)) + + def forward(self, x): + return checkpoint( + self.parametrization, x, use_reentrant=False, context_fn=sac_policy + ) + + def apply_parametrization(model): + modules = list(model.modules()) + + for mod in modules: + params_dict = dict(mod.named_parameters(recurse=False)) + for p_name, p in params_dict.items(): + mod.register_parameter(p_name, nn.Parameter(p)) + nn.utils.parametrize.register_parametrization( + mod, p_name, Parametrization(), unsafe=True + ) + + return model + + class MLPModule(nn.Module): + def __init__(self) -> None: + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(16, 16, bias=False) + + def forward(self, x): + return self.net1(x) + + def reset_parameters(self): + self.net1.reset_parameters() + + fw_compiler = functools.partial( + count_ops, + freqs=[1, 1], + ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], + ) + bw_compiler = functools.partial( + count_ops, + freqs=[ + # 1 from mul recompute, 1 from mul backward + # w/o CSE, we have one extra mul + 3 if partition_fn is default_partition else 2, + 1, + ], + ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], + ) + + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) + + model = MLPModule() + model = apply_parametrization(model) + model_compiled = torch.compile( + copy.deepcopy(model), backend=backend, fullgraph=True + ) + input = torch.randn(8, 16, requires_grad=True) + input_compiled = copy.deepcopy(input) + + out = model(input) + out.sum().backward() + out_compiled = model_compiled(input_compiled) + out_compiled.sum().backward() + + self.assertEqual(out, out_compiled) + self.assertEqual(input.grad, input_compiled.grad) + + @requires_gpu_and_triton + def test_autocast_flash_attention(self, device): + def fn(primals_1, primals_2, primals_3): + return torch.ops.aten._scaled_dot_product_efficient_attention.default( + primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687 + )[0] + + def gn(*args): + return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) + + with torch.autocast(device_type=device): + x = torch.randn(4, 2, 16, 32, device=device, requires_grad=True) + y = torch.randn(4, 2, 16, 32, device=device, requires_grad=True) + z = torch.randn(4, 2, 16, 32, device=device, requires_grad=True) + args = (x, y, z) + + torch.manual_seed(0) + ref = gn(*args) + + opt_gn = torch.compile(gn, backend="aot_eager_decomp_partition") + torch.manual_seed(0) + res = opt_gn(*args) + self.assertEqual(ref, res) + + @requires_gpu_and_triton + def test_error_msg(self, device): + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + x = torch.sin(x) + torch._dynamo.graph_break() + x = torch.cos(x) + return x + + mod = MockModule().to(device) + + def fn(x): + return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) + + x = torch.randn(4, 4).to(device) + opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager_decomp_partition") + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "User-inserted graph break" + ): + opt_fn(x) + + @requires_gpu_and_triton + def test_list_inputs(self, device): + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, ys): + a = torch.sin(x) # noqa: F841 + b = torch.cos(ys[0]) + c = torch.cos(ys[1]) + return (x, [b, c]) + + mod = MockModule().to(device) + + def fn(x, ys): + return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True) + + x = torch.randn(4, 4).to(device) + y = torch.randn(4, 4).to(device) + z = torch.randn(4, 4).to(device) + ref = fn(x, [y, z]) + opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True) + res = opt_fn(x, [y, z]) + self.assertEqual(ref, res) + + @requires_gpu_and_triton + def test_pattern_matcher(self, device): + # Check that the sdpa op is recomputed in the backward graph + # tests percolate_tags + + @checkpoint_wrapper + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(1.0 / math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .matmul(value) + ) + + def fn(query, key, value): + # Checks that sin is not recomputed in the backward graph + return dot_prod_attention(query.sin(), key, value) + + tensor_shape = (4, 2, 16, 32) + dtype = torch.float16 + args1 = [ + torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True), + torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True), + torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True), + ] + + # Save the AOT graphs + aot_graphs = [] + from torch._inductor import compile_fx + + def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): + aot_graphs.append(graph) + return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs) + + backend = functools.partial( + compile_fx.compile_fx, inner_compile=debug_compile_fx_inner + ) + + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + opt_fn(*args1).sum().backward() + + fwd_graph = aot_graphs[0] + # Determine which fused attention backend is expected based on the + # prioritization logic in sdp_utils.cpp:check_prefer_cudnn_attention. + dprops = torch.get_device_module(device_type).get_device_properties(device) + cudnn_version = ( + torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 0 + ) + prefer_cudnn = ( + cudnn_version > 91500 and dprops.major in (9, 10) and dprops.minor in (0, 3) + ) + if prefer_cudnn and torch.version.cuda and TEST_CUDA: + sdpa_op = torch.ops.aten._scaled_dot_product_cudnn_attention.default + else: + sdpa_op = torch.ops.aten._scaled_dot_product_flash_attention.default + self.assertTrue(count_ops(fwd_graph, [], freq=1, op=sdpa_op)) + bwd_graph = aot_graphs[1] + # Check that sin is not recomputed in the backward graph - checks percolate tags + self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default)) + # Check that the sdpa op is recomputed in the backward graph + self.assertTrue(count_ops(bwd_graph, [], freq=1, op=sdpa_op)) + + @requires_distributed() + @requires_gpu_and_triton + def test_distributed_utils_checkpoint_wrapper(self): + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as dist_checkpoint_wrapper, + ) + + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.c = 2 + + def forward(self, x): + x = torch.sin(x) + x = self.linear(x) + x = torch.cos(x) + return x * self.c + + mod = dist_checkpoint_wrapper(MockModule()) + x = torch.randn(4, 4) + ref = mod(x) + opt_mod = torch.compile( + mod, backend="aot_eager_decomp_partition", fullgraph=True + ) + res = opt_mod(x) + self.assertEqual(ref, res) + + @requires_distributed() + @requires_gpu_and_triton + def test_dynamo_does_not_trace_getattr_as_top_frame(self): + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + ) + + cnt = CompileCounterWithBackend("aot_eager_decomp_partition") + + lin = torch.nn.Linear(1, 1) + mod = torch.nn.Sequential(lin, lin) + mod = CheckpointWrapper(mod) + mod._checkpoint_wrapped_module.a = torch.ones(1, 1) + + def fn(x): + return mod(x) * mod.a + + opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) + x = torch.randn(1, 1) + + self.assertEqual(opt_fn(x), fn(x)) + + def test_return_same_element_twice(self): + def gn(x): + y = torch.sin(x) + return y, y + + def fn(x): + return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True) + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None + getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None + return (getitem,) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[4, 4]"): + y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None + return (y,) +""", + ) + + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_mutation(self): + counter = 0 + + def gn(x): + nonlocal counter + counter += 1 + return torch.sin(x) + + def fn(x): + return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True) + + x = torch.randn(4, 4, requires_grad=True) + fn(x).sum().backward() + # The mutation is reapplied in the backward as well + self.assertEqual(counter, 2) + counter = 0 + + opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True) + opt_fn(x).sum().backward() + # The mutation is not reapplied in the backward because the flag was on. + self.assertEqual(counter, 1) + + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_list_mutation(self): + def gn(x, z): + out = x.sin() + z.append(out) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out + + def fn(x): + z = [] + + out1, out2 = torch.utils.checkpoint.checkpoint( + gn, + x, + z, + use_reentrant=False, + ) + + return out1, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_list_mutation_hidden(self): + def gn(x, z): + o = torch.matmul(x, x) @ x + out = x.sin() + z.append(out) + return torch.cos(torch.sin(o)), torch.sin(x) + + def fn(x): + z = [] + + outs = torch.utils.checkpoint.checkpoint( + gn, + x, + z, + use_reentrant=False, + ) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] + + return out1 + out2, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None + getitem_6: "f32[4, 4]" = tag_activation_checkpoint[0] + getitem_7: "f32[4, 4]" = tag_activation_checkpoint[1] + getitem_8: "f32[4, 4]" = tag_activation_checkpoint[2]; tag_activation_checkpoint = None + + add: "f32[4, 4]" = getitem_6 + getitem_7; getitem_6 = getitem_7 = None + return (add, getitem_8) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[4, 4]"): + matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_) + o: "f32[4, 4]" = matmul @ l_x_; matmul = None + + out: "f32[4, 4]" = l_x_.sin() + + sin_1: "f32[4, 4]" = torch.sin(o); o = None + cos: "f32[4, 4]" = torch.cos(sin_1); sin_1 = None + sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None + return (cos, sin_2, out) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[4, 4]"): + mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1) + mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None + + sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1) + + sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None + cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None + sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1) + + add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None + return (add, sin, primals_1) +""", + ) + + def test_frozen_dataclass_pytree_output(self): + import dataclasses + + from torch.utils import _pytree as pytree + + @dataclasses.dataclass(frozen=True) + class InputNode: + x: torch.Tensor + + @dataclasses.dataclass(frozen=True) + class OutputNode: + y: torch.Tensor + + pytree.register_dataclass(InputNode) + pytree.register_dataclass(OutputNode) + + def cleanup(): + # Clean up pytree registrations to avoid leaking state to other tests. + # We manually remove from the registries since _deregister_pytree_node + # has issues with classes registered without serialized_type_name. + for cls in [InputNode, OutputNode]: + pytree.SUPPORTED_NODES.pop(cls, None) + pytree.SUPPORTED_SERIALIZED_TYPES.pop(cls, None) + pytree.CONSTANT_NODES.discard(cls) + + try: + + class TinyMLP(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(4, 8) + self.fc2 = nn.Linear(8, 4) + + def forward(self, inp: InputNode) -> OutputNode: + h = self.fc1(inp.x) + h = torch.nn.functional.silu(h) + y = self.fc2(h) + return OutputNode(y=y) + + mlp = TinyMLP() + + def checkpointed_forward(inp): + return torch.utils.checkpoint.checkpoint( + mlp.forward, + inp, + use_reentrant=False, + preserve_rng_state=True, + ) + + input_eager = InputNode(x=torch.randn(2, 4, requires_grad=True)) + torch.manual_seed(0) + output_eager = checkpointed_forward(input_eager) + output_eager.y.sum().backward() + + input_compiled = InputNode( + x=input_eager.x.detach().clone().requires_grad_(True) + ) + torch.manual_seed(0) + compiled_fn = torch.compile( + checkpointed_forward, + fullgraph=True, + backend="aot_eager_decomp_partition", + ) + output_compiled = compiled_fn(input_compiled) + output_compiled.y.sum().backward() + + self.assertEqual(output_eager.y, output_compiled.y) + self.assertEqual(input_eager.x.grad, input_compiled.x.grad) + finally: + cleanup() + + def test_checkpoint_with_record_function(self): + # Test that record_function ops are allowed inside checkpointed functions. + # record_function is technically "impure" but safe to duplicate during + # activation checkpointing recompute since it only sets up profiling spans. + # This test verifies: + # 1. No assertion error about impure ops in AC + # 2. Forward graph contains record_function ops + # 3. Code produces correct results + def gn(x, y): + with torch.profiler.record_function("matmul_region"): + return torch.sigmoid(torch.matmul(x, y)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=False + ) + + x = torch.randn(4, 4, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + + # Verify record_function_enter_new appears in forward graph + fw_compiler = functools.partial( + count_ops, freq=1, op=torch.ops.profiler._record_function_enter_new.default + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=nop, + partition_fn=default_partition, + ) + # Enable capture_profiler_record_function to trace record_function ops + with torch._dynamo.config.patch(capture_profiler_record_function=True): + self._validate(fn, backend, x, y) + + def _get_sac_annotations(self, checkpointed_fn, policy_fn, decompositions=None): + annotations = [] + + def capture_partition(joint_gm, joint_args, **kwargs): + for node in joint_gm.graph.nodes: + if node.op == "call_function": + recompute = node.meta.get("recompute", None) + if recompute is not None: + annotations.append( + f"{node.name}: {node.target} -> {recompute.name}" + ) + return min_cut_rematerialization_partition(joint_gm, joint_args, **kwargs) + + backend = aot_autograd( + fw_compiler=nop, + bw_compiler=nop, + partition_fn=capture_partition, + decompositions=decompositions, + ) + + def fn(x): + return checkpoint( + checkpointed_fn, + x, + use_reentrant=False, + context_fn=functools.partial( + create_selective_checkpoint_contexts, policy_fn + ), + ) + + x = torch.randn(4, requires_grad=True) + torch._dynamo.reset() + compiled = torch.compile(fn, backend=backend) + out = compiled(x) + out.sum().backward() + return "\n".join(annotations) + + def test_pre_mode_decomp_has_sac_ignored_ops(self): + SAVE_OPS = { + torch.ops.aten.sin.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.cos.default, + } + + def policy_fn(ctx, func, *args, **kwargs): + if func in SAVE_OPS: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + @torch._dynamo.allow_in_graph + def op_with_detach(x): + a = x.sin() + out = a.detach() + a + out = out.cos() + return out + + self.assertExpectedInline( + self._get_sac_annotations(op_with_detach, policy_fn), + """\ +sin: aten.sin.default -> MUST_SAVE +detach_1: aten.detach.default -> PREFER_RECOMPUTE +add: aten.add.Tensor -> MUST_SAVE +cos: aten.cos.default -> MUST_SAVE""", + ) + + def test_post_mode_decomp(self): + from torch._inductor.compile_fx import select_decomp_table + + def policy_fn(ctx, func, *args, **kwargs): + if func == torch.ops.aten.silu.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + x = x.sin() + x = torch.nn.functional.silu(x) + x = x.cos() + return x + + self.assertExpectedInline( + self._get_sac_annotations( + fn, policy_fn, decompositions=select_decomp_table + ), + """\ +sin: aten.sin.default -> PREFER_RECOMPUTE +neg: aten.neg.default -> MUST_SAVE +exp: aten.exp.default -> MUST_SAVE +add: aten.add.Tensor -> MUST_SAVE +div: aten.div.Tensor -> MUST_SAVE +cos: aten.cos.default -> PREFER_RECOMPUTE""", + ) + + def test_multi_output_op(self): + def policy_fn(ctx, func, *args, **kwargs): + if func == torch.ops.aten.topk.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + x = x.sin() + vals, idxs = torch.topk(x, k=2) + out = vals.sum() + out = out.cos() + return out + + self.assertExpectedInline( + self._get_sac_annotations(fn, policy_fn), + """\ +sin: aten.sin.default -> PREFER_RECOMPUTE +topk: aten.topk.default -> MUST_SAVE +getitem: -> MUST_SAVE +getitem_1: -> MUST_SAVE +sum_1: aten.sum.default -> PREFER_RECOMPUTE +cos: aten.cos.default -> PREFER_RECOMPUTE""", + ) + + +class RematerializeACNodesPassTests(torch._dynamo.test_case.TestCase): + """Tests for AC reordering optimization in full graph (forward+backward in one graph).""" + + def count_op(self, gm, target): + return sum(1 for n in gm.graph.nodes if n.target == target) + + def _compile_and_capture(self, fn, remat_using_tags_for_fwd_loss_bwd_graph, inputs): + captured_gm = None + + def compiler(gm, example_inputs): + nonlocal captured_gm + captured_gm = gm + return gm.forward + + backend = aot_autograd( + fw_compiler=compiler, + bw_compiler=None, + partition_fn=None, + ) + + with ( + torch._functorch.config.patch( + remat_using_tags_for_fwd_loss_bwd_graph=remat_using_tags_for_fwd_loss_bwd_graph + ), + torch._dynamo.config.patch(trace_autograd_ops=True), + ): + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + result = compiled_fn(*inputs) + + return result, captured_gm + + @unittest.skipIf(not HAS_GPU_AND_TRITON, "GPU not available") + def test_ac_rematerialize_simple_forward_backward(self): + x = torch.randn(4, 4, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + + def simple_fwd_bwd(x, y): + z = torch.utils.checkpoint.checkpoint( + lambda a, b: torch.sigmoid(torch.matmul(a, b)), + x, + y, + use_reentrant=False, + ) + loss = z.sum() + + dx, dy = _grad(loss, (x, y)) + + return dx.detach(), dy.detach() + + (dx1, dy1), gm_without = self._compile_and_capture( + simple_fwd_bwd, False, (x, y) + ) + (dx2, dy2), gm_with = self._compile_and_capture(simple_fwd_bwd, True, (x, y)) + + self.assertTrue(torch.allclose(dx1, dx2)) + self.assertTrue(torch.allclose(dy1, dy2)) + + mm_with = self.count_op(gm_with, torch.ops.aten.mm.default) + mm_without = self.count_op(gm_without, torch.ops.aten.mm.default) + sigmoid_with = self.count_op(gm_with, torch.ops.aten.sigmoid.default) + sigmoid_without = self.count_op(gm_without, torch.ops.aten.sigmoid.default) + self.assertEqual(mm_with, 4, "mm should be recomputed in backward") + self.assertEqual(mm_without, 3) + self.assertEqual(sigmoid_with, 2, "sigmoid should be recomputed in backward") + self.assertEqual(sigmoid_without, 1) + + self.assertExpectedInline( + gm_with.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + mm = torch.ops.aten.mm.default(arg0_1, arg1_1) + sigmoid = torch.ops.aten.sigmoid.default(mm); mm = None + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + mm_recomputed = torch.ops.aten.mm.default(arg0_1, arg1_1) + sigmoid_recomputed = torch.ops.aten.sigmoid.default(mm_recomputed); mm_recomputed = None + detach_2_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_4 = torch.ops.aten.detach.default(detach_2_recomputed); detach_2_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_4); expand = detach_4 = None + t = torch.ops.aten.t.default(arg0_1); arg0_1 = None + mm_2 = torch.ops.aten.mm.default(t, sigmoid_backward); t = None + t_1 = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_3 = torch.ops.aten.mm.default(sigmoid_backward, t_1); sigmoid_backward = t_1 = None + detach_5 = torch.ops.aten.detach.default(mm_3); mm_3 = None + detach_6 = torch.ops.aten.detach.default(mm_2); mm_2 = None + return (detach_5, detach_6)""", + ) + + def test_ac_rematerialize_with_rng_ops_raises_error(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_rng(x): + z = torch.utils.checkpoint.checkpoint( + lambda a: torch.sigmoid(a + torch.rand_like(a)), x, use_reentrant=False + ) + loss = z.sum() + + dx = _grad(loss, x)[0] + + return dx + + with self.assertRaisesRegex( + torch._dynamo.exc.BackendCompilerFailed, + "Activation checkpoint rematerialization in `forward-loss-backward` graph does not support RNG ops in recompute regions.", + ): + self._compile_and_capture(fwd_bwd_with_rng, True, (x,)) + + def test_ac_rematerialize_with_no_annotations(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd(x): + z = torch.utils.checkpoint.checkpoint( + lambda a: torch.sigmoid(a + 4), x, use_reentrant=False + ) + loss = z.sum() + return _grad(loss, x)[0] + + result_with, gm_with = self._compile_and_capture(fwd_bwd, True, (x,)) + result_without, gm_without = self._compile_and_capture(fwd_bwd, False, (x,)) + + self.assertTrue(torch.allclose(result_with, result_without)) + + # autograd_backward tagging is automatic now, so remat should still work + sigmoid_with = self.count_op(gm_with, torch.ops.aten.sigmoid.default) + sigmoid_without = self.count_op(gm_without, torch.ops.aten.sigmoid.default) + self.assertEqual(sigmoid_with, 2, "sigmoid should be recomputed in backward") + self.assertEqual(sigmoid_without, 1) + + def test_ac_rematerialize_with_selective_checkpoint_policy(self): + x = torch.randn(4, 128, requires_grad=True) + w1 = torch.randn(128, 128, requires_grad=True) + b1 = torch.randn(128, requires_grad=True) + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.addmm.default: + return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE + return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE + + context_fn = functools.partial( + torch.utils.checkpoint.create_selective_checkpoint_contexts, policy_fn + ) + + def fwd_bwd_with_policy(x, w1, b1): + def checkpoint_fn(inp, w, b): + linear = torch.nn.functional.linear(inp, w, b) + return torch.relu(linear) + + result = torch.utils.checkpoint.checkpoint( + checkpoint_fn, x, w1, b1, use_reentrant=False, context_fn=context_fn + ) + loss = result.sum() + + dx, dw, db = _grad(loss, (x, w1, b1)) + return dx, dw, db + + result_with, gm_with = self._compile_and_capture( + fwd_bwd_with_policy, True, (x, w1, b1) + ) + result_without, gm_without = self._compile_and_capture( + fwd_bwd_with_policy, False, (x, w1, b1) + ) + + torch.testing.assert_close(result_with[0], result_without[0]) + torch.testing.assert_close(result_with[1], result_without[1]) + torch.testing.assert_close(result_with[2], result_without[2]) + + addmm_without = self.count_op(gm_without, torch.ops.aten.addmm.default) + relu_without = self.count_op(gm_without, torch.ops.aten.relu.default) + + addmm_with = self.count_op(gm_with, torch.ops.aten.addmm.default) + relu_with = self.count_op(gm_with, torch.ops.aten.relu.default) + + self.assertEqual(addmm_without, addmm_with) + self.assertEqual(relu_with, relu_without + 1) + + recomputed_nodes = [ + n.name for n in gm_with.graph.nodes if "_recomputed" in n.name + ] + self.assertNotIn("addmm_recomputed", recomputed_nodes) + + self.assertTrue( + any("relu" in name for name in recomputed_nodes), + f"Expected relu_recomputed but got: {recomputed_nodes}", + ) + + def _compile_with_joint_graph_pass_and_capture(self, fn, inputs): + from torch._inductor.fx_passes.joint_graph import joint_graph_passes + + captured_gm_before = None + captured_gm_after = None + + def custom_compiler(gm, example_inputs): + nonlocal captured_gm_before, captured_gm_after + import copy + + captured_gm_before = copy.deepcopy(gm) + joint_graph_passes(gm) + captured_gm_after = gm + return gm.forward + + backend = aot_autograd( + fw_compiler=custom_compiler, + bw_compiler=None, + partition_fn=None, + ) + + with torch._dynamo.config.patch(trace_autograd_ops=True): + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + result = compiled_fn(*inputs) + + return result, captured_gm_before, captured_gm_after + + def test_joint_graph_passes_view_optimization(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_views(x): + def checkpoint_fn(a): + b = a.view(16) + c = b.view(4, 4) + return torch.sigmoid(c) + + z = torch.utils.checkpoint.checkpoint( + checkpoint_fn, + x, + use_reentrant=False, + ) + loss = z.sum() + + dx = _grad(loss, x)[0] + + return dx.detach() + + result, gm_before, gm_after = self._compile_with_joint_graph_pass_and_capture( + fwd_bwd_with_views, (x,) + ) + + result_eager = torch.autograd.grad(torch.sigmoid(x).sum(), x)[0] + self.assertTrue(torch.allclose(result, result_eager, atol=1e-5)) + + view_count_before = self.count_op(gm_before, torch.ops.aten.view.default) + view_count_after = self.count_op(gm_after, torch.ops.aten.view.default) + self.assertTrue(view_count_after == 0) + self.assertTrue(view_count_before == 6) + + self.assertExpectedInline( + gm_after.code.strip(), + """\ +def forward(self, arg0_1): + sigmoid = torch.ops.aten.sigmoid.default(arg0_1) + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + sigmoid_recomputed = torch.ops.aten.sigmoid.default(arg0_1); arg0_1 = None + detach_3_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_5 = torch.ops.aten.detach.default(detach_3_recomputed); detach_3_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_5); expand = detach_5 = None + detach_6 = torch.ops.aten.detach.default(sigmoid_backward); sigmoid_backward = None + return (detach_6,)""", + ) + + def test_joint_graph_passes_permute_optimization(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_permute(x): + def checkpoint_fn(a): + b = a.permute(1, 0) + c = b.permute(1, 0) + return torch.sigmoid(c) + + z = torch.utils.checkpoint.checkpoint( + checkpoint_fn, + x, + use_reentrant=False, + ) + loss = z.sum() + + dx = _grad(loss, x)[0] + + return dx.detach() + + result, gm_before, gm_after = self._compile_with_joint_graph_pass_and_capture( + fwd_bwd_with_permute, (x,) + ) + + result_eager = torch.autograd.grad(torch.sigmoid(x).sum(), x)[0] + self.assertTrue(torch.allclose(result, result_eager, atol=1e-5)) + + permute_count_before = self.count_op(gm_before, torch.ops.aten.permute.default) + permute_count_after = self.count_op(gm_after, torch.ops.aten.permute.default) + self.assertTrue(permute_count_after == 0) + self.assertTrue(permute_count_before == 6) + + self.assertExpectedInline( + gm_after.code.strip(), + """\ +def forward(self, arg0_1): + sigmoid = torch.ops.aten.sigmoid.default(arg0_1) + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + sigmoid_recomputed = torch.ops.aten.sigmoid.default(arg0_1); arg0_1 = None + detach_3_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_5 = torch.ops.aten.detach.default(detach_3_recomputed); detach_3_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_5); expand = detach_5 = None + detach_6 = torch.ops.aten.detach.default(sigmoid_backward); sigmoid_backward = None + return (detach_6,)""", + ) + + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_attr_compile_submodules_in_checkpoint_wrapper(self): + """Compiling submodules inside a checkpointed block should not hit the + recompile limit due to WeakKeyDictionary guards in the pack_hook.""" + from torch.utils.checkpoint import checkpoint + + class Block(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm1 = nn.RMSNorm(dim) + self.linear1 = nn.Linear(dim, dim, bias=False) + self.norm2 = nn.RMSNorm(dim) + self.linear2 = nn.Linear(dim, dim, bias=False) + self.norm3 = nn.RMSNorm(dim) + self.linear3 = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + x = x + self.linear1(self.norm1(x)) + x = x + self.linear2(self.norm2(x)) + x = x + self.linear3(self.norm3(x)) + return x + + class CheckpointedBlock(nn.Module): + def __init__(self, block): + super().__init__() + self.block = block + + def forward(self, x): + return checkpoint(self.block, x, use_reentrant=False) + + dim = 32 + block = Block(dim) + + x_ref = torch.randn(4, dim, requires_grad=True) + ref = block(x_ref) + ref.sum().backward() + + block_cp = Block(dim) + block_cp.load_state_dict(block.state_dict()) + wrapped = CheckpointedBlock(block_cp) + + for _, submod in wrapped.block.named_children(): + submod.compile(backend="aot_eager") + + with torch._dynamo.config.patch(recompile_limit=2): + x_test = x_ref.detach().clone().requires_grad_(True) + result = wrapped(x_test) + result.sum().backward() + + self.assertEqual(ref, result) + self.assertEqual(x_ref.grad, x_test.grad) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_multiple_user_phase_annotations_errors(self): + x = torch.randn(4, 4, requires_grad=True) + w = torch.randn(4, 4, requires_grad=True) + + def fn(x, w): + z = torch.utils.checkpoint.checkpoint( + lambda a, b: torch.sin(a @ b), x, w, use_reentrant=False + ) + loss = z.sum() + with torch.fx.traceback.annotate({"phase": "backward"}): + dx, dw = _grad(loss, (x, w)) + # Non-backward computation between two backward annotations + out = dx + dw + with torch.fx.traceback.annotate({"phase": "backward"}): + out = out * 2 + return out.detach() + + with self.assertRaisesRegex(RuntimeError, "backward regions annotated"): + self._compile_and_capture(fn, True, (x, w)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_user_phase_annotation_with_extra_autograd_grad(self): + """Only the user-annotated backward region gets rematerialization.""" + x = torch.randn(4, 4, requires_grad=True) + w1 = torch.randn(4, 4, requires_grad=True) + w2 = torch.randn(4, 4, requires_grad=True) + + def fn(x, w1, w2): + z1 = torch.utils.checkpoint.checkpoint( + lambda a, b: torch.sin(a @ b), x, w1, use_reentrant=False + ) + z2 = torch.utils.checkpoint.checkpoint( + lambda a, b: torch.sigmoid(a @ b), x, w2, use_reentrant=False + ) + loss1 = z1.sum() + loss2 = z2.sum() + # Only the first backward is annotated + with torch.fx.traceback.annotate({"phase": "backward"}): + dx1 = _grad(loss1, (x,)) + # Second backward NOT annotated — should not get remat + dx2 = _grad(loss2, (x,)) + return (dx1[0] + dx2[0]).detach() + + _, gm_with = self._compile_and_capture(fn, True, (x, w1, w2)) + + self.assertExpectedInline( + gm_with.code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1): + mm = torch.ops.aten.mm.default(arg0_1, arg1_1) + sin = torch.ops.aten.sin.default(mm); mm = None + mm_1 = torch.ops.aten.mm.default(arg0_1, arg2_1) + sigmoid = torch.ops.aten.sigmoid.default(mm_1); mm_1 = None + detach_4 = torch.ops.aten.detach.default(sigmoid) + sum_1 = torch.ops.aten.sum.default(sin); sin = None + sum_2 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + mm_recomputed = torch.ops.aten.mm.default(arg0_1, arg1_1); arg0_1 = None + cos = torch.ops.aten.cos.default(mm_recomputed); mm_recomputed = None + mul = torch.ops.aten.mul.Tensor(expand, cos); expand = cos = None + t = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_3 = torch.ops.aten.mm.default(mul, t); mul = t = None + ones_like_1 = torch.ops.aten.ones_like.default(sum_2, pin_memory = False, memory_format = torch.preserve_format); sum_2 = None + expand_1 = torch.ops.aten.expand.default(ones_like_1, [4, 4]); ones_like_1 = None + detach_5 = torch.ops.aten.detach.default(detach_4); detach_4 = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand_1, detach_5); expand_1 = detach_5 = None + t_1 = torch.ops.aten.t.default(arg2_1); arg2_1 = None + mm_4 = torch.ops.aten.mm.default(sigmoid_backward, t_1); sigmoid_backward = t_1 = None + add = torch.ops.aten.add.Tensor(mm_3, mm_4); mm_3 = mm_4 = None + detach_6 = torch.ops.aten.detach.default(add); add = None + return (detach_6,)""", + ) + + def test_chunked_loss_remat(self): + """Chunked loss pattern: multiple backward regions from chunk_loss.backward() + calls, but only the final x.backward() region needs remat. The pass should + skip the chunk backwards and only rematerialize for the final one.""" + dim = 32 + chunksz = 4 + + class Block(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(dim, dim, bias=False) + self.l2 = nn.Linear(dim, dim, bias=False) + + def _fn(self, x): + return self.l2(F.gelu(self.l1(x), approximate="tanh")) + + def forward(self, x): + return checkpoint(self._fn, x, use_reentrant=False) + + class ChunkedLoss(nn.Module): + def __init__(self): + super().__init__() + self.block = Block() + self.head = nn.Linear(dim, dim, bias=False) + + def forward(self, x, y): + x = self.block(x) + x_detached = x.detach().requires_grad_() + total = 0 + for start in range(0, x_detached.shape[0], chunksz): + end = start + chunksz + chunk_loss = ( + F.mse_loss( + self.head(x_detached[start:end]), + y[start:end], + reduction="sum", + ) + / x_detached.shape[0] + ) + chunk_loss.backward() + total = total + chunk_loss.detach() + x.backward(x_detached.grad) + return total + + model = ChunkedLoss() + x = torch.randn(12, dim) + y = torch.randn(12, dim) + + result_with, gm_with = self._compile_and_capture(model, True, (x, y)) + result_without, gm_without = self._compile_and_capture(model, False, (x, y)) + + torch.testing.assert_close(result_with, result_without) + + # Without remat, gelu appears once (forward only). + # With remat, gelu is duplicated into the backward region. + gelu_without = self.count_op(gm_without, torch.ops.aten.gelu.default) + gelu_with = self.count_op(gm_with, torch.ops.aten.gelu.default) + self.assertEqual(gelu_without, 1) + self.assertEqual(gelu_with, 2, "gelu should be recomputed in backward") + + def test_two_backward_regions_needing_remat_errors(self): + """Two independent backward calls that both need recompute should error.""" + dim = 32 + + class Block(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(dim, dim, bias=False) + self.l2 = nn.Linear(dim, dim, bias=False) + + def _fn(self, x): + return self.l2(F.gelu(self.l1(x), approximate="tanh")) + + def forward(self, x): + return checkpoint(self._fn, x, use_reentrant=False) + + class TwoBackwards(nn.Module): + def __init__(self): + super().__init__() + self.block1 = Block() + self.block2 = Block() + + def forward(self, x): + y = self.block1(x) + z = self.block2(y) + z.sum().backward(retain_graph=True) + y.sum().backward() + return z.detach() + + x = torch.randn(8, dim, requires_grad=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.BackendCompilerFailed, + "require recomputation", + ): + self._compile_and_capture(TwoBackwards(), True, (x,)) + + +instantiate_device_type_tests( + ActivationCheckpointingViaTagsTests, globals(), except_for="cpu" +) + + +class ActivationCheckpointingNonStrictTracerTests(torch._dynamo.test_case.TestCase): + """Tests for non-strict tracing flag interaction with checkpoint.""" + + @staticmethod + def _count_backward_regions(gm): + regions = 0 + in_backward = False + for node in gm.graph.nodes: + is_backward = bool(node.meta.get("autograd_backward", False)) + if is_backward and not in_backward: + regions += 1 + in_backward = is_backward + return regions + + def test_backward_nodes_have_seq_nr_under_non_strict(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + return checkpoint( + lambda x: torch.sin(x @ self.w), x, use_reentrant=False + ) + + gm = self._trace_train_step(Model(), torch.randn(2, 4)) + forward_seq_nrs = { + node.meta["seq_nr"] + for node in gm.graph.nodes + if node.op == "call_function" + and not node.meta.get("autograd_backward", False) + and "seq_nr" in node.meta + } + backward_seq_nrs = { + node.meta["seq_nr"] + for node in gm.graph.nodes + if node.op == "call_function" + and node.meta.get("autograd_backward", False) + and "seq_nr" in node.meta + } + + self.assertTrue(forward_seq_nrs) + self.assertTrue(backward_seq_nrs) + self.assertSetEqual(backward_seq_nrs, forward_seq_nrs) + + def test_patch_autograd_grad_requires_non_strict_tracing(self): + x = torch.randn(2, 4, requires_grad=True) + loss = torch.sin(x).sum() + + with torch.compiler._patch_autograd_grad(): + with self.assertRaisesRegex( + AssertionError, + "_patch_autograd_grad\\(\\) must be used under " + "_non_strict_tracing_context\\(\\)", + ): + torch.autograd.grad(loss, (x,)) + + def test_patch_engine_backward_requires_non_strict_tracing(self): + x = torch.randn(2, 4, requires_grad=True) + loss = torch.sin(x).sum() + + with torch.compiler._patch_engine_backward(): + with self.assertRaisesRegex( + AssertionError, + "_patch_engine_backward\\(\\) must be used under " + "_non_strict_tracing_context\\(\\)", + ): + loss.backward() + + def test_patch_autograd_grad_does_not_leak_backward_tag(self): + from torch.fx.experimental.proxy_tensor import make_fx + from torch.fx.traceback import preserve_node_meta + + x = torch.randn(2, 4, requires_grad=True) + + def fn(x): + with torch.fx.traceback.annotate({"ac_region_id": 0}): + y = torch.sin(x) + torch.autograd.grad(y.sum(), (x,)) + return torch.neg(y) + + with ( + torch.compiler._non_strict_tracing_context(), + torch.compiler._patch_autograd_grad(), + preserve_node_meta(), + ): + gm = make_fx(fn)(x) + + backward_nodes = [ + node for node in gm.graph.nodes if node.meta.get("autograd_backward", False) + ] + self.assertTrue(backward_nodes) + + neg_nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.neg.default + ) + self.assertEqual(len(neg_nodes), 1) + self.assertNotIn("autograd_backward", neg_nodes[0].meta) + self.assertEqual(neg_nodes[0].meta.get("custom", {}), {"ac_region_id": 0}) + + def test_patch_engine_backward_does_not_leak_backward_tag(self): + from torch.fx.experimental.proxy_tensor import make_fx + from torch.fx.traceback import preserve_node_meta + + x = torch.randn(2, 4, requires_grad=True) + + def fn(x): + with torch.fx.traceback.annotate({"ac_region_id": 0}): + y = torch.sin(x) + y.sum().backward() + return torch.neg(y) + + with ( + torch.compiler._non_strict_tracing_context(), + torch.compiler._patch_engine_backward(), + preserve_node_meta(), + ): + gm = make_fx(fn)(x) + + backward_nodes = [ + node for node in gm.graph.nodes if node.meta.get("autograd_backward", False) + ] + self.assertTrue(backward_nodes) + + neg_nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.neg.default + ) + self.assertEqual(len(neg_nodes), 1) + self.assertNotIn("autograd_backward", neg_nodes[0].meta) + self.assertEqual(neg_nodes[0].meta.get("custom", {}), {"ac_region_id": 0}) + + def test_patch_autograd_grad_mlp_has_single_contiguous_backward_region(self): + class Block(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc1 = nn.Linear(dim, dim * 2) + self.fc2 = nn.Linear(dim * 2, dim) + + def forward(self, x): + return x + self.fc2(torch.relu(self.fc1(x))) + + class Model(nn.Module): + def __init__(self, dim=32, depth=6): + super().__init__() + self.blocks = nn.ModuleList([Block(dim) for _ in range(depth)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + gm = self._trace_train_step(Model(), torch.randn(4, 16, 32)) + self.assertEqual(self._count_backward_regions(gm), 1) + + def _trace_train_step(self, mod, x): + import torch.utils._pytree as pytree + from torch.fx.experimental.proxy_tensor import make_fx + from torch.fx.traceback import preserve_node_meta + from torch.nn.utils import stateless + + params = dict(mod.named_parameters()) + flat_params, params_spec = pytree.tree_flatten(params) + params_len = len(flat_params) + + def train_step(*all_args): + p = pytree.tree_unflatten(list(all_args[:params_len]), params_spec) + with stateless._reparametrize_module(mod, p): + loss = mod(all_args[params_len]).sum() + return torch.autograd.grad(loss, all_args[:params_len]) + + full_args = (*flat_params, x) + + with ( + torch.compiler._non_strict_tracing_context(), + torch.compiler._patch_autograd_grad(), + preserve_node_meta(), + ): + return make_fx(train_step)(*full_args) + + def test_checkpoint_traces_through_eager_ac_under_non_strict(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + return checkpoint( + lambda x: torch.sin(x @ self.w), x, use_reentrant=False + ) + + gm = self._trace_train_step(Model(), torch.randn(2, 4)) + + # Everything is recomputed in backward: mm_1 is the recomputed mm. + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + mm = torch.ops.aten.mm.default(arg1_1, arg0_1) + sin = torch.ops.aten.sin.default(mm); mm = None + sum_1 = torch.ops.aten.sum.default(sin); sin = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [2, 4]); ones_like = None + mm_1 = torch.ops.aten.mm.default(arg1_1, arg0_1); arg0_1 = None + detach = torch.ops.aten.detach.default(mm_1); mm_1 = None + detach_1 = torch.ops.aten.detach.default(detach); detach = None + cos = torch.ops.aten.cos.default(detach_1); detach_1 = None + mul = torch.ops.aten.mul.Tensor(expand, cos); expand = cos = None + t = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_2 = torch.ops.aten.mm.default(t, mul); t = mul = None + return (mm_2,)""", + ) + + def test_sac_traces_through_eager_ac_under_non_strict(self): + def policy_fn(ctx, func, *args, **kwargs): + if func == torch.ops.aten.mm.default: + return CheckpointPolicy.PREFER_RECOMPUTE + return CheckpointPolicy.MUST_SAVE + + def context_fn(): + return create_selective_checkpoint_contexts(policy_fn) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + def fn(x): + return torch.sin(x @ self.w) + + return checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + gm = self._trace_train_step(Model(), torch.randn(2, 4)) + + # mm is PREFER_RECOMPUTE so it gets recomputed in backward (mm_1). + # sin is MUST_SAVE so its output is saved via detach. + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + mm = torch.ops.aten.mm.default(arg1_1, arg0_1) + sin = torch.ops.aten.sin.default(mm); mm = None + detach = torch.ops.aten.detach.default(sin); detach = None + sum_1 = torch.ops.aten.sum.default(sin); sin = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [2, 4]); ones_like = None + mm_1 = torch.ops.aten.mm.default(arg1_1, arg0_1); arg0_1 = None + detach_1 = torch.ops.aten.detach.default(mm_1); mm_1 = None + detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None + cos = torch.ops.aten.cos.default(detach_2); detach_2 = None + mul = torch.ops.aten.mul.Tensor(expand, cos); expand = cos = None + t = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_2 = torch.ops.aten.mm.default(t, mul); t = mul = None + return (mm_2,)""", + ) + + def test_checkpoint_with_rng_op_under_non_strict(self): + from torch.fx.experimental.proxy_tensor import make_fx + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + def fn(x): + return torch.nn.functional.dropout(torch.sin(x @ self.w), p=0.5) + + return checkpoint(fn, x, use_reentrant=False, preserve_rng_state=True) + + mod = Model() + x = torch.randn(2, 4) + + import torch.utils._pytree as pytree + from torch.nn.utils import stateless + + params = dict(mod.named_parameters()) + flat_params, params_spec = pytree.tree_flatten(params) + params_len = len(flat_params) + + def train_step(*all_args): + p = pytree.tree_unflatten(list(all_args[:params_len]), params_spec) + with stateless._reparametrize_module(mod, p): + loss = mod(all_args[params_len]).sum() + return torch.autograd.grad(loss, all_args[:params_len]) + + full_args = (*flat_params, x) + + torch.manual_seed(42) + with torch.compiler._non_strict_tracing_context(): + gm = make_fx(train_step)(*full_args) + + # The traced graph has two independent bernoulli_ calls (forward and + # recomputed). At replay time these produce different dropout masks + # because get/set_rng_state are silently dropped by make_fx, so the + # traced graph is NOT bitwise equivalent to eager. This is a known + # limitation of tracing through eager checkpoint. + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + mm = torch.ops.aten.mm.default(arg1_1, arg0_1) + sin = torch.ops.aten.sin.default(mm); mm = None + empty_like = torch.ops.aten.empty_like.default(sin) + bernoulli_ = torch.ops.aten.bernoulli_.float(empty_like); empty_like = None + div_ = torch.ops.aten.div_.Scalar(bernoulli_, 0.5); bernoulli_ = None + mul = torch.ops.aten.mul.Tensor(sin, div_); sin = div_ = None + sum_1 = torch.ops.aten.sum.default(mul); mul = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [2, 4]); ones_like = None + mm_1 = torch.ops.aten.mm.default(arg1_1, arg0_1); arg0_1 = None + detach = torch.ops.aten.detach.default(mm_1) + sin_1 = torch.ops.aten.sin.default(mm_1); mm_1 = None + empty_like_1 = torch.ops.aten.empty_like.default(sin_1); sin_1 = None + bernoulli__1 = torch.ops.aten.bernoulli_.float(empty_like_1); empty_like_1 = None + div__1 = torch.ops.aten.div_.Scalar(bernoulli__1, 0.5); bernoulli__1 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, div__1); expand = div__1 = None + detach_1 = torch.ops.aten.detach.default(detach); detach = None + cos = torch.ops.aten.cos.default(detach_1); detach_1 = None + mul_2 = torch.ops.aten.mul.Tensor(mul_1, cos); mul_1 = cos = None + t = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_2 = torch.ops.aten.mm.default(t, mul_2); t = mul_2 = None + return (mm_2,)""", + ) + + +class ActivationCheckpointingNestedCompileTests(torch._dynamo.test_case.TestCase): + @requires_gpu_and_triton + def test_checkpoint_recompute_preserves_nested_fx_trace_policy(self): + from torch._guards import tracing, TracingContext + from torch._subclasses import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + from torch.fx.traceback import preserve_node_meta + + compiled_f = torch.compile(lambda x: x.sin().cos(), fullgraph=True) + + @contextlib.contextmanager + def skip_nested_compile(): + prev = torch._dynamo.config.error_on_nested_fx_trace + torch._dynamo.config.error_on_nested_fx_trace = False + try: + yield + finally: + torch._dynamo.config.error_on_nested_fx_trace = prev + + class M(torch.nn.Module): + def forward(self, x): + return checkpoint(self.block, x, use_reentrant=False) + + def block(self, x): + return compiled_f(x) + + m = getattr(M(), device_type)() + x = torch.randn(8, device=device_type, requires_grad=True) + + def fn(x): + y = m(x).sum() + (gx,) = torch.autograd.grad(y, (x,)) + return y.detach(), gx + + fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=torch.fx.experimental.symbolic_shapes.ShapeEnv(), + ) + fx_x = fake_mode.from_tensor(x, static_shapes=True) + if x.requires_grad and not fx_x.requires_grad: + fx_x.requires_grad_(True) + + ctx = TracingContext(fake_mode) + + with ( + fake_mode, + tracing(ctx), + preserve_node_meta(), + skip_nested_compile(), + torch.compiler._non_strict_tracing_context(), + ): + gm = make_fx(fn)(fx_x) + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x_1): + sin = torch.ops.aten.sin.default(x_1) + cos = torch.ops.aten.cos.default(sin); sin = None + sum_1 = torch.ops.aten.sum.default(cos); cos = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) + expand = torch.ops.aten.expand.default(ones_like, [8]); ones_like = None + detach = torch.ops.aten.detach.default(x_1) + sin_1 = torch.ops.aten.sin.default(x_1); x_1 = None + detach_1 = torch.ops.aten.detach.default(sin_1); sin_1 = None + detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None + sin_2 = torch.ops.aten.sin.default(detach_2); detach_2 = None + neg = torch.ops.aten.neg.default(sin_2); sin_2 = None + mul = torch.ops.aten.mul.Tensor(expand, neg); expand = neg = None + detach_3 = torch.ops.aten.detach.default(detach); detach = None + cos_1 = torch.ops.aten.cos.default(detach_3); detach_3 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, cos_1); mul = cos_1 = None + detach_4 = torch.ops.aten.detach.default(sum_1); sum_1 = None + return (detach_4, mul_1)""", + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/dynamo/test_debug_utils_xpu.py b/test/xpu/dynamo/test_debug_utils_xpu.py new file mode 100644 index 0000000000..b11f4baf9b --- /dev/null +++ b/test/xpu/dynamo/test_debug_utils_xpu.py @@ -0,0 +1,901 @@ +# Owner(s): ["module: dynamo"] + +import os +from unittest.mock import patch + +import torch +import torch._dynamo +import torch._dynamo.config +from torch._dynamo import debug_utils +from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string +from torch._dynamo.test_case import TestCase +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_device_type import instantiate_device_type_tests + +f32 = torch.float32 +i64 = torch.int64 +i32 = torch.int32 + + +class TestDebugUtils(TestCase): + def test_cast_model_to_fp64_dtype_args(self): + # Test that dtype arguments are converted to fp64 + + def fn(x): + return ( + torch.ops.prims.convert_element_type(x, torch.float16), + x.to(torch.float16), + torch.full(x.shape, 2, dtype=torch.float32, device=x.device), + x.new_empty(x.shape), + ) + + x = torch.randn(32, device="cpu") + decomps = torch._decomp.core_aten_decompositions() + fx = make_fx(fn, decomposition_table=decomps)(x) + + self.assertExpectedInline( + fx.code.lstrip(), + """\ +def forward(self, x_1): + convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16) + _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None + full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + return (convert_element_type, _to_copy, full, empty) + """, + ) + + _, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) + self.assertEqual(fp64_examples, (x.to(torch.float64),)) + + self.assertExpectedInline( + fx.code.lstrip(), + """\ +def forward(self, x_1): + convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64) + _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None + full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + return (convert_element_type, _to_copy, full, empty) + """, + ) + + @patch.dict( + os.environ, + { + "TORCHINDUCTOR_MAX_AUTOTUNE": "1", + "TEST_ENV": "1", + "TORCHINDUCTOR_ENV_SINGLE_QUOTES": "inductor_'env'", + "TORCHINDUCTOR_ENV_DOUBLE_QUOTES": 'inductor_"env"', + }, + ) + def test_generate_env_vars_string(self): + env_strings = generate_env_vars_string() + self.assertIn( + """os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1' +""", + env_strings, + ) + self.assertIn( + """os.environ['TORCHINDUCTOR_ENV_SINGLE_QUOTES'] = 'inductor_"env"' +""", + env_strings, + ) + self.assertIn( + """os.environ['TORCHINDUCTOR_ENV_DOUBLE_QUOTES'] = 'inductor_"env"' +""", + env_strings, + ) + self.assertIn( + """import os +""", + env_strings, + ) + self.assertNotIn( + """TEST_ENV +""", + env_strings, + ) + + +class TestDebugUtilsDevice(TestCase): + def test_aot_graph_parser(self, device): + def forward( + self, + primals_1: "f32[1001, 6]", + primals_2: "f32[1001]", + primals_3: "f32[1001, 64]", + primals_4: "f32[4190]", + primals_5: "f32[4190]", + primals_6: "f32[1739, 4190]", + primals_48: "f32[6144, 4191]", + ): + _tensor_constant0: i64[4190] = self._tensor_constant0 + lift_fresh_copy: i64[4190] = torch.ops.aten.lift_fresh_copy.default( + _tensor_constant0 + ) + _tensor_constant0 = None + index: f32[6144, 4190] = torch.ops.aten.index.Tensor( # noqa: F841 + primals_48, [None, lift_fresh_copy] + ) + lift_fresh_copy = None + + _tensor_constant1: i64[6] = self._tensor_constant1 + lift_fresh_copy_1: i64[6] = torch.ops.aten.lift_fresh_copy.default( + _tensor_constant1 + ) + _tensor_constant1 = None + index_1: f32[6144, 6] = torch.ops.aten.index.Tensor( + primals_48, [None, lift_fresh_copy_1] + ) + primals_48 = lift_fresh_copy_1 = None + permute: f32[6, 1001] = torch.ops.aten.permute.default(primals_1, [1, 0]) + primals_1 = None + addmm: f32[6144, 1001] = torch.ops.aten.addmm.default( + primals_2, index_1, permute + ) + primals_2 = permute = None + amax: f32[6144, 1] = torch.ops.aten.amax.default(addmm, [-1], True) + sub: f32[6144, 1001] = torch.ops.aten.sub.Tensor(addmm, amax) + exp: f32[6144, 1001] = torch.ops.aten.exp.default(sub) + sub = None + sum_1: f32[6144, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True) + div: f32[6144, 1001] = torch.ops.aten.div.Tensor(exp, sum_1) + exp = None + + full_default: i32[6144, 1001] = torch.ops.aten.full.default( + [6144, 1001], + 1, + dtype=torch.int32, + layout=torch.strided, + device=device, + pin_memory=False, + ) + + iota: i32[1001] = torch.ops.prims.iota.default( + 1001, + start=0, + step=1, + dtype=torch.int32, + device=device, + requires_grad=False, + ) + + mul: i32[6144, 1001] = torch.ops.aten.mul.Tensor(full_default, iota) + full_default = iota = None + + iota_1: i32[6144] = torch.ops.prims.iota.default( + 6144, + start=0, + step=1001, + dtype=torch.int32, + device=device, + requires_grad=False, + ) + view: i32[6150144] = torch.ops.aten.reshape.default(mul, [-1]) + mul = None + view_1: f32[6150144] = torch.ops.aten.reshape.default(div, [-1]) + div = None + _embedding_bag = torch.ops.aten._embedding_bag.default( + primals_3, view, iota_1, False, 0, False, view_1 + ) + + return _embedding_bag + + kwargs = aot_graph_input_parser(forward, device=device) + # runs successfully + forward(**kwargs) + + def test_sym_aot_graph_parser(self, device): + def forward( + self, + primals_1: "f32[1001, 6]", + primals_2: "f32[s0]", # noqa: F821 + primals_3: "Sym(s0)", # noqa: F821, + primals_4: "f32[s1]", # noqa: F821, + primals_5: "Sym(s1)", # noqa: F821, + ): + _tensor_constant0: i64[4190] = self._tensor_constant0 + + kwargs = aot_graph_input_parser( + forward, device=device, sym_shapes={"s0": 10}, default_sym_shape=5 + ) + + self.assertEqual(list(kwargs["primals_2"].shape), [10]) + self.assertEqual(kwargs["primals_3"], 10) + + self.assertEqual(list(kwargs["primals_4"].shape), [5]) + self.assertEqual(kwargs["primals_5"], 5) + + +instantiate_device_type_tests(TestDebugUtils, globals()) + +devices = ["cuda", "hpu", "xpu"] +instantiate_device_type_tests( + TestDebugUtilsDevice, globals(), only_for=devices, allow_xpu=True +) + + +class TestBackendOverrideIntegration(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + self._backends_called = [] + + def tearDown(self): + torch._dynamo.reset() + super().tearDown() + + def _fn_with_4_graphs(self, x): + x = x + 1 + torch._dynamo.graph_break() + x = x * 2 + torch._dynamo.graph_break() + x = x - 1 + torch._dynamo.graph_break() + x = x / 2 + return x + + def _run_with_override(self, device, override_config, default_backend="eager"): + from torch._dynamo.graph_id_filter import ( + _create_backend_router, + get_backend_override_for_compile_id, + ) + + torch._dynamo.reset() + # Clear the router cache to ensure fresh routers for each test + _create_backend_router.cache_clear() + self._backends_called.clear() + original_get_override = get_backend_override_for_compile_id + + # Pre-parse the config to build a mapping of graph_id -> backend_str + # by using the same parsing logic but extracting the original strings + backend_str_map: dict[int, str] = {} + if override_config: + for rule_str in override_config.split(";"): + rule_str = rule_str.strip() + if not rule_str or ":" not in rule_str: + continue + colon_idx = rule_str.find(":") + filter_str = rule_str[:colon_idx].strip() + backend_str = rule_str[colon_idx + 1 :].strip() + # Parse the filter to extract graph IDs + from torch._dynamo.graph_id_filter import GraphIdFilter + + gf = GraphIdFilter(filter_str) + # Store the backend_str for any graph that matches this filter + for graph_id in range(100): # Check first 100 graphs + if graph_id in gf and graph_id not in backend_str_map: + backend_str_map[graph_id] = backend_str + + def tracking_get_override(compile_id, config_str): + result = original_get_override(compile_id, config_str) + if result is not None: + graph_id = compile_id.frame_id + if graph_id in backend_str_map: + self._backends_called.append(backend_str_map[graph_id]) + return result + + with ( + patch.object( + torch._dynamo.config, "debug_backend_override", override_config + ), + patch( + "torch._dynamo.output_graph.get_backend_override_for_compile_id", + tracking_get_override, + ), + ): + compiled_fn = torch.compile(self._fn_with_4_graphs, backend=default_backend) + compiled_fn(torch.randn(10, device=device)) + + return self._backends_called.copy() + + def test_no_override(self, device): + result = self._run_with_override(device, "") + self.assertEqual(result, []) + + def test_override_all_graphs(self, device): + result = self._run_with_override(device, ">=0:aot_eager") + self.assertEqual(result, ["aot_eager", "aot_eager", "aot_eager", "aot_eager"]) + + def test_override_greater_than(self, device): + result = self._run_with_override(device, ">0:eager") + self.assertEqual(result, ["eager", "eager", "eager"]) + + def test_override_less_than(self, device): + result = self._run_with_override(device, "<2:aot_eager") + self.assertEqual(result, ["aot_eager", "aot_eager"]) + + def test_override_less_or_equal(self, device): + result = self._run_with_override(device, "<=1:aot_eager") + self.assertEqual(result, ["aot_eager", "aot_eager"]) + + def test_override_single_id(self, device): + result = self._run_with_override(device, "1:aot_eager") + self.assertEqual(result, ["aot_eager"]) + + def test_override_multiple_ids(self, device): + result = self._run_with_override(device, "0,2:aot_eager") + self.assertEqual(result, ["aot_eager", "aot_eager"]) + + def test_override_range(self, device): + result = self._run_with_override(device, "1-2:eager") + self.assertEqual(result, ["eager", "eager"]) + + def test_multiple_rules(self, device): + result = self._run_with_override(device, "0:aot_eager;1:inductor;3:eager") + self.assertEqual(result, ["aot_eager", "inductor", "eager"]) + + def test_conflicting_rules_raise(self, device): + with self.assertRaisesRegex( + torch._dynamo.exc.InternalTorchDynamoError, + "Conflicting backend override", + ): + self._run_with_override(device, ">=0:aot_eager;>=1:inductor") + + def test_complex_config(self, device): + result = self._run_with_override(device, "0:aot_eager;>=2:inductor") + self.assertEqual(result, ["aot_eager", "inductor", "inductor"]) + + def test_override_with_backward(self, device): + """Verify that backend override works when backward compilation occurs.""" + from torch._dynamo.graph_id_filter import ( + _create_backend_router, + get_backend_override_for_compile_id, + ) + + torch._dynamo.reset() + _create_backend_router.cache_clear() + overrides_applied = [] + original_get_override = get_backend_override_for_compile_id + + def tracking_get_override(compile_id, config_str): + result = original_get_override(compile_id, config_str) + if result is not None: + overrides_applied.append(compile_id.frame_id) + return result + + def fn(x): + return (x * 2 + 1).sum() + + with ( + patch.object( + torch._dynamo.config, "debug_backend_override", ">=0:aot_eager" + ), + patch( + "torch._dynamo.output_graph.get_backend_override_for_compile_id", + tracking_get_override, + ), + ): + compiled_fn = torch.compile(fn, backend="eager") + x = torch.randn(10, device=device, requires_grad=True) + result = compiled_fn(x) + result.backward() + + self.assertEqual(overrides_applied, [0]) + self.assertIsNotNone(x.grad) + + +instantiate_device_type_tests( + TestBackendOverrideIntegration, globals(), except_for="mps" +) + + +class TestInductorConfigOverrideIntegration(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def tearDown(self): + torch._dynamo.reset() + super().tearDown() + + def test_config_router_single_graph(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter("0:triton.cudagraph_skip_dynamic_graphs=False") + self.assertEqual( + router.get_value_for_graph(0), + {"triton.cudagraph_skip_dynamic_graphs": False}, + ) + self.assertIsNone(router.get_value_for_graph(1)) + + def test_config_router_multiple_options(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter( + "0:triton.cudagraphs=False,triton.cudagraph_trees=False" + ) + self.assertEqual( + router.get_value_for_graph(0), + {"triton.cudagraphs": False, "triton.cudagraph_trees": False}, + ) + + def test_config_router_comparison(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter(">1:triton.cudagraphs=True") + self.assertIsNone(router.get_value_for_graph(0)) + self.assertIsNone(router.get_value_for_graph(1)) + self.assertEqual(router.get_value_for_graph(2), {"triton.cudagraphs": True}) + + def test_config_router_range(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter("1-3:triton.cudagraphs=False") + self.assertIsNone(router.get_value_for_graph(0)) + self.assertEqual(router.get_value_for_graph(1), {"triton.cudagraphs": False}) + self.assertEqual(router.get_value_for_graph(2), {"triton.cudagraphs": False}) + self.assertEqual(router.get_value_for_graph(3), {"triton.cudagraphs": False}) + self.assertIsNone(router.get_value_for_graph(4)) + + def test_config_router_value_types(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter( + "0:bool_opt=True,int_opt=42,float_opt=3.14,str_opt=hello,none_opt=None" + ) + config = router.get_value_for_graph(0) + self.assertEqual(config["bool_opt"], True) + self.assertEqual(config["int_opt"], 42) + self.assertAlmostEqual(config["float_opt"], 3.14) + self.assertEqual(config["str_opt"], "hello") + self.assertIsNone(config["none_opt"]) + + def test_config_router_aggregation(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter("0:a=1;>=0:b=2") + # Graph 0 matches both rules, configs are merged + self.assertEqual(router.get_value_for_graph(0), {"a": 1, "b": 2}) + # Graph 1 matches only the second rule + self.assertEqual(router.get_value_for_graph(1), {"b": 2}) + + def test_config_router_conflict_raises(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + with self.assertRaisesRegex(ValueError, "Conflicting config override"): + GraphConfigRouter("0:a=1;>=0:a=2") + + def test_config_router_same_value_no_conflict(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter("0:a=1;>=0:a=1") + self.assertEqual(router.get_value_for_graph(0), {"a": 1}) + self.assertEqual(router.get_value_for_graph(1), {"a": 1}) + + def test_config_router_aggregation_multiple_rules(self, device): + from torch._dynamo.graph_id_filter import GraphConfigRouter + + router = GraphConfigRouter("0:a=1;1:b=2;>=0:c=3") + self.assertEqual(router.get_value_for_graph(0), {"a": 1, "c": 3}) + self.assertEqual(router.get_value_for_graph(1), {"b": 2, "c": 3}) + self.assertEqual(router.get_value_for_graph(2), {"c": 3}) + + def test_backend_router_conflict_raises(self, device): + from torch._dynamo.graph_id_filter import GraphBackendRouter + + with self.assertRaisesRegex(ValueError, "Conflicting backend override"): + GraphBackendRouter("0-5:eager;3-10:inductor") + + def test_backend_router_same_backend_no_conflict(self, device): + from torch._dynamo.graph_id_filter import GraphBackendRouter + + router = GraphBackendRouter("0:eager;>=0:eager") + self.assertIsNotNone(router.get_value_for_graph(0)) + + def test_get_inductor_config_override_empty(self, device): + from torch._dynamo.graph_id_filter import ( + get_inductor_config_override_for_compile_id, + ) + + result = get_inductor_config_override_for_compile_id(None, "") + self.assertIsNone(result) + + def test_combined_backend_and_config_override(self, device): + """ + Test combining backend override with config override. + + Scenario: Default backend is eager, but override all graphs to use + inductor with cudagraphs enabled, and additionally override graph 1 + to use cudagraph_skip_dynamic_graphs=False. + """ + from torch._dynamo.graph_id_filter import ( + _create_backend_router, + _create_inductor_config_router, + ) + + torch._dynamo.reset() + _create_backend_router.cache_clear() + _create_inductor_config_router.cache_clear() + + backends_used: list[str] = [] + configs_applied: list[dict] = [] + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + x = x * 2 + torch._dynamo.graph_break() + x = x - 1 + return x + + from torch._dynamo import output_graph + + original_wrap = output_graph._wrap_with_inductor_config + + def tracking_wrap(compiler_fn, config_patches): + configs_applied.append(config_patches) + return original_wrap(compiler_fn, config_patches) + + backend_override = ">=0:inductor" + from torch._dynamo.graph_id_filter import get_backend_override_for_compile_id + + original_get_backend = get_backend_override_for_compile_id + + def tracking_get_backend(compile_id, config_str): + result = original_get_backend(compile_id, config_str) + if result is not None: + backends_used.append("inductor") + return result + + # Use both overrides: + # - Backend: all graphs use inductor + # - Config: all graphs enable cudagraphs, graph 1 also disables + # cudagraph_skip_dynamic_graphs + with ( + patch.object( + torch._dynamo.config, + "debug_backend_override", + backend_override, + ), + patch.object( + torch._dynamo.config, + "debug_inductor_config_override", + "1:triton.cudagraph_skip_dynamic_graphs=False;>=0:triton.cudagraphs=True", + ), + patch( + "torch._dynamo.output_graph.get_backend_override_for_compile_id", + tracking_get_backend, + ), + patch.object(output_graph, "_wrap_with_inductor_config", tracking_wrap), + ): + compiled_fn = torch.compile(fn, backend="eager") + compiled_fn(torch.randn(10, device=device)) + + self.assertEqual(len(backends_used), 3) + self.assertEqual(backends_used, ["inductor", "inductor", "inductor"]) + + # All matching rules are aggregated. Graph 1 matches both rules. + self.assertEqual(len(configs_applied), 3) + self.assertEqual(configs_applied[0], {"triton.cudagraphs": True}) + self.assertEqual( + configs_applied[1], + { + "triton.cudagraph_skip_dynamic_graphs": False, + "triton.cudagraphs": True, + }, + ) + self.assertEqual(configs_applied[2], {"triton.cudagraphs": True}) + + def test_multiple_config_overrides_with_backend(self, device): + """ + Test multiple config overrides applied to different graphs with backend override. + + Scenario: Default backend is eager, override graphs 0,2 to use inductor, + and apply different config overrides to each. + """ + from torch._dynamo.graph_id_filter import ( + _create_backend_router, + _create_inductor_config_router, + ) + + torch._dynamo.reset() + _create_backend_router.cache_clear() + _create_inductor_config_router.cache_clear() + + backends_used: list[str] = [] + configs_applied: list[dict] = [] + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + x = x * 2 + torch._dynamo.graph_break() + x = x - 1 + torch._dynamo.graph_break() + x = x / 2 + return x + + from torch._dynamo import output_graph + + original_wrap = output_graph._wrap_with_inductor_config + + def tracking_wrap(compiler_fn, config_patches): + configs_applied.append(config_patches) + return original_wrap(compiler_fn, config_patches) + + # Build backend tracking map + backend_override = "0,2:inductor" + backend_str_map: dict[int, str] = {} + from torch._dynamo.graph_id_filter import GraphIdFilter + + for rule_str in backend_override.split(";"): + if ":" not in rule_str: + continue + colon_idx = rule_str.find(":") + filter_str = rule_str[:colon_idx].strip() + backend_str = rule_str[colon_idx + 1 :].strip() + gf = GraphIdFilter(filter_str) + for gid in range(10): + if gid in gf and gid not in backend_str_map: + backend_str_map[gid] = backend_str + + from torch._dynamo.graph_id_filter import get_backend_override_for_compile_id + + original_get_backend = get_backend_override_for_compile_id + + def tracking_get_backend(compile_id, config_str): + result = original_get_backend(compile_id, config_str) + if result is not None and compile_id.frame_id in backend_str_map: + backends_used.append(backend_str_map[compile_id.frame_id]) + return result + + # Use both overrides: + # - Backend: graphs 0,2 use inductor (graphs 1,3 stay with eager) + # - Config: graph 0 disables cudagraphs, graph 2 disables skip_dynamic_graphs + with ( + patch.object( + torch._dynamo.config, + "debug_backend_override", + backend_override, + ), + patch.object( + torch._dynamo.config, + "debug_inductor_config_override", + "0:triton.cudagraphs=False;2:triton.cudagraph_skip_dynamic_graphs=False", + ), + patch( + "torch._dynamo.output_graph.get_backend_override_for_compile_id", + tracking_get_backend, + ), + patch.object(output_graph, "_wrap_with_inductor_config", tracking_wrap), + ): + compiled_fn = torch.compile(fn, backend="eager") + compiled_fn(torch.randn(10, device=device)) + + self.assertEqual(len(backends_used), 2) + self.assertEqual(backends_used, ["inductor", "inductor"]) + + self.assertEqual(len(configs_applied), 2) + self.assertIn({"triton.cudagraphs": False}, configs_applied) + self.assertIn({"triton.cudagraph_skip_dynamic_graphs": False}, configs_applied) + + def test_config_override_backward_propagation(self, device): + """ + Verify that inductor config overrides are active at inductor compile + time for both forward and backward, across multiple graph breaks. + """ + import torch._functorch.config + from torch._dynamo.graph_id_filter import _create_inductor_config_router + from torch._inductor import ( + compile_fx as compile_fx_mod, + config as inductor_config, + ) + + torch._dynamo.reset() + _create_inductor_config_router.cache_clear() + + TRACKED_CONFIGS = [ + "triton.cudagraphs", + "triton.dense_indexing", + "triton.cudagraph_skip_dynamic_graphs", + ] + + def _read_config(key): + obj = inductor_config + for part in key.split("."): + obj = getattr(obj, part) + return obj + + baseline = {k: _read_config(k) for k in TRACKED_CONFIGS} + configs_at_compile: dict[tuple[int, bool], dict] = {} + + original_compile_fx = compile_fx_mod.compile_fx + original_inner_compile = compile_fx_mod.compile_fx_inner + + def tracking_inner_compile(gm, example_inputs, **kwargs): + compile_id = torch._guards.CompileContext.current_compile_id() + is_backward = kwargs.get("is_backward", False) + snapshot = {k: _read_config(k) for k in TRACKED_CONFIGS} + configs_at_compile[(compile_id.frame_id, is_backward)] = snapshot + return original_inner_compile(gm, example_inputs, **kwargs) + + def tracking_compile_fx(model_, example_inputs_, *args, **kwargs): + # Inject tracking inner_compile so compile_fx's config.patch + # wrapping covers it for both forward and backward. + if "inner_compile" not in kwargs: + kwargs["inner_compile"] = tracking_inner_compile + return original_compile_fx(model_, example_inputs_, *args, **kwargs) + + def fn(x): + y = x * 2 + 1 + torch._dynamo.graph_break() + z = y.sin() + torch._dynamo.graph_break() + return z.exp().sum() + + # Overlapping rules (all three configs default to False): + config_override = ( + ">=1:triton.cudagraphs=True;" + "0-1:triton.dense_indexing=True;" + "1:triton.cudagraph_skip_dynamic_graphs=True" + ) + expected_overrides = { + 0: {"triton.dense_indexing": True}, + 1: { + "triton.cudagraphs": True, + "triton.dense_indexing": True, + "triton.cudagraph_skip_dynamic_graphs": True, + }, + 2: { + "triton.cudagraphs": True, + }, + } + + with ( + patch.object( + torch._dynamo.config, + "debug_inductor_config_override", + config_override, + ), + patch.object(compile_fx_mod, "compile_fx", tracking_compile_fx), + patch.object(torch._functorch.config, "enable_autograd_cache", False), + ): + compiled_fn = torch.compile(fn) + x = torch.randn(10, device=device, requires_grad=True) + result = compiled_fn(x) + result.backward() + + # Verify each graph has fwd+bwd, correct overrides, no cross-graph + # leak, and identical configs for forward and backward. + for gid in range(3): + self.assertIn((gid, False), configs_at_compile, f"graph {gid} fwd missing") + self.assertIn((gid, True), configs_at_compile, f"graph {gid} bwd missing") + expected = {**baseline, **expected_overrides[gid]} + for is_bw in [False, True]: + phase = "backward" if is_bw else "forward" + self.assertEqual( + configs_at_compile[(gid, is_bw)], + expected, + f"graph {gid} {phase}: config mismatch", + ) + + self.assertIsNotNone(x.grad) + + +instantiate_device_type_tests( + TestInductorConfigOverrideIntegration, globals(), only_for=["cpu", "cuda"] +) + + +class TestConfigOverrideValidation(TestCase): + def setUp(self): + super().setUp() + from torch._dynamo.graph_id_filter import ( + _validate_backend_names, + _validate_dynamo_config_keys, + _validate_inductor_config_keys, + ) + + _validate_backend_names.cache_clear() + _validate_dynamo_config_keys.cache_clear() + _validate_inductor_config_keys.cache_clear() + torch._dynamo.reset() + + def tearDown(self): + torch._dynamo.reset() + super().tearDown() + + @torch._dynamo.config.patch( + debug_backend_override="0:not_a_real_backend", + ) + def test_invalid_backend_raises_on_compile(self): + def fn(x): + return x + 1 + + with self.assertRaisesRegex(ValueError, "not_a_real_backend"): + torch.compile(fn, backend="eager")(torch.randn(4)) + + @torch._dynamo.config.patch( + debug_dynamo_config_override="0:nonexistent_dynamo_option=True", + ) + def test_invalid_dynamo_config_raises_on_compile(self): + def fn(x): + return x + 1 + + with self.assertRaisesRegex(ValueError, "nonexistent_dynamo_option"): + torch.compile(fn, backend="eager")(torch.randn(4)) + + @torch._dynamo.config.patch( + debug_inductor_config_override="0:nonexistent_inductor_option=True", + ) + def test_invalid_inductor_config_raises_on_compile(self): + def fn(x): + return x + 1 + + with self.assertRaisesRegex(ValueError, "nonexistent_inductor_option"): + torch.compile(fn, backend="eager")(torch.randn(4)) + + +class TestDynamoConfigOverrideIntegration(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def tearDown(self): + torch._dynamo.reset() + super().tearDown() + + @torch._dynamo.config.patch( + specialize_float=False, + verbose=False, + debug_dynamo_config_override=( + "0:specialize_float=True;1:verbose=True,recompile_limit=10" + ), + ) + def test_dynamo_config_override_per_graph(self): + """Per-graph dynamo config overrides target the right graphs. + + Graph 0: specialize_float overridden True (base False) + Graph 1: verbose+recompile_limit overridden (multiple keys) + Graph 2: no override, keeps base values + """ + from torch._dynamo.graph_id_filter import _create_dynamo_config_router + + _create_dynamo_config_router.cache_clear() + + observed: dict[int, dict] = {} + + def capturing_backend(gm, example_inputs): + fid = torch._guards.CompileContext.current_compile_id().frame_id + observed[fid] = { + "specialize_float": torch._dynamo.config.specialize_float, + "verbose": torch._dynamo.config.verbose, + "recompile_limit": torch._dynamo.config.recompile_limit, + } + return gm + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + x = x * 2 + torch._dynamo.graph_break() + return x - 1 + + torch.compile(fn, backend=capturing_backend)(torch.randn(4)) + + self.assertTrue(observed[0]["specialize_float"]) + self.assertFalse(observed[0]["verbose"]) + + self.assertFalse(observed[1]["specialize_float"]) + self.assertTrue(observed[1]["verbose"]) + self.assertEqual(observed[1]["recompile_limit"], 10) + + self.assertFalse(observed[2]["specialize_float"]) + self.assertFalse(observed[2]["verbose"]) + + def test_dynamo_config_override_warning(self): + from torch._dynamo.graph_id_filter import _create_dynamo_config_router + + _create_dynamo_config_router.cache_clear() + with self.assertWarnsRegex( + UserWarning, "TORCH_COMPILE_OVERRIDE_DYNAMO_CONFIGS" + ): + _create_dynamo_config_router("0:specialize_float=True") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/dynamo/test_export_xpu.py b/test/xpu/dynamo/test_export_xpu.py new file mode 100644 index 0000000000..70238f0c7f --- /dev/null +++ b/test/xpu/dynamo/test_export_xpu.py @@ -0,0 +1,4688 @@ +# Owner(s): ["module: dynamo"] +""" +PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes +with test_export_persist_assert) +""" + +import copy +import functools +import inspect +import io +import operator +import os +import subprocess +import sys +import unittest +from collections.abc import Sequence +from enum import Enum +from unittest.mock import patch + +import torch +import torch._dynamo +import torch._dynamo.test_case +import torch._dynamo.testing +from functorch.experimental.control_flow import cond +from torch._dynamo import config +from torch._dynamo.exc import UserError +from torch._dynamo.testing import normalize_gm +from torch._higher_order_ops.out_dtype import out_dtype +from torch._subclasses import fake_tensor +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.testing._internal import common_utils +from torch.testing._internal.common_device_type import instantiate_device_type_tests + + +@torch._dynamo.assume_constant_result +def dynamo_assume_constant_result_global_function(): + return "test" + + +class ExportTests(torch._dynamo.test_case.TestCase): + # TODO(voz): Refactor to a shared test function. + # The tests in this file are a little redundant, + # They all take a func, run it with eager, then export it, then compare + def test_export(self): + def pre_attention_state_ops(input, mems, state): + lc_key = state[0] + lc_val = state[1] + bar = [] + for _ in range(4): + bar2 = [] + for _ in range(3): + bar2.append( + lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) + ) + bar.append(bar2) + + return bar + + def func(): + mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) + state = [ + torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), + torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), + ] + i = torch.tensor( + [ + [0.0313, -0.1487, -0.3846, -0.5321], + [-1.7073, 1.3331, -0.0890, -1.4935], + [-0.8314, -0.1862, -0.5935, 1.5232], + ] + ) + return pre_attention_state_ops(i, mems, state) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func() + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)() + out_graph = exported[0] + + dynamo_result = out_graph() + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_no_tensor_computation_fail(self): + with self.assertRaisesRegex( + AssertionError, + "Failed to produce a graph", + ): + inp = [torch.randn(3)] + inp2 = 2 + inps = [inp, inp2] + + def func(x, y): + return x + + torch._dynamo.export(func, same_signature=False)(*inps) + + def test_no_tensor_computation(self): + inp = [torch.randn(3)] + inp2 = 2 + inps = [inp, inp2] + + def func(x, y): + return x + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, x, y): + arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + x = arg0 + return pytree.tree_unflatten([x], self._out_spec)""", + ) + + def test_export_empty_graph_no_error(self): + def func(x): + return len(x) + + exported = torch._dynamo.export(func)(torch.randn(5)) + out_graph = exported[0] + result = out_graph(torch.randn(5)) + self.assertEqual(result, 5) + + def test_no_tensor_computation_2(self): + inp = torch.randn(3) + inp2 = 2 + inps = [inp, inp2] + + def func(x, y): + return y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, x, y): + arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + x = arg0 + return pytree.tree_unflatten([2], self._out_spec)""", + ) + + def test_export_mismatched_out(self): + def func(x): + y = x + 1 + return ([x, x], (y, y)) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) + out_graph = exported[0] + + dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_shape_control_flow_1(self): + def func(x): + if x.shape[0] > 10: + return x.cos() + return x.sin() + + opt_func = torch.compile(func, backend="eager") + real_result = opt_func(torch.ones(6, 4)) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(torch.ones(6, 4)) + out_graph, out_guards = exported + + dynamo_result = out_graph(torch.ones(6, 4)) + + from torch._guards import GuardSource + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + hit = False + for guard in out_guards: + if guard.source == GuardSource.SHAPE_ENV: + hit = True + self.assertExpectedInline( + guard.code_list, + """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] and L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", + ) + break + + self.assertTrue(hit) + + def test_export_control_flow_with_getattr(self): + class Animal(Enum): + COW = "moo" + + class MyModule(torch.nn.Module): + def __init__(self, a): + super().__init__() + self.a = a + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + else: + raise ValueError("bad") + + module = MyModule("moo") + input = (torch.ones(4, 3),) + resA = module(*input) + graph, _ = torch._dynamo.export(module)(*input) + resB = graph(*input) + self.assertTrue(torch._dynamo.utils.same(resA, resB)) + + def test_export_graph_bypass(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + ] + + def func(x): + first = x[2] + second = x[2] + return first * second + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_list_unpack(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + ] + + def func(x): + first = x[2] + second = x[2] + return x[0], first * second, x[1], x[2] + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_with_shallow_list_copy_wo_side_effects(self): + def f(x): + y = x.copy() + return y[0] + y[1] + + inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] + gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( + inp + ).graph_module + self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp))) + + def test_export_with_shallow_list_copy_with_side_effects(self): + def f(x): + y = x.copy() + x[0] = x[1] + y.append(torch.tensor([[100]])) + return x[0] + x[1], y[0] + y[1], y[2] + + inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] + gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( + inp + ).graph_module + res = gm(inp) + ref = f(inp) + self.assertTrue(torch._dynamo.utils.same(res, ref)) + self.assertEqual(res[0], res[1]) + + def test_export_mismatched_out_2(self): + def func(x): + y = x + 1 + return ([x, x], (y, y)) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) + out_graph = exported[0] + + dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_graph_with_list(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + torch.tensor([0.4, 0.4]), + ] + + def func(x): + first = x[2] + second = x[2] + return first * second, x + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_graph_with_complex_reorder(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + torch.tensor([0.4, 0.4]), + ] + + def func(x): + first = x[0] + second = x[1] + third = x[2] + return third, first, second, first * second, first * third + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_immutable_list_dict(self): + class M(torch.nn.Module): + def forward(self, x1, x2): + return [x1 + x2], {"moo1": x1 * x1, "moo2": x2 * x2} + + x1 = torch.randn(2, 3) + x2 = torch.randn(2, 3) + model = M() + + fx_model = make_fx( + model, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + _error_on_data_dependent_ops=True, + )(*[x1, x2]) + ep = torch.export.export(fx_model, (x1, x2)) + res = torch.compile(ep.module(), backend="eager", dynamic=True, fullgraph=True)( + x1, x2 + ) + self.assertTrue(torch._dynamo.utils.same(res, M()(x1, x2))) + + def test_dupes(self): + inp = torch.tensor([0.1, 0.1]) + + def func(x): + y = x + 1 + return y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_2(self): + inp = torch.tensor([0.1, 0.1]) + + def func(x): + y = x + 1 + return y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.4, 0.4]) + inps = [inp, inp2] + + def func(x, z): + y = x + 1 + return y, y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass_with_non_tensor_arg(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return y, y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass_reorder_with_non_tensor_arg(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return z, y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_dupes_and_bypass_with_non_tensor_output(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return y[0].item(), y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_zeroes_in_and_out_different_shape_on_test(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + return [[a], [b, c], [a + b], [[c + c]]] + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_zeroes_in_new_shape_scalar_out(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + return a[0].item() + b[0].item() + c[0].item() + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_zeroes_in_new_shape_scalar_out_permute(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + return b[0].item() + c[0].item() + a[0].item() + a[0].item() + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_func_return(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + x = a + b + c + + def func2(y): + return x * y + + return func2(x) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dict_return(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + x = a + b + c + return {"a": x} + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_with_aten_graph(self): + def pre_attention_state_ops(input, mems, state): + lc_key = state[0] + lc_val = state[1] + bar = [] + for _ in range(4): + bar2 = [] + for _ in range(3): + bar2.append( + lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) + ) + bar.append(bar2) + + return bar + + def func(): + mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) + state = [ + torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), + torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), + ] + i = torch.tensor( + [ + [0.0313, -0.1487, -0.3846, -0.5321], + [-1.7073, 1.3331, -0.0890, -1.4935], + [-0.8314, -0.1862, -0.5935, 1.5232], + ] + ) + return pre_attention_state_ops(i, mems, state) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func() + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)() + out_graph = exported[0] + + dynamo_result = out_graph() + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_no_tensor_computation_with_aten_graph(self): + inp = [torch.randn(3)] + inp2 = 2 + inps = [inp, inp2] + + def func(x, y): + return x + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, x, y): + arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + arg0_1 = arg0 + return pytree.tree_unflatten([arg0_1], self._out_spec)""", + ) + + def test_no_tensor_computation_2_with_aten_graph(self): + inp = torch.randn(3) + inp2 = 2 + inps = [inp, inp2] + + def func(x, y): + return y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, x, y): + arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + arg0_1 = arg0 + return pytree.tree_unflatten([2], self._out_spec)""", + ) + + def test_export_mismatched_out_with_aten_graph(self): + def func(x): + y = x + 1 + return ([x, x], (y, y)) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)( + torch.tensor([[[1.3737, 0.1]]]) + ) + out_graph = exported[0] + + dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_graph_bypass_with_aten_graph(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + ] + + def func(x): + first = x[2] + second = x[2] + return first * second + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_list_unpack_with_aten_graph(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + ] + + def func(x): + first = x[2] + second = x[2] + return x[0], first * second, x[1], x[2] + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_mismatched_out_2_with_aten_graph(self): + def func(x): + y = x + 1 + return ([x, x], (y, y)) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)( + torch.tensor([[[1.3737, 0.1]]]) + ) + out_graph = exported[0] + + dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_graph_with_list_with_aten_graph(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + torch.tensor([0.4, 0.4]), + ] + + def func(x): + first = x[2] + second = x[2] + return first * second, x + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_graph_with_complex_reorder_with_aten_graph(self): + inp = [ + torch.tensor([0.1, 0.1]), + torch.tensor([0.2, 0.2]), + torch.tensor([0.3, 0.3]), + torch.tensor([0.4, 0.4]), + ] + + def func(x): + first = x[0] + second = x[1] + third = x[2] + return third, first, second, first * second, first * third + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + + def func(x): + y = x + 1 + return y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_2_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + + def func(x): + y = x + 1 + return y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(inp) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + + dynamo_result = out_graph(inp) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.4, 0.4]) + inps = [inp, inp2] + + def func(x, z): + y = x + 1 + return y, y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return y, y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return z, y, y + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self): + inp = torch.tensor([0.1, 0.1]) + inp2 = torch.tensor([0.1, 0.1]) + inp3 = 4 + inps = [inp, inp2, inp3] + + def func(x, z, k): + y = x + k + return y[0].item(), y, z + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + return [[a], [b, c], [a + b], [[c + c]]] + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_func_return_with_aten_graph(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + x = a + b + c + + def func2(y): + return x * y + + return func2(x) + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_dict_return_with_aten_graph(self): + inp = torch.zeros(10) + inp2 = torch.zeros(10) + inp3 = torch.zeros(10) + inps = [inp, inp2, inp3] + + inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] + + def func(a, b, c): + x = a + b + c + return {"a": x} + + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps_rand) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps_rand) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_with_stack_trace(self): + inp = torch.randn(4, 4) + + class MyBlock(torch.nn.Module): + def forward(self, x): + x = torch.nn.functional.linear(x, torch.randn(4, 4)) + return torch.cos(x).relu() + 1 + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.block = MyBlock() + + def forward(self, x): + out = self.block(x) + return out + + exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp) + out_graph = exported[0] + + for node in out_graph.graph.nodes: + if node.op not in {"placeholder", "output"}: + self.assertTrue(node.stack_trace is not None) + self.assertTrue(node.meta["nn_module_stack"] is not None) + self.assertTrue(node.meta["source_fn_stack"] is not None) + + torch._dynamo.reset() + + exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp) + out_graph = exported[0] + for node in out_graph.graph.nodes: + if node.op == "call_function": + self.assertTrue(node.stack_trace is not None) + self.assertTrue(node.meta["nn_module_stack"] is not None) + self.assertTrue(node.meta["source_fn_stack"] is not None) + self.assertTrue(node.meta["val"] is not None) + self.assertTrue(node.meta["original_aten"] is not None) + + def test_export_preserves_nn_module_stack_for_get_attr(self): + inp = torch.randn(4, 4) + + class MyBlock(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1, 1)) + self.buffer = torch.nn.Buffer(torch.ones(1, 1)) + + def forward(self, x): + x = torch.nn.functional.linear(x, torch.randn(4, 4)) + return torch.cos(x).relu() + self.weight + self.buffer + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.block = MyBlock() + + def forward(self, x): + out = self.block(x) + return out + + m = MyModule() + exported = torch._dynamo.export(m, aten_graph=False)(inp) + out_graph = exported[0] + + attr_access_count = 0 + for node in out_graph.graph.nodes: + if node.op == "get_attr": + attr_access_count += 1 + self.assertTrue(node.meta["nn_module_stack"] is not None) + self.assertEqual(attr_access_count, 2) + + torch._dynamo.reset() + + exported = torch._dynamo.export(m, aten_graph=True)(inp) + out_graph = exported[0] + + attr_access_count = 0 + for node in out_graph.graph.nodes: + if node.op == "get_attr": + attr_access_count += 1 + self.assertTrue(node.meta["nn_module_stack"] is not None) + self.assertEqual(attr_access_count, 2) + + def test_export_compare_optimize_with_make_fx(self): + inp = torch.tensor([0.1, 0.1]) + linear = torch.nn.Linear(2, 2) + + def func(x): + x = x + 1 + y = x.t() + y = y.relu() + y = linear(y) + return y + + exported = torch._dynamo.export(func, aten_graph=True)(inp) + out_graph = exported[0] + export_result = out_graph(inp) + + torch._dynamo.reset() + + def compiler(gm, sample_inputs): + def fw(*args): + aten_gm = make_fx(gm)(*args) + return aten_gm(*args) + + return fw + + opt_func = torch.compile(func, backend=compiler, fullgraph=True, dynamic=True) + make_fx_result_through_backend = opt_func(inp) + + fx_g = make_fx(func)(inp) + make_fx_result_through_direct = fx_g(inp) + + self.assertTrue( + torch._dynamo.utils.same(make_fx_result_through_backend, export_result) + ) + self.assertTrue( + torch._dynamo.utils.same(make_fx_result_through_direct, export_result) + ) + + def test_export_with_constant_method_on_module(self): + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 2)) + self.linear = torch.nn.Linear(2, 2) + + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return torch.nonzero(x) + + def forward(self, x): + y = torch.sin(x) + x = self.linear(x) + y = self.helper_fn(x) + return y + + module = MyModule() + real_result = module(torch.tensor([[1.0, 0], [0, 0]])) + module = MyModule() + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) + result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_method_on_module_invoke_twice(self): + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 2)) + self.linear = torch.nn.Linear(2, 2) + + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return torch.nonzero(x) + + def forward(self, x): + y = torch.sin(x) + x = self.linear(x) + y = self.helper_fn(x) + self.helper_fn(x) + return y + + module = MyModule() + real_result = module(torch.tensor([[1.0, 0], [0, 0]])) + module = MyModule() + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) + result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_free_function(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + return torch.nonzero(x) + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 2)) + self.linear = torch.nn.Linear(2, 2) + + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return torch.nonzero(x) + + def forward(self, x): + y = torch.sin(x) + x = self.linear(x) + y = helper_fn(x) + self.helper_fn(x) + return y + + module = MyModule() + real_result = module(torch.tensor([[1.0, 0], [0, 0]])) + module = MyModule() + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) + result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_global_function(self): + class MyModule(torch.nn.Module): + def forward(self): + a = dynamo_assume_constant_result_global_function() + b = dynamo_assume_constant_result_global_function() + return a + b + + module = MyModule() + graph, _ = torch._dynamo.export(module)() + result = graph() + self.assertEqual(result, "testtest") + + def test_export_with_constant_free_function_and_class_method(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + return torch.nonzero(x) + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 2)) + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + y = torch.sin(x) + x = self.linear(x) + y = helper_fn(x) + return y + + module = MyModule() + real_result = module(torch.tensor([[1.0, 0], [0, 0]])) + module = MyModule() + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) + result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_free_function_and_class_method_multiarg(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + return torch.nonzero(x) + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 2)) + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x, z): + y = torch.sin(x) + x = self.linear(x) + y = helper_fn(x) + helper_fn(z) + return y + + module = MyModule() + real_result = module( + torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) + ) + module = MyModule() + graph, _ = torch._dynamo.export(module)( + torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) + ) + result = graph( + torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]]) + ) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph( + torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]]) + ) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_free_function_and_class_method_multiarg_diff(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + return torch.nonzero(x) + + class MyModule(torch.nn.Module): + def forward(self, x, z): + y = helper_fn(x) + helper_fn(z) + return y + + module = MyModule() + real_result = module( + torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) + ) + module = MyModule() + graph, _ = torch._dynamo.export(module)( + torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) + ) + result = graph( + torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]]) + ) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + result = graph( + torch.tensor([[1, 0], [0.25, 0.25]]), + torch.tensor([[0.33, 0.33], [0.25, 0.25]]), + ) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_tuple_nonzero(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return (torch.nonzero(x), torch.nonzero(x)) + + def forward(self, x): + y = torch.tensor([0.5]) + elements = self.helper_fn(x) + all_y = [] + for element in elements: + for item in element: + all_y.append(y * item) + return all_y + + module = MyModule() + real_result = module(torch.tensor([1.0, 1.0])) + graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) + + # Tensor input can be almost anything here, and the result will capture what we + # made constant at compile time. + result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_list_nonzero(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return [torch.nonzero(x), torch.nonzero(x)] + + def forward(self, x): + y = torch.tensor([0.5]) + elements = self.helper_fn(x) + all_y = [] + for element in elements: + for item in element: + all_y.append(y * item) + return all_y + + module = MyModule() + real_result = module(torch.tensor([1.0, 1.0])) + graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) + + # Tensor input can be almost anything here, and the result will capture what we + # made constant at compile time. + result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_list_nonzero_free_function(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + return [torch.nonzero(x), torch.nonzero(x)] + + class MyModule(torch.nn.Module): + def forward(self, x): + y = torch.tensor([0.5]) + elements = helper_fn(x) + all_y = [] + for element in elements: + for item in element: + all_y.append(y * item) + return all_y + + module = MyModule() + real_result = module(torch.tensor([1.0, 1.0])) + graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) + + # Tensor input can be almost anything here, and the result will capture what we + # made constant at compile time. + result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_dict_values(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return {"x": x, "x^2": x * x} + + def forward(self, x): + y = torch.tensor([0.5]) + elements = self.helper_fn(x) + y = y * elements["x"] + y = y * elements["x^2"] + return y + + module = MyModule() + real_result = module(torch.tensor([2.0, 2.0])) + graph, _ = torch._dynamo.export(module)(torch.tensor([2.0, 2.0])) + + # Tensor input can be almost anything here, and the result will capture what we + # made constant at compile time. + result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_none_control_flow(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + if x.item() < 0: + return None + else: + return x + + def forward(self, x): + y = torch.tensor([0.5]) + x = self.helper_fn(x) + if x is None: + return y + return y * x + + module = MyModule() + real_result = module(torch.tensor([-1])) + + # X is negative, so .item() < 0, which means we return y + self.assertEqual(real_result, torch.tensor([0.5])) + + graph, _ = torch._dynamo.export(module)(torch.tensor([-1])) + result = graph(torch.tensor([2])) + # X is positive, but we compiled helper_fn to return None, so it will still return y + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_not_none_control_flow(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + if x.item() < 0: + return None + else: + return x + + def forward(self, x): + y = torch.tensor([0.5]) + x = self.helper_fn(x) + if x is None: + return y + return y * x + + module = MyModule() + real_result = module(torch.tensor([2])) + + # X is positive, so .item() > 0, which means we return y * x + self.assertEqual(real_result, torch.tensor([1.0])) + + graph, _ = torch._dynamo.export(module)(torch.tensor([2])) + result = graph(torch.tensor([-0.5])) + # X is negative, but we compiled helper_fn to return x, so it will still return y * x + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_none_control_flow_free_func(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + if x.item() < 0: + return None + else: + return x + + class MyModule(torch.nn.Module): + def forward(self, x): + y = torch.tensor([0.5]) + x = helper_fn(x) + if x is None: + return y + return y * x + + module = MyModule() + real_result = module(torch.tensor([-1])) + + # X is negative, so .item() < 0, which means we return y + self.assertEqual(real_result, torch.tensor([0.5])) + + graph, _ = torch._dynamo.export(module)(torch.tensor([-1])) + result = graph(torch.tensor([2])) + # X is positive, but we compiled helper_fn to return None, so it will still return y + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_not_none_control_flow_pos(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + if x.item() < 0: + return None + else: + return x + + def forward(self, x): + y = torch.tensor([0.5]) + x = self.helper_fn(x) + if x is None: + return y + return y * x + + module = MyModule() + real_result = module(torch.tensor([2])) + + # X is positive, so .item() > 0, which means we return y * x + self.assertEqual(real_result, torch.tensor([1.0])) + + graph, _ = torch._dynamo.export(module)(torch.tensor([2])) + result = graph(torch.tensor([-0.5])) + # X is negative, but we compiled helper_fn to return x, so it will still return y * x + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_not_none_control_flow_free_func(self): + @torch._dynamo.assume_constant_result + def helper_fn(x): + if x.item() < 0: + return None + else: + return x + + class MyModule(torch.nn.Module): + def forward(self, x): + y = torch.tensor([0.5]) + x = helper_fn(x) + if x is None: + return y + return y * x + + module = MyModule() + real_result = module(torch.tensor([2])) + + # X is positive, so .item() > 0, which means we return y * x + self.assertEqual(real_result, torch.tensor([1.0])) + + graph, _ = torch._dynamo.export(module)(torch.tensor([2])) + result = graph(torch.tensor([-0.5])) + # X is negative, but we compiled helper_fn to return x, so it will still return y * x + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_export_with_constant_not_return_const(self): + class MyModule(torch.nn.Module): + @torch._dynamo.assume_constant_result + def helper_fn(self, x): + return self.val + + def forward(self, x): + y = torch.tensor([0.5]) + x = self.helper_fn(x) + if x == "A": + return y + return -1 + + module = MyModule() + module.val = "A" + resA = module(torch.tensor([2])) + graph, _ = torch._dynamo.export(module)(torch.tensor([2])) + module.val = "B" + resB = graph(torch.tensor([2])) + self.assertTrue(torch._dynamo.utils.same(resA, resB)) + + def test_export_with_builtin_op_on_assume_constant(self): + @torch._dynamo.assume_constant_result + def get_y(y) -> torch.Tensor: + return y + + class Bob(torch.nn.Module): + def __init__(self, p, val) -> None: + super().__init__() + self.p = p + self.y = torch.nn.Parameter(torch.tensor(val)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # This only looks dynamic but it's actually a constant value + if get_y(self.y) < self.p: + return torch.cat([x, x]) + else: + return x + + model = Bob(0.5, 0.3) + inp = torch.ones(3, 4) + graph, _ = torch._dynamo.export(model)(inp) + self.assertEqual(model(inp), graph(inp)) + + def test_export_with_constant_in_unspecialized_nn_module(self): + class Module(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.y = y + + @torch._dynamo.assume_constant_result + def check(self): + return self.y[0].item() == 1 + + def forward(self, x): + # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo + self.device = x.device + + if self.check(): + return x + 1 + else: + return x + 2 + + model = Module(torch.tensor([1])) + inp = torch.ones(3, 4) + graph, _ = torch._dynamo.export(model)(inp) + self.assertEqual(model(inp), graph(inp)) + + def test_export_decomp(self): + def f(x): + return x.t() + x.t() + + def nop(x): + return x.cos() + + graph, _ = torch._dynamo.export( + f, + aten_graph=True, + decomposition_table={torch.ops.aten.t.default: nop}, + )(torch.randn(5)) + self.assertEqual( + len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), + 0, + ) + + graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)( + torch.randn(5) + ) + self.assertEqual( + len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), + 2, + ) + + def test_export_decomp_asserts_bad_args(self): + def f(x): + return x.t() + x.t() + + def nop(x): + return x.cos() + + with self.assertRaises(AssertionError): + torch._dynamo.export( + f, + (torch.randn(5)), + aten_graph=False, + decomposition_table={torch.ops.aten.t.default: nop}, + ) + + @config.patch(capture_scalar_outputs=True) + def test_export_with_module_layer(self): + from functorch.experimental.control_flow import cond + + class Module(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, pred, x): + def true_fn(val): + return self.linear(val) * torch.tensor(2) + + def false_fn(val): + return self.linear(val) * torch.tensor(-1) + + return cond(pred, true_fn, false_fn, [x]) + + mod = Module() + x = torch.randn([3, 3]) + pred = torch.tensor(x[0][0].item() < 0) + real_result = mod.forward(pred, x) + + torch._dynamo.reset() + + exported = torch._dynamo.export(mod.forward)(pred, x) + out_graph = exported[0] + + dynamo_result = out_graph(pred, x) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + # New X, just to show we did not specialize + x = x * -1 + pred = torch.tensor(x[0][0].item() < 0) + real_result_2 = mod.forward(pred, x) + dynamo_result_2 = out_graph(pred, x) + self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2)) + + @config.patch(capture_scalar_outputs=True) + def test_export_with_cond_branches_calling_methods(self): + from functorch.experimental.control_flow import cond + + class Module(torch.nn.Module): + # ok + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def t(self, val): + return val + 1 + + def f(self, val): + return val - 1 + + def true_fn(self, val): + return self.linear(val) + self.t(val) + + def false_fn(self, val): + return self.linear(val) - self.f(val) + + def forward(self, pred, x): + return cond(pred, self.true_fn, self.false_fn, [x]) + + mod = Module() + x = torch.randn([3, 3]) + pred = torch.tensor(x[0][0].item() < 0) + real_result = mod.forward(pred, x) + out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) + dynamo_result = out_graph(pred, x) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + @config.patch(capture_scalar_outputs=True) + def test_export_with_cond_closure(self): + from functorch.experimental.control_flow import cond + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, pred, x): + def true_fn(x): + return x * 2 + + def false_fn(x): + return x - 2 + + return cond(pred, true_fn, false_fn, [x]) + + class Bar(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, pred, x): + def true_fn(x): + return x * 2 + + def false_fn(x): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) + + class FooBar(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, pred, x): + y = x + x + + def true_fn(x, y): + return self.linear(x) * (x + y) + + def false_fn(x, y): + return x * (y - x) + + return cond(pred, true_fn, false_fn, [x, y]) + + for Module in [Foo, Bar, FooBar]: + mod = Module() + x = torch.randn([3, 3], requires_grad=True) + pred = torch.tensor(x[0][0].item() < 0) + real_result = mod.forward(pred, x) + out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) + dynamo_result = out_graph(pred, x) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_with_cond_with_closed_function(self): + def hello(x): + return x + 1 + + def hi(x): + return x + 2 + + def foo(pred, x): + def true_fn(x): + return hello(x) + + def false_fn(x): + return hi(x) + + return cond(pred, true_fn, false_fn, [x]) + + x = torch.randn(5) + pred = x[0] > 0 + real_result = foo(pred, x) + out_graph, _ = torch._dynamo.export(foo)(pred, x) + dynamo_result = out_graph(pred, x) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_with_cond_dynamic_shape_pred(self): + from functorch.experimental.control_flow import cond + + class Module(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return x + x + + def false_fn(x): + return x[:2].clone() + + return cond(x.shape[0] <= 2, true_fn, false_fn, [x]) + + class Module2(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return x + x + + def false_fn(x): + return x[:2].clone() + + return cond(x.shape[0] <= 2, true_fn, false_fn, (x,)) + + mods = [Module(), Module2()] + for mod in mods: + x = torch.randn(2, 2) + out_graph, _ = torch._dynamo.export(mod)(x) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + sym_size_int = torch.ops.aten.sym_size.int(l_x_, 0) + le = sym_size_int <= 2; sym_size_int = None + cond_true_0 = self.cond_true_0 + cond_false_0 = self.cond_false_0 + cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, (l_x_,)); le = cond_true_0 = cond_false_0 = l_x_ = None + getitem_3 = cond[0] + sym_size_int_1 = torch.ops.aten.sym_size.int(getitem_3, 0); getitem_3 = None + ge = sym_size_int_1 >= 2; sym_size_int_1 = None + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 2 on node 'ge'"); ge = _assert_scalar_default = None + getitem_2 = cond[0]; cond = None + return pytree.tree_unflatten([getitem_2], self._out_spec)""", + ) + self.assertExpectedInline( + out_graph.cond_true_0.code.strip(), + """\ +def forward(self, l_x_): + l_x__1 = l_x_ + add = l_x__1 + l_x__1; l_x__1 = None + return (add,)""", + ) + self.assertExpectedInline( + out_graph.cond_false_0.code.strip(), + """\ +def forward(self, l_x_): + l_x__1 = l_x_ + getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None + clone = getitem.clone(); getitem = None + return (clone,)""", + ) + # We could successfully export branches that return different sizes + torch._dynamo.export(mod)(torch.randn(3, 2)) + + # We specialize into one of the branches since predicate is a python boolean. + test_x = torch.randn(3, 2) + mod(test_x) + + def test_export_with_map_cond(self): + from functorch.experimental.control_flow import cond, map + + class Module(torch.nn.Module): + def inner(self, x, pred): + def true_fn(x): + return x + x + + def false_fn(x): + return x * x + + return cond(pred, true_fn, false_fn, [x]) + + def forward(self, pred, xs): + def body(x, pred): + return self.inner(x, pred) + + return map(body, xs, pred) + + mod = Module() + x = torch.randn(3, 2, 1) + pred_x = torch.tensor(True) + + y = torch.randn(4, 3, 2) + pred_y = torch.tensor(False) + real_result = mod(pred_y, y) + + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) + self.assertEqual(real_result, out_graph(pred_y, y)) + + def test_export_with_map_zero_sized_tensor(self): + from functorch.experimental.control_flow import map + + class Module(torch.nn.Module): + def forward(self, xs): + def body(x): + return x + 1 + + return map(body, xs) + + mod = Module() + xs = torch.randn(0, 2) + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "Observed exception", + ): + torch._dynamo.export(mod)(xs) + + def test_export_meta_val(self): + def f(x, y, z): + return x * y + z + + gm, _ = torch._dynamo.export( + f, + aten_graph=True, + )( + torch.ones(3, 2), + torch.zeros(3, 2), + torch.ones(3, 2), + ) + for node in gm.graph.nodes: + if node.op == "placeholder": + self.assertIn("val", node.meta) + + def test_input_container_type(self): + def f(x: torch.Tensor, y: list[torch.Tensor]) -> dict[str, torch.Tensor]: + return {"a": x.sum() + sum(y).sum()} + + inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) + + gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) + + self.assertEqual(gm(*inp), f(*inp)) + + @config.patch(assume_static_by_default=False) + def test_export_symbolic_shape(self): + def f(x: torch.Tensor) -> torch.Tensor: + return torch.empty(x.shape[0] * 2) + + inp = (torch.randn(6, 5),) + gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) + + has_sym_size = False + for node in gm.graph.nodes: + if node.target is torch.ops.aten.sym_size.int: + has_sym_size = True + + self.assertTrue(has_sym_size) + + @config.patch(assume_static_by_default=False) + def test_dynamic_slicing(self): + def f(x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] + + gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) + + inp = torch.randn(6, 7) + self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape) + + count = 0 + # aten graph should flatten getitem calls to actual + # slice kernel call. + for node in gm_aten_mode.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.slice.Tensor + ): + count += 1 + + self.assertEqual(count, 2) + + gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5)) + + # In torch mode, the graph should contain 3 getitem methods + # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice + # this is because Tensor class has its' own getitem method + # which gets translated to aten.Slice later. + count = 0 + for node in gm_torch_mode.graph.nodes: + if node.op == "call_function" and node.target == operator.getitem: + count += 1 + + self.assertEqual(count, 1) + self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape) + + @config.patch(capture_scalar_outputs=True) + def test_dynamic_slicing_simple(self): + def f(x): + return x[slice(None, None, None)] + + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) + + inp = torch.randn(6, 7) + self.assertEqual(gm(inp), f(inp)) + + def test_pre_dispatch_simple(self): + def f(x): + y = torch.ones_like(x) + return torch.matmul(x, y) + + gm, _ = torch._dynamo.export( + f, + aten_graph=True, + pre_dispatch=True, + tracing_mode="fake", + )( + torch.randn(5, 5), + ) + + inp = torch.randn(6, 6) + self.assertEqual(gm(inp), f(inp)) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + arg0_1 = arg0 + ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False) + matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None + return pytree.tree_unflatten([matmul], self._out_spec)""", + ) + + @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) + def test_export_cond_in_aten_symbolic(self): + class ConditionOp(torch.nn.Module): + def true_fn(self, x, y): + return x * y + + def false_fn(self, x, y): + return x + y + + def forward(self, pred, x, y): + return cond(pred, self.true_fn, self.false_fn, [x, y]) + + model = ConditionOp() + inp = ( + torch.tensor(False), + torch.randn(4, 4), + torch.randn(4, 4), + ) + gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp) + + gm.print_readable() + + self.assertEqual(gm(*inp), model(*inp)) + + def test_export_with_kwargs(self): + def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs): + out = pos0 + for arg in tuple0: + out *= arg + for arg in myargs: + out *= arg + out *= mykw0 + out *= mykwargs["input0"] * mykwargs["input1"] + return out + + mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} + tuple0 = (torch.randn(4), torch.randn(4)) + mykw0 = torch.randn(4) + pos0 = torch.randn(4) + myargs = [torch.randn(4), torch.randn(4)] + + expected_argument_names = [ + "pos0", + "tuple0", + "myargs_0", + "myargs_1", + "mykw0", + "input0", + "input1", + ] + self._test_export_preserving_original_signature( + fn_with_kwargs, + expected_argument_names, + pos0, + tuple0, + *myargs, + mykw0=mykw0, + **mykwargs, + ) + + def test_export_with_kwargs_and_empty_args(self): + def fn_with_kwargs(mykw0=None, **mykwargs): + out = mykw0 + out *= mykwargs["input0"] * mykwargs["input1"] + return out + + mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} + mykw0 = torch.randn(4) + + expected_argument_names = ["mykw0"] + list(mykwargs.keys()) + self._test_export_preserving_original_signature( + fn_with_kwargs, expected_argument_names, mykw0, **mykwargs + ) + + def test_export_with_args_and_empty_kwargs(self): + def fn_with_kwargs(pos0, tuple0, *myargs): + out = pos0 + for arg in tuple0: + out *= arg + for arg in myargs: + out *= arg + return out + + tuple0 = (torch.randn(4), torch.randn(4)) + pos0 = torch.randn(4) + myargs = [torch.randn(4), torch.randn(4)] + + expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"] + self._test_export_preserving_original_signature( + fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs + ) + + @common_utils.parametrize( + "default_value", + [ + common_utils.subtest(None, name="None"), + common_utils.subtest(42.0, name="float"), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + torch.randn(4), + name="tensor", + decorators=[unittest.expectedFailure], + ), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + (torch.randn(4),), + name="tuple", + decorators=[unittest.expectedFailure], + ), + ], + ) + def test_export_with_args_with_default(self, default_value): + def fn(pos0, pos1_default=default_value): + out = pos0 + if pos1_default is None: + pos1_default = torch.randn(4) + if isinstance(pos1_default, tuple): + pos1_default = pos1_default[0] + out *= pos1_default + return out + + pos0 = torch.randn(4) + expected_argument_names = ["pos0"] + self._test_export_preserving_original_signature( + fn, expected_argument_names, pos0 + ) + + @common_utils.parametrize( + "default_value", + [ + common_utils.subtest(None, name="None"), + common_utils.subtest(42.0, name="float"), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + torch.randn(4), + name="tensor", + decorators=[unittest.expectedFailure], + ), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + (torch.randn(4),), + name="tuple", + decorators=[unittest.expectedFailure], + ), + ], + ) + def test_export_with_kwargs_with_default(self, default_value): + def fn(pos0, *, kw0, kw1_default=default_value, **kwargs): + out = pos0 + out += kw0 + if kw1_default is None: + kw1_default = torch.randn(4) + elif isinstance(kw1_default, tuple): + kw1_default = kw1_default[0] + out += kw1_default + out += kwargs["kw2"] + return out + + pos0 = torch.randn(4) + kw0 = torch.randn(4) + kw2 = torch.randn(4) + + args = (pos0,) + kwargs = {"kw0": kw0, "kw2": kw2} + expected_argument_names = ["pos0", "kw0", "kw2"] + self._test_export_preserving_original_signature( + fn, expected_argument_names, *args, **kwargs + ) + + def test_export_with_wrapped_fn(self): + # To ensure dynamo.export is robust to wrapped functions + # when it cannot use `inspect` to retrieve original signature + # info. + def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): + out = pos0 + out += pos1 + out += kw0 + out += kw1 + for arg in args: + out += arg + for kwarg in kwargs.values(): + out += kwarg + return out + + def wrapped_fn(*args, **kwargs): + return _fn(*args, **kwargs) + + pos0 = torch.randn(4) + kw0 = torch.randn(4) + args = (pos0, torch.randn(4), torch.randn(4)) + kwargs = {"kw0": kw0, "kw2": torch.randn(4)} + expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( + kwargs.keys() + ) + + self._test_export_preserving_original_signature( + wrapped_fn, expected_argument_names, *args, **kwargs + ) + + def test_export_with_functools_wrapped_method(self): + def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x + + @test_decorator + def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): + out = pos0 + out += pos1 + out += kw0 + out += kw1 + for arg in args: + out += arg + for kwarg in kwargs.values(): + out += kwarg + return out + + pos0 = torch.randn(4) + pos1 = torch.randn(4) + unnamed_pos = torch.randn(4) + kw0 = torch.randn(4) + args = (pos0, pos1, unnamed_pos) + kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)} + expected_argument_names = [ + "pos0", + "pos1", + "args_0", # 3rd unnamed positional argument + ] + list(kwargs.keys()) + m = MyModule() + + self._test_export_preserving_original_signature( + m.method_to_test, expected_argument_names, *args, **kwargs + ) + + def test_export_with_functools_wrapped_fn(self): + def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + @test_decorator + def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): + out = pos0 + out += pos1 + out += kw0 + out += kw1 + for arg in args: + out += arg + for kwarg in kwargs.values(): + out += kwarg + return out + + def wrapped_fn(*args, **kwargs): + return _fn(*args, **kwargs) + + pos0 = torch.randn(4) + kw0 = torch.randn(4) + args = (pos0, torch.randn(4), torch.randn(4)) + kwargs = {"kw0": kw0, "kw2": torch.randn(4)} + expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( + kwargs.keys() + ) + + self._test_export_preserving_original_signature( + wrapped_fn, expected_argument_names, *args, **kwargs + ) + + def _test_export_preserving_original_signature( + self, fn, expected_argument_names: Sequence[str], *args, **kwargs + ): + torch._dynamo.reset() + exported = torch._dynamo.export( + fn, + *args, + **kwargs, + aten_graph=False, + ) + + out_graph = exported[0] + dynamo_result = out_graph(*args, **kwargs) + real_result = fn(*args, **kwargs) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + # Check that the exported graph preserves same argument names. + self.assertEqual( + inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names + ) + + def test_dataclass_input_output(self): + from dataclasses import dataclass + + @dataclass + class Tensors: + x: torch.Tensor + y: torch.Tensor + + def f(t): + return t.x + t.y + + with self.assertRaisesRegex( + UserError, + "It looks like one of the inputs with type .*Tensors.* " + "is not supported or pytree-flattenable", + ): + torch._dynamo.export(f, aten_graph=False)( + Tensors(x=torch.randn(10), y=torch.randn(10)) + ) + + def f(x, y): + return Tensors(x=x.sin(), y=y.cos()) + + with self.assertRaisesRegex( + UserError, + "It looks like one of the outputs with type .*Tensors.* " + "is not supported or pytree-flattenable", + ): + torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) + + def test_empty(self): + def f(x): + return x + + exported = torch._dynamo.export(f)(torch.randn(3, 3)) + out_graph = exported[0] + inp = torch.randn(3, 3) + self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp))) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = torch.ones(3, 3) + + def forward(self): + return self.a + + exported = torch._dynamo.export(M())() + out_graph = exported[0] + self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph())) + + def test_export_meta(self): + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = torch.nn.Parameter(torch.ones(2, 3)) + + def forward(self, x): + return self.p + x + + with torch.device("meta"): + m = MyModule() + + inp = torch.ones(2, 3, device="meta") + exported = torch._dynamo.export(m)(inp) + out_graph = exported[0] + dynamo_result = out_graph(inp) + self.assertEqual(dynamo_result, m(inp)) + + def test_constraint_violation_error_messages(self): + class Foo(torch.nn.Module): + def forward(self, x): + if x.shape[0] == x.shape[1] * 2: + return x + 1 + else: + return x + 2 + + foo = Foo() + + t = torch.zeros([8, 4]) + dim0 = torch.export.Dim("dim0", min=3, max=10) + dim1 = torch.export.Dim("dim1") + dynamic_shapes = {"x": (dim0, dim1)} + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "Constraints violated .*!(.*\n)*.*" + "by dim0 = 2\\*dim1(.*\n)*.*" + "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", + ): + torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes, strict=True) + + class Bar(torch.nn.Module): + def forward(self, x): + if x.shape[0] == 5: + return x + 1 + else: + return x + 2 + + bar = Bar() + + t = torch.zeros([5]) + dim0 = torch.export.Dim("dim0", min=3, max=8) + dynamic_shapes = {"x": (dim0,)} + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "You marked.*but your code specialized it to be a constant.*" + "If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", + ): + torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True) + + class Qux(torch.nn.Module): + def forward(self, x): + if x.shape[0] > 5 and x.shape[0] < 10: + return x + 1 + else: + return x + 2 + + qux = Qux() + + t = torch.zeros([7]) + dim0 = torch.export.Dim("dim0", min=3, max=8) + dynamic_shapes = {"x": (dim0,)} + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "Not all values.*satisfy the generated guard", + ): + torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes, strict=True) + + def test_untracked_inputs_in_constraints(self): + from copy import copy + + class Foo(torch.nn.Module): + def forward(self, x, y): + return y + 1 + + foo = Foo() + + x = torch.randn(2) + y = torch.randn(5, 4) + + dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y") + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} + + example_inputs = (copy(x), y) + ep = torch.export.export( + foo, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + ep.module()(torch.randn(3), y) # no specialization error + + def test_export_raise_guard_full_constraint(self): + y = torch.randn([3, 3, 3]) + + def my_dyn_fn(x): + if x.shape[0] == 3: + return x.sin() + return x.cos() + + torch._dynamo.export(my_dyn_fn)(y) + + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export( + my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) + )(y) + + def test_export_module_specify_constraints_signature(self): + y = torch.randn([3, 3, 3]) + + class Mod(torch.nn.Module): + def forward(self, x): + if x.shape[0] == 3: + return x.sin() + return x.cos() + + mod = Mod() + torch._dynamo.export(mod)(y) + + with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"): + torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))( + y + ) + + def test_export_raise_guard_partial_constraint(self): + y = torch.randn([3, 3, 3]) + + def my_dyn_fn(x): + if x.shape[0] > 3: + return x.sin() + return x.cos() + + torch._dynamo.export(my_dyn_fn)(y) + + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export( + my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) + )(y) + + def test_export_raise_on_relationship(self): + y = torch.randn([3, 3, 3]) + + def my_dyn_fn(a, b, c): + if a.shape[0] == b.shape[1] == c.shape[2]: + return a.sin() + + return a.cos() + + torch._dynamo.export(my_dyn_fn)(y, y, y) + dim = torch.export.Dim("dim") + dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) + dynamic_shapes = ({0: dim}, {1: dim}, {2: dim}) + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) + + def test_export_no_raise(self): + y = torch.randn([3, 3, 3]) + + def my_dyn_fn(a, b, c): + if a.shape[1] == 3: + return a.cos() + return a * b * c + + torch._dynamo.export(my_dyn_fn)(y, y, y) + dim = torch.export.Dim("dim") + dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) + + def test_export_multi_dynamic_dim_unsafe_relationship(self): + x = torch.randn([3, 3, 3]) + y = torch.randn([2, 2, 2]) + z = torch.randn([3, 3, 3]) + + def my_dyn_fn(a, b, c): + if a.shape[0] == c.shape[0]: + return a.cos() + return a * c, b + + torch._dynamo.export(my_dyn_fn)(x, y, z) + dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz") + dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) + dimz = dimx + dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) + + def test_remove_redundant_dynamic_dim_in_error_message(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] == y["k"].shape[0]: + return x + 1 + else: + return x - 1 + + foo = Foo() + + a = torch.randn(3) + b = torch.randn(3) + dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b") + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"): + torch.export.export( + foo, + (a, {"k": b}), + dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}}, + strict=True, + ) + + def test_enforce_equalities(self): + class Bar(torch.nn.Module): + def forward(self, x, y): + return torch.matmul(x, y) + + bar = Bar() + + batch, size = torch.export.dims("batch", "size") + dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)} + + x = torch.randn(10, 3, 3) + y = torch.randn(10, 3, 4) + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3", + ): + with torch._export.config.patch(use_new_tracer_experimental=True): + torch.export.export( + bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True + ) + y = torch.randn(10, 3, 3) + with torch._export.config.patch(use_new_tracer_experimental=True): + ebar = torch.export.export( + bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True + ) + + for node in ebar.graph_module.graph.nodes: + if node.op == "placeholder": + shape = node.meta["val"].shape + self.assertEqual(shape[1], shape[2]) + + @torch._dynamo.config.patch( + capture_dynamic_output_shape_ops=True, + specialize_int=True, + capture_scalar_outputs=True, + ) + def test_export_preserve_constraints_as_metadata_tensor(self): + def f(x): + b = x.nonzero() + torch._check(b.shape[0] >= 2) + torch._check(b.shape[0] <= 5) + return b + + y = torch.tensor([8, 8, 6]) + torch._dynamo.export( + f, + aten_graph=True, + tracing_mode="symbolic", + )(y) + + @config.patch( + capture_dynamic_output_shape_ops=True, + specialize_int=True, + capture_scalar_outputs=True, + ) + def test_exported_graph_serialization(self): + def f(x, y): + b = x.item() + return torch.empty((b, y.shape[0])) + + x = torch.tensor([3]) + y = torch.randn([8, 8, 6]) + example_inputs = [x, y] + dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)}) + gm, _ = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, + aten_graph=True, + tracing_mode="symbolic", + )(*example_inputs) + + # Ensure the exported graph module with metadata is serializable, + # metadata won't be saved in the serialized module + buffer = io.BytesIO() + torch.save(gm, buffer) + + def test_export_dynamic_dim_not_1(self): + x = torch.randn([1, 1, 1]) + + def my_dyn_fn(a): + if a.shape[0] != 1: + return a.cos() + return a * a + + torch._dynamo.export(my_dyn_fn)(x) + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export( + my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) + )(x) + + def test_symbool(self): + def f(x): + a = torch.scalar_tensor(x.shape[0] > 4) + return x.sin().sum() + a.sum() + + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) + self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4))) + + def test_export_multi_dynamic_dim_constraint(self): + x = torch.randn([3, 3, 3]) + y = torch.randn([2, 2, 2]) + z = torch.randn([3, 3, 3]) + + def my_dyn_fn(a, b, c): + if a.shape[0] == c.shape[0]: + return a.cos() + return a * c, b + + torch._dynamo.export(my_dyn_fn)(x, y, z) + dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2") + dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None) + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) + dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0}) + torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) + + def test_export_dynamic_dim_range_constraint(self): + x = torch.ones(6, 4, 4) + dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},) + + def foo(x): + if x.shape[0] > 3: # ok + return x.sin() + return x.cos() + + torch._dynamo.export( + foo, + dynamic_shapes=dynamic_shapes, + aten_graph=True, + )(x) + + def bar(x): + if x.shape[0] > 5: # error + return x.sin() + return x.cos() + + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export( + bar, + dynamic_shapes=dynamic_shapes, + aten_graph=True, + )(x) + + def test_trivial_constraint(self): + class Foo(torch.nn.Module): + def forward(self, x): + # complex divisibility condition + if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0: + return x + 1 + else: + return x - 1 + + foo = Foo() + + class Bar(torch.nn.Module): + def forward(self, x): + # trivially true + if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0: + return x + 1 + else: + return x - 1 + + bar = Bar() + + class Qux(torch.nn.Module): + def forward(self, x): + # simple divisibility condition (not trivially true) + if (3 * x.shape[0]) % 2 == 0: + return x + 1 + else: + return x - 1 + + qux = Qux() + + x = torch.randn(12) + dim0 = torch.export.Dim("dim0", max=100) + dynamic_shapes = {"x": (dim0,)} + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + r"Constraints violated \(dim0\)", + ): + torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes, strict=True) + + torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes, strict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + r"Constraints violated \(dim0\)", + ): + torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes, strict=True) + + def test_list_contains(self): + def func(x): + assert x.size(-1) in [4, 5, 6], "bad" # noqa: S101 + return x + x + + inps = (torch.randn(1, 5),) + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_list_not_contains(self): + def func(x): + assert x.size(0) not in [4, 5, 6], "bad1" # noqa: S101 + assert "monkey" not in ["cow", "pig"], "bad2" # noqa: S101 + return x + x + + inps = (torch.randn(1, 5),) + opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True) + real_result = opt_func(*inps) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, aten_graph=True)(*inps) + out_graph = exported[0] + + dynamo_result = out_graph(*inps) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + + def test_export_identity(self): + inp = torch.tensor([0.1, 0.1]) + + def func(x): + return x + + torch._dynamo.reset() + exported, _ = torch._dynamo.export(func)(inp) + dynamo_result = exported(inp) + self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result)) + + def test_export_specialized_int(self): + class Foo(torch.nn.Module): + def __init__( + self, + input_dim, + ): + super().__init__() + self.torch_module = torch.nn.LayerNorm( + input_dim, eps=1e-5, elementwise_affine=True + ) + self.int_val = 100 + + def forward(self, input): + return input.cos() * self.int_val * self.torch_module.eps + + mod = Foo(128) + inp = torch.randn(3, 128) + + # In export, int & float in forward should always be specialized + gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp) + count = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + count += 1 + self.assertEqual(count, 1) + + def test_export_with_nonzero_static(self): + class BasicModule(torch.nn.Module): + def __init__(self, static_size): + super().__init__() + self.static_size = static_size + + def forward(self, x): + return torch.nonzero_static(x, size=self.static_size) + + input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3) + static_sizes = 3, 4 + for input_tensor, static_size in zip(input_tensors, static_sizes): + m = BasicModule(static_size) + gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor) + res = gm(input_tensor) + self.assertEqual(res.size(0), static_size) + self.assertTrue( + torch._dynamo.utils.same( + res, torch.nonzero_static(input_tensor, size=static_size) + ) + ) + + def test_export_pass_arg_by_name(self): + class BasicModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.my_lin = torch.nn.Linear(3, 4, bias=True) + + def forward(self, x): + return self.my_lin(x) + + mod, input_tensor = BasicModule(), torch.randn(2, 3) + gm, _ = torch._dynamo.export(mod, aten_graph=True)(input_tensor) + ref = mod(x=input_tensor) + res = gm(x=input_tensor) + self.assertTrue(torch._dynamo.utils.same(ref, res)) + + def test_export_pass_arg_by_name_star_args(self): + class BasicModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.my_lin = torch.nn.Linear(3, 4, bias=True) + + def forward(self, *args): + return self.my_lin(args[0]) * self.my_lin(args[1]) + + mod, input_tensor, input_tensor2 = ( + BasicModule(), + torch.randn(2, 3), + torch.randn(2, 3), + ) + gm, _ = torch._dynamo.export(mod, aten_graph=True)(input_tensor, input_tensor2) + ref = mod(input_tensor, input_tensor2) + res = gm(input_tensor, input_tensor2) + self.assertTrue(torch._dynamo.utils.same(ref, res)) + + def test_export_dynamic_dim_cleanup(self): + y = torch.randn([3, 3, 3]) + + def my_dyn_fn(x): + return x.cos() + + torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))( + y + ) + + @config.patch(capture_dynamic_output_shape_ops=True) + def test_export_dynamic_control_flow_error(self): + def f(x): + if x.nonzero() > 3: + return x.cos() + return x.sin() + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "Data-dependent branching", + ): + torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6)) + + @config.patch(assume_static_by_default=False) + def test_export_persist_assert(self): + def f(x): + assert x[0].sum() > 4, "Shape must be more than 4" # noqa: S101 + return x.cos() + x.sin() + + gm, _ = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( + torch.ones(5, 4, 6) + ) + + def has_aten_op(gm, op): + for node in gm.graph.nodes: + if node.target == op: + return True + return False + + self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) + + gm.graph.eliminate_dead_code() + gm.recompile() + self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) + + with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): + gm(torch.zeros(3, 4, 5)) + + @common_utils.parametrize( + "type_fn", + [ + common_utils.subtest(type, name="builtin"), + common_utils.subtest(lambda obj: obj.__class__, name="attr"), + ], + ) + def test_access_class_method_from_user_class(self, type_fn): + class A: + @classmethod + def func(cls): + return torch.Tensor([4, 5]) + + def f(x): + a = A() + return x.sum() + type_fn(a).func().sum() + + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) + self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) + + def test_not_functionalize(self): + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.ones(6, 2)) + + def forward(self, x): + x.add_(2) + return x.sum() + self.buffer1.sum() + + example_inputs = (torch.ones(1, 2, 3),) + gm, _ = torch._dynamo.export( + Foo(), + aten_graph=True, + tracing_mode="symbolic", + )(*example_inputs) + count = 0 + for node in gm.graph.nodes: + if node.target == torch.ops.aten.add_.Tensor: + count += 1 + self.assertEqual(count, 1) + test_inp = (torch.ones(1, 2, 3),) + test_inp_v2 = (torch.ones(1, 2, 3),) + self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2)) + + def test_round_dynamic_shapes(self): + def f(x): + return x[: round(x.shape[0] / 2)] + + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) + + self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) + + def test_cond_supported_pred_types(self): + def true_fn(x): + return x.cos() + + def false_fn(x): + return x.sin() + + def f_pred_traced_as_symnode_var(x): + return cond(x.shape[0] > 2, true_fn, false_fn, [x]) + + def f_pred_traced_as_tensor_var(x): + return cond(x.all(), true_fn, false_fn, [x]) + + def f_pred_complex_expression_traced_as_symnode_var(x): + return cond( + x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10, + true_fn, + false_fn, + [x], + ) + + example_inputs = (torch.rand(5, 8),) + for f in [ + f_pred_traced_as_symnode_var, + f_pred_traced_as_tensor_var, + f_pred_complex_expression_traced_as_symnode_var, + ]: + gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs) + self.assertEqual(gm(*example_inputs), f(*example_inputs)) + + def test_sum_param(self): + # Setting a new attribute inside forward() + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = torch.randn(3, 2) + + def forward(self, x): + self.b = 2 + return x.sum() + self.a.sum() + self.b + + torch._dynamo.export(Foo())(torch.randn(3, 2)) + + def test_mixed_real_and_fake_inputs(self): + class _TestPattern(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.bn = torch.nn.BatchNorm2d(1) + + def forward(self, input): + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.conv.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.conv.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype) + conv = self.conv._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) + conv_orig = conv_orig + self.conv.bias.reshape(bias_shape) + conv = self.bn(conv_orig) + return conv + + example_inputs = (torch.randn(1, 1, 3, 3),) + torch._dynamo.export( + _TestPattern(), + aten_graph=True, + )(*example_inputs) + + @config.patch( + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + assume_static_by_default=False, + ) + def test_sym_contains(self): + def f(x, y): + return x.size(0) in y + + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3)) + + true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5)) + false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2)) + self.assertEqual(gm(*true_inp), f(*true_inp)) + self.assertEqual(gm(*false_inp), f(*false_inp)) + + def test_cond_raise_user_error_on_missing_args(self): + def true_fn(x): + return x.cos() + + def false_fn(x): + return x.sin() + + def f(x): + return cond(x.shape[0] > 10, true_fn, false_fn) + + # Now we allow torch.cond to handle empty args + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + TypeError, + r"false_fn\(\) missing 1 required positional argument: 'x'", + ): + f(*example_inputs) + + def test_cond_raise_user_error_on_unsupported_pred(self): + def f_unsupported_pred(x): + pred = torch.nn.Module() + return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x]) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + RuntimeError, + "Expected pred to be bool or tensor, but got Module()", + ): + f_unsupported_pred(*example_inputs) + + def test_cond_raise_user_error_on_non_list_operands(self): + def f_non_list_operands(x): + return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + RuntimeError, + r"Expect operands to be a tuple of possibly nested dict/list/tuple", + ): + f_non_list_operands(*example_inputs) + + def test_cond_raise_user_error_on_non_tensor_operands(self): + def f_non_tensor_operands(x): + a: float = 3.14 + return cond( + torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a] + ) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + RuntimeError, + r"Expect operands to be a tuple of possibly nested dict/list/tuple", + ): + f_non_tensor_operands(*example_inputs) + + def test_cond_raise_user_error_on_branch_args_mismatch(self): + def true_fn(x, y): + return x.sin() + + def false_fn(x): + return x.cos() + + def f_branch_args_mismatch(x, y): + return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y]) + + example_inputs = (torch.rand(5), torch.rand(2)) + with self.assertRaisesRegex( + torch._dynamo.exc.UncapturedHigherOrderOpError, + r"Higher Order Operator: torch\.cond", + ): + torch._dynamo.export( + f_branch_args_mismatch, + aten_graph=True, + )( + *example_inputs, + ) + + @config.patch(suppress_errors=True) + def test_uncaptured_higher_order_op_error_not_suppresed(self): + def true_fn(x, y): + return x.sin() + + def false_fn(x): + return x.cos() + + def f_branch_args_mismatch(x, y): + return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y]) + + example_inputs = (torch.rand(5), torch.rand(2)) + with self.assertRaisesRegex( + torch._dynamo.exc.UncapturedHigherOrderOpError, + r"Higher Order Operator: torch\.cond", + ): + torch._dynamo.export( + f_branch_args_mismatch, + aten_graph=True, + )( + *example_inputs, + ) + + def test_cond_raise_user_error_on_branch_return_non_tensor(self): + def f_branch_return_non_tensor(x): + return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x]) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + torch._dynamo.exc.UncapturedHigherOrderOpError, + r"Higher Order Operator: torch\.cond", + ): + torch._dynamo.export( + f_branch_return_non_tensor, + aten_graph=True, + )(*example_inputs) + + def test_cond_raise_user_error_on_branch_return_multiple_tensors(self): + def f_branch_return_multiple_tensors(pred, x, y): + return cond( + pred, + lambda x: (x.clone(), x.clone()), + lambda x: (x.clone(), x.clone()), + [y], + ) + + example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2)) + gm, _ = torch._dynamo.export( + f_branch_return_multiple_tensors, + aten_graph=True, + )(*example_inputs) + self.assertEqual( + gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs) + ) + + def test_multiple_outputs_op_with_evaluator(self): + class TopKModel(torch.nn.Module): + def forward(self, x): + values, _ = torch.topk(x, 3) + return torch.sum(values) + + x = torch.arange(1.0, 6.0, requires_grad=True) + torch._dynamo.export(TopKModel())(x) + + def test_cond_raise_user_error_on_mismatch_return_length(self): + def true_fn(x): + return x.clone() + + def false_fn(x): + return (x.clone(), x.clone()) + + def f_mismatch_return_length(x): + return cond(torch.tensor(100), true_fn, false_fn, [x]) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + torch._dynamo.exc.TorchRuntimeError, + "Unmatched output spec from torch.cond branches", + ): + torch._dynamo.export( + f_mismatch_return_length, + aten_graph=True, + )(*example_inputs) + + def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self): + def true_fn(x): + return torch.tensor([[3], [2]]) + + def false_fn(x): + return torch.tensor([3.14]) + + def f_return_tensor_mismatch(x): + return cond(x.shape[0] < 3, true_fn, false_fn, [x]) + + example_inputs = (torch.rand(5),) + with self.assertRaisesRegex( + torch._dynamo.exc.TorchRuntimeError, + "When merging two branches' output in torch.cond", + ): + torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)( + *example_inputs, + ) + + def test_byte_tensor_does_not_crash(self): + # See https://github.com/pytorch/pytorch/issues/100455 + def func(text): + tensor = torch.ByteTensor(list(bytes(text, "utf8"))) + return tensor + tensor + + text = "".join(chr(a % 90 + 40) for a in range(111)) + opt_func = torch.compile(func, backend="eager", dynamic=True) + for i in [99, 100]: + input = text[:i] + opt_func(input) + + def test_export_defaults_ok(self): + class DynamicSliceExportMod(torch.nn.Module): + def forward(self, x): + results = [] + for i in range(4): + results.append(x[: x.size(0) - i, i : x.size(2), i:3]) + return tuple(results) + + gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)( + torch.randn(5, 5, 5), + ) + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + arg0_1 = arg0 + sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0) + slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3) + sub = sym_size_int - 1 + slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None + slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None + slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None + sub_1 = sym_size_int - 2 + slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None + slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int); slice_5 = None + slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None + sub_2 = sym_size_int - 3 + slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None + slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int); slice_8 = sym_size_int = None + slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None + return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""", + ) + + def test_capture_symbolic_tracing_simple_within_fake_mode(self): + from torch._dynamo.output_graph import config + + def f(x): + y = torch.randn(3) + return x + x * y + + with fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + ), + ): + x = torch.randn(3) + + for aten_graph in [True, False]: + gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x) + self.assertTrue( + isinstance(gm, torch.fx.GraphModule), + msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_" + + str(aten_graph), + ) + + def test_export_with_symbool_inputs(self): + def f(pred: bool, x: torch.Tensor): + if pred: + return x.sin() + else: + return x.cos() + + x = torch.randn([3, 4]) + + def test_symbool_guards( + f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards + ): + shape_env = ShapeEnv() + with fake_tensor.FakeTensorMode( + shape_env=shape_env, + ) as fake_mode: + fake_x = fake_mode.from_tensor( + x, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())], + ), + ) + for i, size in enumerate(size_tests): + pred = fake_x.size(0) == size + gm, guards = torch._dynamo.export(f)(pred, x) + actual = normalize_gm(gm.print_readable(print_output=False)) + # TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work + self.assertExpectedInline(actual, exp_graph[i].format(size=size)) + dynamo_shape_env_guards = [ + guard + for guard in guards + if guard.guard_types is not None + and "SHAPE_ENV" in guard.guard_types + ] + self.assertEqual(len(dynamo_shape_env_guards), 1) + guard_code_on_predicate = [ + code + for code in dynamo_shape_env_guards[0].code_list + if "L['pred']" in code + ] + self.assertEqual(guard_code_on_predicate, exp_guard_code[i]) + outter_shape_env_guards = [ + str(guard.expr) for guard in shape_env.guards + ] + self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i]) + + true_graph = """\ +class GraphModule(torch.nn.Module): + def forward(self, pred, x): + arg0: "Sym(Eq(s26, {size}))"; arg1: "f32[s77, s27]"; + + arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {{}}), self._in_spec) + l_x_ = arg1 + + sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None + return pytree.tree_unflatten([sin], self._out_spec) +""" + false_graph = """\ +class GraphModule(torch.nn.Module): + def forward(self, pred, x): + arg0: "Sym(Eq(s26, {size}))"; arg1: "f32[s77, s27]"; + + arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {{}}), self._in_spec) + l_x_ = arg1 + + cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None + return pytree.tree_unflatten([cos], self._out_spec) +""" + true_guard_code = [ + "cast_symbool_to_symint_guardless(L['pred']) == 1", + ] + false_guard_code = [ + "cast_symbool_to_symint_guardless(L['pred']) != 1", + ] + test_symbool_guards( + f, + [3, 3, 4, 5], + [true_graph, true_graph, false_graph, false_graph], + [true_guard_code, true_guard_code, false_guard_code, false_guard_code], + # Outer shape env should have no guards in it because we never specialize on the outer symbool. + [[], [], [], []], + ) + + def test_input_global(self) -> None: + global bulbous_bouffant + bulbous_bouffant = torch.randn(3) + + def f(y): + return bulbous_bouffant + y + + torch._dynamo.export(f)(torch.randn(3)) + + def test_input_global_multiple_access(self) -> None: + global macademia + macademia = torch.randn(3) + + def g(y): + global macademia + y = macademia + y + return y + + def f(y): + global macademia + y = g(y) + return macademia + y + + torch._dynamo.export(f)(torch.randn(3)) + + def test_input_nonlocal(self) -> None: + arglebargle = torch.randn(3) + + def f(y): + return arglebargle + y + + torch._dynamo.export(f)(torch.randn(3)) + + def test_input_unused_nonlocal_ok(self) -> None: + arglebargle = torch.randn(3) + + def f(y): + x = arglebargle # noqa: F841 + return y + + torch._dynamo.export(f)(torch.randn(3)) + + def test_symbolic_tracing_within_fake_mode_with_constraints(self): + from torch._subclasses import fake_tensor + + fake_mode = fake_tensor.FakeTensorMode() + + class DynamicShapeSimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b, c) -> torch.Tensor: + d = (torch.matmul(a, b) + c) / 2 + d_s0 = d.shape[0] + d_s1 = d.shape[1] + d_s3 = d_s0 * d_s1 + e = d.view(d_s3) + return torch.cat([e, e]) + + with fake_mode: + model = DynamicShapeSimpleModel() + inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) + dim = torch.export.Dim("dim") + dynamic_shapes = ({0: dim}, None, {0: dim}) + for aten_graph in [True, False]: + gm = torch._dynamo.export( + model, + dynamic_shapes=dynamic_shapes, + aten_graph=aten_graph, + )(*inputs).graph_module + + # Since there are no parameters we can do this + inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) + self.assertEqual(model(*inputs), gm(*inputs)) + + def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self): + from torch._subclasses import fake_tensor + + fake_mode = fake_tensor.FakeTensorMode() + + # TODO: Seems to choke if you don't make a fresh model and + # just try to export Linear directly... + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + out = self.linear(x) + return out + + with fake_mode: + model = Model() + inputs = (torch.randn(10, 2, 2),) + dynamic_shapes = ({0: torch.export.Dim("dim")},) + for aten_graph in [True, False]: + torch._dynamo.export( + model, + dynamic_shapes=dynamic_shapes, + aten_graph=aten_graph, + )(*inputs).graph_module + + def test_capture_symbolic_tracing_within_fake_mode(self): + from torch._dynamo.output_graph import config + from torch._subclasses import fake_tensor + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.linear2 = torch.nn.Linear(2, 2) + + def forward(self, x): + out = self.linear(x) + out = self.linear2(out) + return out + + # User-instantiated FakeTensorMode + fake_mode = fake_tensor.FakeTensorMode( + allow_non_fake_inputs=False, + allow_fallback_kernels=True, + shape_env=ShapeEnv( + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + ), + ) + # Fakefy input+model before exporting it + with fake_mode: + x = torch.rand(5, 2, 2) + model = Model() + + # Export the model with fake inputs and parameters + for aten_graph in [True, False]: + graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x) + self.assertTrue( + isinstance(graph_module, torch.fx.GraphModule), + msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_" + + str(aten_graph), + ) + + def test_cond_op_param_buffer_lifted(self): + class A(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) + + def forward(self): + return self.buffer1.sum() + + class B(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) + + def forward(self): + return self.buffer2.sum() + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = A() + self.b = B() + + def forward(self, x): + def true_fn(x): + return x.cos() + self.a() + + def false_fn(x): + return x.sin() + self.b() + + return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) + + gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) + self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) + self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) + + def test_nested_cond_op_param_buffer_lifted(self): + class A(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) + + def forward(self): + return self.buffer1.sum() + + class B(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) + + def forward(self): + return self.buffer2.sum() + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = A() + self.b = B() + + def forward(self, x): + def true_true_fn(x): + return x.cos() + self.a() + + def true_false_fn(x): + return x.cos() + self.a() + 1 + + def true_fn(x): + return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x]) + + def false_fn(x): + return x.sin() + self.b() + + return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) + + gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) + self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) + self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4))) + self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) + + def test_map_cond_param_buffer_lifted(self): + from functorch.experimental.control_flow import cond, map + + class A(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) + + def forward(self): + return self.buffer1.sum() + + class B(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) + + def forward(self): + return self.buffer2.sum() + + class Module(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = A() + self.b = B() + + def inner(self, x, pred): + def true_fn(x): + return x + x + self.a() + + def false_fn(x): + return x * x + self.b() + + return cond(pred, true_fn, false_fn, [x]) + + def forward(self, pred, xs): + def body(x, pred): + return self.inner(x, pred) + self.b() + + return map(body, xs, pred) + + mod = Module() + x = torch.randn(3, 2, 1) + pred_x = torch.tensor(True) + + y = torch.randn(4, 3, 2) + pred_y = torch.tensor(False) + real_result = mod(pred_y, y) + + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) + self.assertEqual(real_result, out_graph(pred_y, y)) + + def test_cond_free_variables_overlapping(self): + from functorch.experimental.control_flow import cond + + class Module(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, pred, x): + a = torch.ones(6, 4) + b = torch.ones(6, 4) + c = torch.ones(6, 4) + d = torch.ones(6, 4) + + def true_fn(x): + return x + x + a.cos() + b.cos() + d.cos() + + def false_fn(x): + return x * x + a.sin() + b.sin() + c.sin() + + return cond(pred, true_fn, false_fn, [x]) + + mod = Module() + x = torch.ones(6, 4) + pred_x = torch.tensor(True) + + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) + self.assertExpectedInline( + out_graph.code.strip(), + """\ +def forward(self, pred, x): + arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) + l_pred_ = arg0 + l_x_ = arg1 + a = torch.ones(6, 4) + b = torch.ones(6, 4) + c = torch.ones(6, 4) + d = torch.ones(6, 4) + cond_true_0 = self.cond_true_0 + cond_false_0 = self.cond_false_0 + cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (a, b, l_x_, d, c)); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None + getitem = cond[0]; cond = None + return pytree.tree_unflatten([getitem], self._out_spec)""", + ) + + self.assertExpectedInline( + out_graph.cond_true_0.code.strip(), + """\ +def forward(self, a, b, l_x_, d_true_branch, c_false_branch): + a_1 = a + b_1 = b + l_x__1 = l_x_ + add = l_x__1 + l_x__1; l_x__1 = None + cos = a_1.cos(); a_1 = None + add_1 = add + cos; add = cos = None + cos_1 = b_1.cos(); b_1 = None + add_2 = add_1 + cos_1; add_1 = cos_1 = None + cos_2 = d_true_branch.cos(); d_true_branch = None + add_3 = add_2 + cos_2; add_2 = cos_2 = None + return (add_3,)""", + ) + + self.assertExpectedInline( + out_graph.cond_false_0.code.strip(), + """\ +def forward(self, a, b, l_x_, d_true_branch, c_false_branch): + a_1 = a + b_1 = b + l_x__1 = l_x_ + mul = l_x__1 * l_x__1; l_x__1 = None + sin = a_1.sin(); a_1 = None + add = mul + sin; mul = sin = None + sin_1 = b_1.sin(); b_1 = None + add_1 = add + sin_1; add = sin_1 = None + sin_2 = c_false_branch.sin(); c_false_branch = None + add_2 = add_1 + sin_2; add_1 = sin_2 = None + return (add_2,)""", + ) + + @unittest.skipIf( + common_utils.TEST_WITH_ASAN, + "Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416", + ) + def test_retracibility(self): + class MyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(20, 98) + self.bias = torch.randn(20) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(16, 33, 3) + self.linear = MyLinear() + + def forward(self, x): + a, b = x + a_conv = self.conv(a) + a_linear = self.linear(a_conv) + b_conv = self.conv(b) + b_linear = self.linear(b_conv) + return ( + a_linear.cos() + b_linear.sin(), + a_linear.sin() + b_linear.cos(), + ) + + inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) + + gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) + gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) + + inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) + + self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0])) + self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1])) + + def test_retracibility_dict_container_inp_out(self): + class MyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(20, 98) + self.bias = torch.randn(20) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(16, 33, 3) + self.linear = MyLinear() + + def forward(self, x): + a1, a2 = x["a"] + b = x["b"] + a1_conv = self.conv(a1) + a1_linear = self.linear(a1_conv) + a2_conv = self.conv(a2) + a2_linear = self.linear(a2_conv) + b_conv = self.conv(b) + b_linear = self.linear(b_conv) + return { + "a": [ + a1_linear.cos() + b_linear.sin(), + a1_linear.cos() + b_linear.sin(), + ], + "b": a2_linear.sin() + b_linear.cos(), + } + + inp_container = { + "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), + "b": torch.randn(20, 16, 50, 100), + } + + gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) + gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) + + inp_test = { + "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), + "b": torch.randn(20, 16, 50, 100), + } + + self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0])) + self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1])) + self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"])) + + def test_retracibility_nested_list_out(self): + class MyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(20, 98) + self.bias = torch.randn(20) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(16, 33, 3) + self.linear = MyLinear() + + def forward(self, x): + a1, a2 = x["a"] + b = x["b"] + a1_conv = self.conv(a1) + a1_linear = self.linear(a1_conv) + a2_conv = self.conv(a2) + a2_linear = self.linear(a2_conv) + b_conv = self.conv(b) + b_linear = self.linear(b_conv) + return [ + [ + a1_linear.cos() + b_linear.sin(), + a1_linear.cos() + b_linear.sin(), + ], + [ + a2_linear.sin() + b_linear.cos(), + a2_linear.sin() + b_linear.cos(), + ], + ] + + inp_container = { + "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), + "b": torch.randn(20, 16, 50, 100), + } + + gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) + gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) + + inp_test = { + "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), + "b": torch.randn(20, 16, 50, 100), + } + + self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0])) + self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1])) + self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0])) + self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1])) + + def test_fx_pytree(self): + def foo(args): + flat_args, spec = torch.utils._pytree.tree_flatten(args) + flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec) + return flat_args_fx[0] + flat_args[0] + + inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) + + gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True) + + self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container))) + + @config.patch(suppress_errors=True) + @config.patch(verbose=True) + def test_export_with_map_zero_sized_tensor_suppress_errors(self): + from functorch.experimental.control_flow import map + + class Module(torch.nn.Module): + def forward(self, xs): + def body(x): + return x + 1 + + return map(body, xs) + + mod = Module() + xs = torch.randn(0, 2) + with self.assertRaises( + torch._dynamo.exc.Unsupported, + ): + torch._dynamo.export(mod, xs) + + def test_param_buffer_safe_from_mutation_simple(self): + class Module(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5)) + + def forward(self, x): + self.buffer1.add_(1) + return x + self.buffer1 + + gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False) + buffers = list(gm.named_buffers()) + self.assertEqual(len(buffers), 1) + + name, buffer = buffers[0] + self.assertEqual(name, "L__self___buffer1") + + self.assertTrue(torch.allclose(buffer, torch.zeros(5))) + + def test_param_buffer_safe_from_mutation_recurse(self): + class Child(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer2 = torch.nn.Buffer(torch.zeros(5)) + + def forward(self, x): + return x.sum() + self.buffer2.sum() + + class Module(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.buffer1 = torch.nn.Buffer(torch.zeros(5)) + self.child = Child() + + def forward(self, x): + self.buffer1.add_(1) + self.child.buffer2.add_(2) + return x.sum() + self.buffer1.sum() + self.child(x) + + gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False) + for _, buffer in gm.named_buffers(): + self.assertTrue(torch.allclose(buffer, torch.zeros(5))) + + def test_predispatch_with_higher_order(self): + def f(x): + return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x]) + + gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( + torch.randn(4, 4) + ) + inp1 = torch.randn(4, 4) + inp2 = torch.randn(6, 4) + self.assertTrue(torch.allclose(f(inp1), gm(inp1))) + self.assertTrue(torch.allclose(f(inp2), gm(inp2))) + + def test_predispatch_with_higher_order_nested(self): + def f(x): + def true_fn(x): + return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x]) + + return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x]) + + gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( + torch.randn(4, 4) + ) + inp1 = torch.randn(4, 4) + inp2 = torch.randn(6, 4) + inp3 = torch.randn(8, 4) + self.assertTrue(torch.allclose(f(inp1), gm(inp1))) + self.assertTrue(torch.allclose(f(inp2), gm(inp2))) + self.assertTrue(torch.allclose(f(inp3), gm(inp3))) + + def test_predispatch_with_for_out_dtype(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight) + + weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) + m = M(weight) + x = torch.randint(-128, 127, (5, 5), dtype=torch.int8) + gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) + + self.assertTrue(torch.allclose(m(x), gm(x))) + + def test_predispatch_with_for_out_dtype_nested(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def true_fn(self, x): + return out_dtype( + torch.ops.aten.mm.default, torch.int32, x, self.weight + ).sum() + + def false_fn(self, x): + return out_dtype( + torch.ops.aten.mul.Tensor, torch.int32, x, self.weight + ).sum() + + def forward(self, x): + return cond(x.sum() != 0, self.true_fn, self.false_fn, [x]) + + weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) + m = M(weight) + x = torch.ones((5, 5), dtype=torch.int8) + gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) + + self.assertTrue(torch.allclose(m(x), gm(x))) + y = torch.zeros((5, 5), dtype=torch.int8) + self.assertTrue(torch.allclose(m(y), gm(y))) + + self.assertExpectedInline( + gm.true_graph_0.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None + sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None + return (sum_1,)""", + ) + + self.assertExpectedInline( + gm.false_graph_0.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None + sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None + return (sum_1,)""", + ) + + def test_export_nn_module_stack_patched_module(self): + def forward(self, x, y): + return x * y + + class Toplevel(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def forward(self, x, y): + return self.m(x, y) + + class M(torch.nn.Module): + def forward(self, x, y): + return x + y + + t = Toplevel(M()) + t.m.forward = forward.__get__(t.m, M) + x, y = torch.rand(3), torch.rand(3) + gm, _ = torch._dynamo.export(t, x, y) + + self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y))) + for node in gm.graph.nodes: + if node.op == "call_function": + self.assertIn("nn_module_stack", node.meta) + + def test_preserve_fx_node_metadata(self): + class Module1(torch.nn.Module): + def forward(self, x): + return torch.sin(x) + + class Module2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.mod1 = Module1() + + def forward(self, x): + x = torch.cos(x) + x = self.mod1(x) + x = torch.relu(x) + return x + + def fn(x): + return torch.abs(x) + + mod = Module2() + inp = torch.randn(3, 3) + + gm, _ = torch._dynamo.export(mod)(inp) + + # replace relu with fn + gm_edit = copy.deepcopy(gm) + for nd in gm_edit.graph.nodes: + if nd.target == torch.relu: + nd.target = fn + nd.meta.clear() + break + gm_edit.recompile() + + gm2, _ = torch._dynamo.export(gm_edit)(inp) + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + x = torch.cos(l_x_); l_x_ = None + x_1 = torch.sin(x); x = None + x_2 = torch.relu(x_1); x_1 = None + return pytree.tree_unflatten([x_2], self._out_spec)""", + ) + + def _constais_op(gm, target): + for nd in gm.graph.nodes: + if nd.target == target: + return True + return False + + self.assertTrue(_constais_op(gm_edit, torch.cos)) + self.assertTrue(_constais_op(gm_edit, torch.sin)) + self.assertTrue(not _constais_op(gm_edit, torch.relu)) + + self.assertExpectedInline( + gm2.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + x = torch.cos(l_x_); l_x_ = None + x_1 = torch.sin(x); x = None + x_2 = torch.abs(x_1); x_1 = None + return pytree.tree_unflatten([x_2], self._out_spec)""", + ) + + # check for other metadata + for op in (torch.sin, torch.cos): + nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes)) + nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes)) + self.assertTrue( + ("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta) + ) + if "nn_module_stack" in nd1.meta: + self.assertEqual( + nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"] + ) + self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"]) + + def test_preserve_fx_node_metadata_recompile(self): + def fn(x): + return torch.sin(x) + + gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) + do_export = torch._dynamo.export(gm) + torch.compile(fn, backend="eager")(torch.randn(3, 3)) + gm1, _ = do_export(torch.randn(3, 3)) + gm2, _ = do_export(torch.randn(5, 3)) + + self.assertExpectedInline( + gm1.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + sin = torch.sin(l_x_); l_x_ = None + return pytree.tree_unflatten([sin], self._out_spec)""", + ) + self.assertExpectedInline( + gm2.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + sin = torch.sin(l_x_); l_x_ = None + return pytree.tree_unflatten([sin], self._out_spec)""", + ) + + def test_preserve_fx_node_metadata_inline(self): + def f1(x): + return torch.sin(x) + + gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3)) + + def f2(x): + x = torch.cos(x) + return gm(x) + + gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3)) + + self.assertExpectedInline( + gm2.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + x = torch.cos(l_x_); l_x_ = None + sin = torch.sin(x); x = None + return pytree.tree_unflatten([sin], self._out_spec)""", + ) + + def test_preserve_fx_node_metadata_graph_break(self): + def fn(x): + x = torch.sin(x) + x = torch.abs(x) + return torch.cos(x) + + def bad_fn(x): + torch._dynamo.graph_break() + return x + + gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) + + # replace abs with graph break + gm_edit = copy.deepcopy(gm) + for nd in gm_edit.graph.nodes: + if nd.target == torch.abs: + nd.target = bad_fn + nd.meta.clear() + break + gm_edit.recompile() + + expected = [ + """x = torch.sin(l_x_)""", + """cos = torch.cos(l_nested_frame_values_0_1_)""", + ] + + def test_backend(gm: torch.fx.GraphModule, example_inputs): + self.assertTrue(expected) + # Normalize output for dynamic and not + for nd in gm.graph.nodes: + if "example_value" in nd.meta: + del nd.meta["example_value"] + self.assertIn(expected[0], gm.print_readable(print_output=False)) + expected.pop(0) + return gm.forward + + torch._dynamo.reset() + opt_gm_edit = torch.compile(gm_edit, backend=test_backend) + opt_gm_edit(torch.randn(3, 3)) + + def test_torch_inference_mode_ctx(self): + @torch.inference_mode() + def fn(x): + return x + 1 + + gm, _ = torch._dynamo.export(fn, torch.rand(2, 2)) + + inp = torch.randn(2, 2) + out = gm(inp) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_args_0_ = arg0 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) + add = l_args_0_ + 1; l_args_0_ = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", + ) + self.assertEqual(out.requires_grad, False) + with self.assertRaisesRegex( + RuntimeError, + "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.", + ): + out.requires_grad = True + + @torch.inference_mode(False) + def fn_no_inference(x): + return x + 1 + + gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2)) + self.assertExpectedInline( + gm_no_inference.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_args_0_ = arg0 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False) + add = l_args_0_ + 1; l_args_0_ = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", + ) + + inp = torch.randn(2, 2) + out = gm_no_inference(inp) + self.assertEqual(out.requires_grad, False) + out.requires_grad = True + + def fn(x): + with torch.inference_mode(): + return x + 1 + + gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2)) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) + add = l_x_ + 1; l_x_ = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", + ) + inp = torch.randn(2, 2, requires_grad=True) + out = gm(inp) + self.assertEqual(out.requires_grad, False) + + def test_export_masking_with_no_grad(self): + def fn(x, b, y): + x = x.clone() + x[b] = y + return x + + def fn_no_grad(x, b, y): + with torch.no_grad(): + return fn(x, b, y) + + def fn_inference_mode(x, b, y): + with torch.inference_mode(): + return fn(x, b, y) + + x = torch.randn(4, requires_grad=True) + b = torch.tensor([True, False, True, False]) + y = torch.randn(2, requires_grad=True) + + gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x, b, y): + arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) + l_x_ = arg0 + l_b_ = arg1 + l_y_ = arg2 + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + x = l_x_.clone(); l_x_ = None + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None + return pytree.tree_unflatten([x], self._out_spec)""", + ) + + gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x, b, y): + arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) + l_x_ = arg0 + l_b_ = arg1 + l_y_ = arg2 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) + x = l_x_.clone(); l_x_ = None + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([x], self._out_spec)""", + ) + + gm, _ = torch._dynamo.export(fn)(x, b, y) + + def test_dynamo_list_index(self): + def fn(x, in_list): + return x + in_list.index(2) + + inputs = (torch.ones(2, 2), [1, 2]) + graph, _ = torch._dynamo.export(fn)(*inputs) + out = graph(*inputs) + self.assertEqual(out, torch.ones(2, 2) + 1) + + def test_dynamo_enum_in_tuple(self): + class IntEnum(int, Enum): + X = 0 + + def fn(tensor): + return tensor[..., IntEnum.X] + + tensor = torch.rand((5, 5)) + graph, _ = torch._dynamo.export(fn)(tensor) + out = graph(tensor) + self.assertEqual(out, tensor[:, 0]) + + def test_subclass_parameters(self): + from torch.testing._internal.two_tensor import TwoTensor + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.p1 = torch.nn.Parameter(torch.ones(3, 4)) + self.p2 = torch.nn.Parameter( + TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4)) + ) + + def forward(self, x): + return x + 2 * self.p1 + self.p2 + + m = M() + ref_x = torch.randn(3, 4) + ref_out = m(ref_x) + + from torch._functorch._aot_autograd.subclass_parametrization import ( + unwrap_tensor_subclass_parameters, + ) + + unwrap_tensor_subclass_parameters(m) + ref_x2 = ref_x.detach().clone() + ref_out2 = m(ref_x2) + self.assertEqual(ref_out2, ref_out) + + x = ref_x.detach().clone() + graph, _ = torch._dynamo.export(m)(x) + out = graph(x) + self.assertEqual(ref_out, out) + + def test_strict_fake_tensor_prop_real_tensors(self): + class Foo(torch.nn.Module): + def forward(self, x): + return bool(x.eq(0.1).any().item()) + + model = Foo() + inputs = (torch.randn(64),) + ref = model(*inputs) + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + ep = torch.export.export(model, inputs, strict=True) + res = ep.module()(*inputs) + + self.assertEqual(ref, res) + + +class ExportTestsSubprocess(torch._dynamo.test_case.TestCase): + def test_strict_export_under_pythonoptimize(self): + env = dict(os.environ) + env["PYTHONOPTIMIZE"] = "1" + code = """\ +import torch +model = torch.nn.Linear(2, 3) +example_input = torch.randn(1, 2) +ep = torch.export.export(model, args=(example_input,), strict=True) +out_export = ep.module()(example_input) +out_orig = model(example_input) +torch.testing.assert_close(out_export, out_orig) +""" + result = subprocess.run( + [sys.executable, "-c", code], + env=env, + capture_output=True, + text=True, + ) + self.assertEqual( + result.returncode, + 0, + msg=f"strict export under PYTHONOPTIMIZE=1 failed: stdout={result.stdout!r} stderr={result.stderr!r}", + ) + + +class ExportTestsDevice(torch._dynamo.test_case.TestCase): + def test_export_with_parameters(self, device): + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.features = torch.nn.Sequential( + torch.nn.Conv2d( + 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ), + torch.nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.features(x) + + model = MyModule().eval().to(device) + random_inputs = (torch.rand([32, 3, 32, 32]).to(device),) + dim_x = torch.export.Dim("dim_x", min=1, max=32) + exp_program = torch.export.export( + model, random_inputs, dynamic_shapes={"x": {0: dim_x}}, strict=True + ) + output_buffer = io.BytesIO() + # Tests if we can restore saved nn.Parameters when we load them again + torch.export.save(exp_program, output_buffer) + loaded_model = torch.export.load(output_buffer) + self.assertTrue( + isinstance( + loaded_model.module().get_parameter("features.0.weight"), + torch.nn.Parameter, + ) + ) + + def test_export_fast_binary_broadcast_check(self, device): + # This test looks at the case where we erroneously create a guard + # when checking the equality of the operands' shape and the output + # shape during FakeTensor's binary op fast path. + + class MyModel(torch.nn.Module): + def forward(self, a, b): + # final shape is (dim0, 4, 8) + # order matters since a & the output have the same shape + return b + a + + a = torch.randn(100, 4, 8) + b = torch.randn(4, 8) + model = MyModel().eval().to(device) + batchsize = torch.export.Dim("dim0", min=3, max=1024) + dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]} + + torch.export.export( + model, (a, b), dynamic_shapes=dynamic_shape_spec, strict=True + ) + + def test_export_fast_binary_broadcast_check_unbacked(self, device): + class MyModel(torch.nn.Module): + def forward(self, numel, scalar): + u0 = numel.item() + x = torch.ones(u0 + 1) + return scalar - x + + model = MyModel().eval().to(device) + numel = torch.tensor(10) + scalar = torch.randn(1) + torch.export.export(model, (numel, scalar), strict=True) + + +common_utils.instantiate_parametrized_tests(ExportTests) +devices = ["cuda", "hpu", "xpu"] +instantiate_device_type_tests( + ExportTestsDevice, globals(), only_for=devices, allow_xpu=True +) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/dynamo/test_structured_trace_xpu.py b/test/xpu/dynamo/test_structured_trace_xpu.py new file mode 100644 index 0000000000..b6090bf713 --- /dev/null +++ b/test/xpu/dynamo/test_structured_trace_xpu.py @@ -0,0 +1,1616 @@ +# Owner(s): ["module: dynamo"] +import copy +import functools +import io +import json +import logging +import os +import re +import shutil +import subprocess +import tempfile +import unittest.mock +from contextlib import contextmanager +from unittest import skipIf + +import torch +import torch._dynamo.test_case +import torch._dynamo.testing +import torch._logging.structured +import torch.distributed as dist +import torch.fx as fx +from torch._inductor.test_case import TestCase +from torch._logging._internal import TorchLogsFormatter +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing._internal.common_utils import find_free_port +from torch.testing._internal.inductor_utils import HAS_XPU_AND_TRITON +from torch.testing._internal.triton_utils import requires_gpu_and_triton + +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + +if torch.distributed.is_available(): + from torch.testing._internal.distributed.fake_pg import FakeStore + +HAS_TLPARSE = shutil.which("tlparse") is not None +requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") +requires_distributed = functools.partial( + unittest.skipIf, not dist.is_available(), "requires distributed" +) + + +def example_fn(a): + output = a.mul(torch.ones(1000, 1000)) + output = output.add(torch.ones(1000, 1000)) + return output + + +def example_training_fn(a): + output = a.mul(torch.ones(1000, 1000, requires_grad=True)) + output = output.add(torch.ones(1000, 1000)) + output.sum().backward() + return output + + +def dynamo_error_fn(a): + output = a.mul(torch.ones(1000, 1000)) + output = output.add(torch.ones(10, 10)) + return output + + +def inductor_error_fn(a): + output = torch.round(a) + return output + + +def inductor_schedule_fn(a): + output = a.add(torch.ones(1000, 1000, device=device_type)) + return output + + +ARGS = (torch.ones(1000, 1000, requires_grad=True),) + + +def replace_dynamic(buffer, key): + return re.sub(r'("' + key + r'":\s*)(\d+\.\d+)', r"\1", buffer) + + +class StructuredTraceTestingFilter(logging.Filter): + def __init__(self, match_name=None): + self.match_name = match_name + + def filter(self, record): + if "str" in record.metadata: + return False + if self.match_name is not None: + if "artifact" in record.metadata: + if self.match_name != record.metadata["artifact"]["name"]: + return False + elif self.match_name not in record.metadata: + return False + return True + + +class ChromiumEventFilter(logging.Filter): + def filter(self, record): + return "chromium_event" not in record.metadata + + +class StructuredTracePayloadFormatter(logging.Formatter): + def format(self, record): + return record.payload.strip() + + +class _DescribeIdNormalizer: + def __init__(self): + self._tensor_id_remap = {} + self._storage_id_remap = {} + self._next_tensor_id = 0 + self._next_storage_id = 0 + + def normalize(self, metadata): + if "describe_storage" in metadata: + storage_meta = metadata["describe_storage"] + if (storage_id := storage_meta.get("id")) is not None: + storage_meta["id"] = self._normalize_storage_id(storage_id) + storage_meta["describer_id"] = "ID" + if "describe_tensor" in metadata: + tensor_meta = metadata["describe_tensor"] + if (tensor_id := tensor_meta.get("id")) is not None: + tensor_meta["id"] = self._normalize_tensor_id(tensor_id) + if (storage_id := tensor_meta.get("storage")) is not None: + tensor_meta["storage"] = self._normalize_storage_id(storage_id) + tensor_meta["describer_id"] = "ID" + if "view_func" in tensor_meta: + tensor_meta["view_func"] = "VIEW_FUNC" + if "describe_source" in metadata: + source_meta = metadata["describe_source"] + if (source_id := source_meta.get("id")) is not None: + source_meta["id"] = self._normalize_tensor_id(source_id) + source_meta["describer_id"] = "ID" + return metadata + + def _normalize_tensor_id(self, original_id): + if original_id not in self._tensor_id_remap: + self._tensor_id_remap[original_id] = self._next_tensor_id + self._next_tensor_id += 1 + return self._tensor_id_remap[original_id] + + def _normalize_storage_id(self, original_id): + if original_id not in self._storage_id_remap: + self._storage_id_remap[original_id] = self._next_storage_id + self._next_storage_id += 1 + return self._storage_id_remap[original_id] + + +class StructuredTraceTestingFormatter(logging.Formatter): + def __init__(self): + super().__init__() + self._id_normalizer = _DescribeIdNormalizer() + + def format(self, record): + metadata = copy.deepcopy(record.metadata) + + # Stub out values that are not stable across runs + # TODO: Check that these match schema + if "has_payload" in metadata: + metadata["has_payload"] = "HASH" + if "dynamo_start" in metadata: + metadata["dynamo_start"]["stack"] = "STACK" + if "inductor_output_code" in metadata: + metadata["inductor_output_code"]["filename"] = "FILENAME" + if "file_path" in metadata["inductor_output_code"]: + metadata["inductor_output_code"]["file_path"] = "FILENAME" + if "stack" in metadata: + metadata["stack"] = "STACK" + if "compilation_metrics" in metadata: + metadata["compilation_metrics"] = "METRICS" + if "bwd_compilation_metrics" in metadata: + metadata["bwd_compilation_metrics"] = "METRICS" + if "compilation_metrics_runtime" in metadata: + metadata["compilation_metrics_runtime"] = "METRICS" + if "bwd_compilation_metrics_runtime" in metadata: + metadata["bwd_compilation_metrics_runtime"] = "METRICS" + metadata = self._id_normalizer.normalize(metadata) + if ( + (k := "create_symbol") in metadata + or (k := "guard_added_fast") in metadata + or (k := "create_unbacked_symbol") in metadata + ): + metadata[k]["user_stack"] = "STACK" + metadata[k]["stack"] = "STACK" + + if "dump_file" in metadata: + # Don't include the actually key number, that's sensitive to other + # test runs + metadata["dump_file"]["name"] = "" + return ( + json.dumps(metadata) + + "\n" + + "\n".join(l.rstrip() for l in record.payload.splitlines()) + ) + + return json.dumps(metadata) + + +trace_log = logging.getLogger("torch.__trace") + +chrome_event_filter = ChromiumEventFilter() + + +def show_chrome_events(fn): + """ + Don't hide chrome events for this test + """ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + self.handler.removeFilter(chrome_event_filter) + return fn(self, *args, **kwargs) + + return wrapper + + +class StructuredTraceTest(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + torch._logging.structured.INTERN_TABLE.clear() + self.buffer = io.StringIO() + self.old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + + self.handler = logging.StreamHandler(self.buffer) + self.handler.setFormatter(StructuredTraceTestingFormatter()) + self.handler.addFilter(StructuredTraceTestingFilter()) + self.handler.addFilter(chrome_event_filter) + trace_log.addHandler(self.handler) + + self.raw_file = tempfile.NamedTemporaryFile( # noqa: SIM115 + mode="w", delete=True + ) # set this to False to keep temporary files + self.raw_handler = logging.StreamHandler(self.raw_file) + self.raw_handler.setFormatter(TorchLogsFormatter(trace=True)) + trace_log.addHandler(self.raw_handler) + + def tearDown(self): + trace_log.removeHandler(self.handler) + trace_log.removeHandler(self.raw_handler) + self.raw_file.close() + trace_log.setLevel(self.old_level) + super().tearDown() + + def assertExpectedInline(self, actual, expected): + super().assertExpectedInline( + self._normalize_rank_field(self._normalize_describe_ids(actual)), + self._normalize_rank_field(self._normalize_describe_ids(expected)), + ) + + @staticmethod + def _normalize_rank_field(text): + if not isinstance(text, str): + return text + text = text.replace(', "rank": 0', "") + text = text.replace('"rank": 0, ', "") + text = text.replace('"rank": 0', "") + return text + + @staticmethod + def _normalize_describe_ids(text): + if not isinstance(text, str): + return text + normalizer = _DescribeIdNormalizer() + trailing_newline = text.endswith("\n") + normalized_lines = [] + for line in text.splitlines(): + if not line: + normalized_lines.append(line) + continue + try: + metadata = json.loads(line) + except json.JSONDecodeError: + normalized_lines.append(line) + continue + normalized_lines.append(json.dumps(normalizer.normalize(metadata))) + result = "\n".join(normalized_lines) + if trailing_newline: + result += "\n" + return result + + def assertParses(self): + if not HAS_TLPARSE: + self.skipTest("requires tlparse") + out = tempfile.mkdtemp() + try: + subprocess.check_call( + [ + "tlparse", + "-o", + out, + "--overwrite", + "--no-browser", + "--strict", + self.raw_file.name, + ] + ) + finally: + shutil.rmtree(out, ignore_errors=True) + + def test_compile_id_serialization_deserialization(self): + cid = torch._guards.CompileId( + frame_id=1, + frame_compile_id=2, + ) + if cid != torch._guards.CompileId.from_string(str(cid)): + raise AssertionError("CompileId round-trip failed") + + cid = torch._guards.CompileId( + compiled_autograd_id=1, + frame_id=2, + frame_compile_id=3, + ) + if cid != torch._guards.CompileId.from_string(str(cid)): + raise AssertionError("CompileId round-trip failed") + + cid = torch._guards.CompileId( + compiled_autograd_id=1, + frame_id=None, + frame_compile_id=None, + ) + if cid != torch._guards.CompileId.from_string(str(cid)): + raise AssertionError("CompileId round-trip failed") + + for bad_cid in ["-/-", "-/1", "1/-", "!1/2", "!1/-/-"]: + with self.assertRaises(ValueError): + torch._guards.CompileId.from_string(bad_cid) + + @requires_gpu_and_triton + def test_schedule(self): + fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") + fn_opt(torch.ones(1000, 1000, device=device_type)) + self.assertExpectedInline( + self.buffer.getvalue(), + f"""\ +{{"dynamo_start": {{"stack": "STACK"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 0, "describer_id": "ID", "size": 4000000}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1000, 1000], "dynamo_hint_overrides": {{}}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 0, "source": "L['a']"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"dynamo_output_graph": {{"sizes": {{"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aotautograd_cache_miss", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_pre_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_pre_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aot_forward_graph_fw_metadata", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_inference_graph": {{}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "torch._functorch.config", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_joint_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_joint_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_runnable", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_post_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "inductor_post_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"inductor_output_code": {{"filename": "FILENAME", "file_path": "FILENAME"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "triton_kernel_info", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_cache_miss", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "runtime_wrapper_orchestration", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"dynamo_cpp_guards_str": {{}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0}} +""", + ) + + self.assertParses() + + @requires_gpu_and_triton + def test_gpugraphs(self): + fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) + fn_opt(torch.ones(1000, 1000, device=device_type)) + self.assertExpectedInline( + self.buffer.getvalue(), + f"""\ +{{"dynamo_start": {{"stack": "STACK"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 0, "describer_id": "ID", "size": 4000000}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1000, 1000], "dynamo_hint_overrides": {{}}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 0, "source": "L['a']"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"dynamo_output_graph": {{"sizes": {{"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aotautograd_cache_miss", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_pre_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_pre_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aot_forward_graph_fw_metadata", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_inference_graph": {{}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "torch._functorch.config", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_joint_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_joint_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_runnable", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_post_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "inductor_post_grad_graph", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"inductor_output_code": {{"filename": "FILENAME", "file_path": "FILENAME"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "triton_kernel_info", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_cache_miss", "encoding": "json"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "runtime_wrapper_orchestration", "encoding": "string"}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"dynamo_cpp_guards_str": {{}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0}} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_recompiles(self): + def fn(x, y): + return torch.add(x, y) + + fn_opt = torch.compile(fn, backend="inductor") + fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000)) + fn_opt(torch.ones(1000, 1000), 1) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_recompile_backend_match(self): + def fn(x): + return x.sin() + x.cos() + + x = torch.ones(10) + + torch.compile(fn, backend="eager", dynamic=True)(x) + + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter(StructuredTraceTestingFilter("recompile_reasons")) + trace_log.addHandler(payload_handler) + try: + torch.compile(fn, backend="eager", dynamic=False)(x) + finally: + trace_log.removeHandler(payload_handler) + + payload = payload_buffer.getvalue() + self.assertIn("BACKEND_MATCH", payload) + self.assertIn("Cached backend:", payload) + self.assertIn("New backend:", payload) + self.assertIn("_TorchCompileWrapper", payload) + self.assertIn("dynamic=True", payload) + self.assertIn("dynamic=False", payload) + + self.assertParses() + + @requires_tlparse + def test_example_fn(self): + fn_opt = torch.compile(example_fn, backend="inductor") + fn_opt(torch.ones(1000, 1000)) + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_example_training_fn(self): + fn_opt = torch.compile(example_training_fn, backend="inductor") + fn_opt(torch.ones(1000, 1000, requires_grad=True)) + buffer = self.buffer.getvalue() + buffer = replace_dynamic(buffer, "inductor_compile_time_s") + buffer = replace_dynamic(buffer, "code_gen_time_s") + buffer = replace_dynamic(buffer, "structured_logging_overhead_s") + self.assertExpectedInline( + buffer, + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_49_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack1']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_50_MODIFIED_BYTECODE", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_49_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "backward_prologue", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "backward_epilogue", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "compiled_function_forward", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "compiled_function_backward", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "compiled_fn_wrapper", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_45_MODIFIED_BYTECODE", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_52_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_dynamo_error(self): + try: + fn_opt = torch.compile(dynamo_error_fn, backend="inductor") + fn_opt(*ARGS) + except Exception: + pass + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_inductor_error(self): + import torch._inductor.lowering + + def throw(x): + raise AssertionError + + # inject an error in the lowerings + dict_entries = {} + for x in list(torch._inductor.lowering.lowerings.keys()): + if "round" in x.__name__: + dict_entries[x] = throw + + with unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries): + try: + fn_opt = torch.compile(inductor_error_fn, backend="inductor") + fn_opt(*ARGS) + except Exception: + pass + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", + ) + + self.assertParses() + + @skipIf(HAS_XPU_AND_TRITON, "No backend type associated with device type xpu") + @requires_distributed() + @requires_gpu_and_triton + @unittest.skip("https://github.com/pytorch/pytorch/issues/176188") + def test_ddp_graphs(self): + import torch._dynamo.convert_frame as convert_frame + + convert_frame.FRAME_COUNTER = 0 + convert_frame.FRAME_COMPILE_COUNTER.clear() + + class ToyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(1024, 1024), + torch.nn.Linear(1024, 1024), + ) + + def forward(self, x): + return self.layers(x) + + # TODO: this isn't safely bracketed, will leak + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(find_free_port()) + dist.init_process_group("gloo", rank=0, world_size=1) + + model = DDP(ToyModel().to(f"{device_type}:0"), device_ids=[0], bucket_cap_mb=4) + ddp_model = torch.compile(model, backend="inductor") + + ddp_model(torch.randn(1024, 1024, device=f"{device_type}:0")) + + dist.destroy_process_group() + + self.assertExpectedInline( + self.buffer.getvalue(), + f"""\ +{{"dynamo_start": {{"stack": "STACK"}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "dynamo_graph_break_reason", "encoding": "string"}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"describe_storage": {{"id": 0, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1}} +{{"describe_tensor": {{"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1}} +{{"describe_source": {{"describer_id": "ID", "id": 0, "source": "L['args'][0]"}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1}} +{{"dynamo_cpp_guards_str": {{}}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1}} +{{"dynamo_start": {{"stack": "STACK"}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "dynamo_graph_break_reason", "encoding": "string"}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"dynamo_cpp_guards_str": {{}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 1}} +{{"dynamo_start": {{"stack": "STACK"}}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "dynamo_graph_break_reason", "encoding": "string"}}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"dynamo_cpp_guards_str": {{}}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 1}} +{{"dynamo_start": {{"stack": "STACK"}}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "torch_dynamo_resume_in___init___at_103_ORIGINAL_BYTECODE", "encoding": "string"}}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}} +{{"dynamo_start": {{"stack": "STACK"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 0, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 0, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 1, "describer_id": "ID", "size": 4096}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 2, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 2, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "stride": [1024, 1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 2, "source": "L['x']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 3, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 8, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 3, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 8, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 4, "describer_id": "ID", "size": 4096}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 9, "ndim": 1, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 4, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 9, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"dynamo_output_graph": {{"sizes": {{"l_self_modules_layers_modules_0_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_0_parameters_bias_": [1024], "l_x_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_bias_": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"optimize_ddp_split_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"optimize_ddp_split_child": {{"name": "submod_0"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"optimize_ddp_split_child": {{"name": "submod_1"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"describe_storage": {{"id": 0, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 0, "source": "L['x']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 1, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 2, "describer_id": "ID", "size": 4096}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "before_pre_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_pre_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aotautograd_cache_bypass", "encoding": "json"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_joint_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "torch._functorch.config", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aot_forward_graph_fw_metadata", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_forward_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_backward_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_runnable", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_post_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "inductor_post_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"inductor_output_code": {{"filename": "FILENAME", "file_path": "FILENAME"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_cache_miss", "encoding": "json"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"describe_storage": {{"id": 16, "describer_id": "ID", "size": 4194304}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 28, "ndim": 2, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 28, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_storage": {{"id": 17, "describer_id": "ID", "size": 4096}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_tensor": {{"id": 29, "ndim": 1, "dtype": "torch.float32", "device": "device(type='{device_type}', index=0)", "size": [1024], "dynamo_hint_overrides": {{}}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"describe_source": {{"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +{{"artifact": {{"name": "before_pre_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "after_pre_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aotautograd_cache_bypass", "encoding": "json"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_joint_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "torch._functorch.config", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "aot_forward_graph_fw_metadata", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_forward_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"aot_backward_graph": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_runnable", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "before_post_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "inductor_post_grad_graph", "encoding": "string"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"inductor_output_code": {{"filename": "FILENAME", "file_path": "FILENAME"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"artifact": {{"name": "fx_graph_cache_miss", "encoding": "json"}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"dynamo_cpp_guards_str": {{}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}} +{{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}} +""", + ) + + self.assertParses() + + @requires_tlparse + @unittest.skip("https://github.com/pytorch/pytorch/issues/176188") + def test_graph_breaks(self): + @torch.compile(backend="inductor") + def fn(x): + torch._dynamo.graph_break() + return x + 1 + + fn(torch.ones(1)) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "torch_dynamo_resume_in_fn_at_808_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch_dynamo_resume_in_fn_at_808_MODIFIED_BYTECODE", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +""", + ) + + self.assertParses() + + # TODO: bring in the trace_source tests once we start emitting bytecode + + @requires_tlparse + def test_graph_sizes_dynamic(self): + def fn(a, b): + return a @ b + + fn_opt = torch.compile(fn, backend="eager", dynamic=False) + fn_opt(torch.randn(10, 20), torch.randn(20, 30)) + + fn_opt2 = torch.compile(fn, backend="eager", dynamic=True) + fn_opt2(torch.randn(5, 10), torch.randn(10, 15)) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 800}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 20], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [20, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 2400}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +""", + ) + + self.assertParses() + + @requires_tlparse + def test_guards_recompiles(self): + def fn(x, ys, zs): + return inner(x, ys, zs) + + def inner(x, ys, zs): + for y, z in zip(ys, zs): + x += y * z + return x + + ys = [1.0, 2.0] + zs = [3.0] + x = torch.tensor([1.0]) + + fn_opt = torch.compile(fn, backend="eager") + fn_opt(x, ys, zs) + fn_opt(x, ys[:1], zs) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +""", + ) + + self.assertParses() + + def test_dump_file(self): + def f(x, y): + return x.add(y) + + gm = fx.symbolic_trace(f) + torch.compile(gm, backend="eager")(torch.randn(3), torch.randn(3)) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dump_file": {"name": ""}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} + + + +def forward(self, x, y): + add = x.add(y); x = y = None + return add + +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 12}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 12}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [3], "l_y_": [3], "add": [3]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", + ) + + @requires_tlparse + @torch._inductor.config.patch("fx_graph_cache", True) + def test_codecache(self): + def fn(a): + return a.sin() + + x = torch.tensor([1.0]) + fn_opt = torch.compile(fn, backend="inductor") + fn_opt(x) + torch._dynamo.reset() + # Trigger a cache hit + fn_opt(x) + + # Should print twice, including inductor_output_code + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "inductor_provenance_tracking_node_mappings", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "inductor_provenance_tracking_kernel_stack_traces", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "runtime_wrapper_orchestration", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", + ) + self.assertParses() + + @requires_tlparse + def test_make_fx_fail_partial(self): + from torch.fx.experimental.proxy_tensor import make_fx + + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter(StructuredTraceTestingFilter("make_fx_fail_partial")) + trace_log.addHandler(payload_handler) + + def f(x): + y = x + 1 # noqa: F841 + raise RuntimeError("boo") + + try: + make_fx(f)(torch.randn(2)) + except RuntimeError: + pass + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"artifact": {"name": "make_fx_fail_partial", "encoding": "string"}, "stack": "STACK", "has_payload": "HASH"} +""", + ) + + self.assertExpectedInline( + payload_buffer.getvalue(), + """\ +def forward(self, x_1: "f32[2][1]cpu"): + # No stacktrace found for following nodes + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = add = None +""", + ) + + @requires_tlparse + @torch._inductor.config.patch("fx_graph_cache", True) + @show_chrome_events + def test_chromium_event(self): + def fn(a): + return a.sin() + + x = torch.tensor([1.0]) + fn_opt = torch.compile(fn, backend="inductor") + fn_opt(x) + torch._dynamo.reset() + # Trigger a cache hit + fn_opt(x) + # Should print twice, including inductor_output_code + self.assertParses() + chromium_event = ( + '{"chromium_event": {}, "frame_id": 0, "frame_compile_id": 0, ' + '"attempt": 0, "has_payload": "HASH"}' + ) + self.assertTrue(chromium_event in self.buffer.getvalue()) + + @requires_tlparse + @torch._dynamo.config.patch("compiled_autograd", True) + @torch._inductor.config.patch("fx_graph_cache", True) + @show_chrome_events + def test_compiled_autograd_id(self): + def fn(a): + return a.sin().sum().backward() + + x = torch.tensor([1.0], requires_grad=True) + fn_opt = torch._dynamo.optimize("inductor")(fn) + fn_opt(x) + torch._dynamo.reset() + # Trigger a cache hit + fn_opt(x) + # Should print twice, including inductor_output_code + self.assertParses() + chromium_events = [ + ( + '{"chromium_event": {}, "frame_id": 0, "frame_compile_id": 0, ' + '"attempt": 0, "has_payload": "HASH"}' + ), + ( + '{"compiled_autograd_graph": {}, "compiled_autograd_id": 0, ' + '"attempt": 0, "has_payload": "HASH"}' + ), + ( + '{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 1, "frame_compile_id": 0, ' + '"attempt": 0, "has_payload": "HASH"}' + ), + ] + logs = self.buffer.getvalue() + self.assertTrue(all(event in logs for event in chromium_events)) + + @requires_tlparse + @torch._dynamo.config.patch("compiled_autograd", True) + def test_compiled_autograd_attribution(self): + # multiple dynamo recompiles should still be attributed to the parent compiled autograd id + def fn(): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.sin(x) + + @staticmethod + def backward(ctx, gO): + print("graph break") + (x,) = ctx.saved_tensors + print("graph break") + return gO * torch.cos(x) + + grads = [] + for i in [10, 100, 10, 15, 20, 25]: + x = torch.arange(0.0, i, requires_grad=True) + out = MySin.apply(x) + loss = out.sum() + loss.backward() + grads.append(x.grad) + + return grads + + fn_opt = torch.compile(fn) + fn_opt() + self.assertParses() + expected = [ + '{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 0, "attempt": 0}', + ] + logs = self.buffer.getvalue() + self.assertTrue(all(event in logs for event in expected)) + + @requires_tlparse + @show_chrome_events + def test_compiled_autograd_chromium(self): + with torch._dynamo.compiled_autograd._enable(torch.compile): + for i in [10, 100, 10, 15, 20, 25]: + x = torch.arange(0.0, i, requires_grad=True) + loss = x.sum() + loss.backward() + + self.assertParses() + expected = [ + '{"chromium_event": {}, "compiled_autograd_id": 0, "attempt": 0, "has_payload": "HASH"}', + '{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, ' + '"has_payload": "HASH"}', + '{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, ' + '"has_payload": "HASH"}', + ] + logs = self.buffer.getvalue() + self.assertTrue(all(event in logs for event in expected)) + + def test_recompile_user_contexts(self): + # test that user_context is called only once per recompile + num_calls = 0 + + def f(x): + return x + 1 + + f = torch.compile(f) + + def user_context() -> str: + nonlocal num_calls + num_calls += 1 + return "user_context: " + str(num_calls) + + torch._dynamo.register_hook_for_recompile_user_context(user_context) + + for _ in range(10): + f(torch.randn(1, 5)) + + # first compile + self.assertEqual(num_calls, 1) + + for i in range(2, 10): + f(torch.randn(i, 5)) + + # first compile + recompile once + self.assertEqual(num_calls, 2) + + def test_recompile_user_contexts_iteration(self): + class Step: + def __init__(self): + self.step = 0 + + def next_step(self): + self.step += 1 + + step = Step() + + def f(x): + return x + 1 + + f = torch.compile(f) + + def user_context() -> str: + return "user_context: " + str(step.step) + + torch._dynamo.register_hook_for_recompile_user_context(user_context) + + for i in range(10): + f(torch.randn(i + 2 // 3, 5)) + step.next_step() + + @contextmanager + def _setup_collective_schedule_capture(self): + """Helper to turn on and capture the 'inductor_collective_schedule' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter( + StructuredTraceTestingFilter("inductor_collective_schedule") + ) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + def test_collective_schedule_empty(self): + """Verify logging when no collective kernels are present (empty schedule).""" + with self._setup_collective_schedule_capture() as payload_buffer: + from torch._inductor.debug import log_collective_schedule + + log_collective_schedule([]) + + # With no collectives, artifact should not be logged and payload should be empty + self.assertNotIn('"inductor_collective_schedule"', self.buffer.getvalue()) + self.assertEqual(payload_buffer.getvalue().strip(), "") + + @requires_tlparse + @requires_distributed() + @torch._inductor.config.patch("fx_graph_cache", False) + def test_collective_schedule_real(self): + """Test collective schedule with _c10d_functional ops that work with FakeStore.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class CollectiveModule(torch.nn.Module): + def forward(self, x): + # Use _c10d_functional ops that actually trigger collective kernels + y = torch.ops._c10d_functional.all_reduce.default(x, "sum", "0") + y = torch.ops._c10d_functional.wait_tensor.default(y) + return y * 2 + + try: + with self._setup_collective_schedule_capture() as payload_buffer: + torch._dynamo.reset() + + mod = CollectiveModule() + compiled = torch.compile(mod, backend="inductor") + + compiled(torch.randn(4, 4)) + + # Verify collective schedule artifact was logged + self.assertIn('"inductor_collective_schedule"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + schedule = json.loads(payload_content) + self.assertIsInstance(schedule, list) + + # Verify expected collective operations are present + self.assertExpectedInline( + str(schedule), + """\ +['torch.ops._c10d_functional.all_reduce_.default', 'torch.ops._c10d_functional.wait_tensor.default']\ +""", + ) + self.assertParses() + finally: + dist.destroy_process_group() + + @contextmanager + def _setup_runtime_estimates_capture(self): + """Helper to turn on and capture the combined 'inductor_runtime_and_tensor_meta' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter( + StructuredTraceTestingFilter("inductor_runtime_and_tensor_meta") + ) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + @requires_distributed() + @requires_gpu_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_simple(self): + """Test runtime estimates logging with simple compute and collective ops.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + h = self.linear(x) + h = torch.relu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + return h + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = SimpleModule().to(device_type) + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device=device_type)) + + # Verify runtime + tensor meta artifact was logged + self.assertIn( + '"inductor_runtime_and_tensor_meta"', self.buffer.getvalue() + ) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Verify runtime estimates + compute_ops = [op for op in ops if op["type"] == "compute"] + collective_ops = [op for op in ops if op["type"] == "collective"] + + self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0) + + # Just check each op has an estimated runtime value (any value, including 0) + for op in ops: + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) + + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @requires_distributed() + @requires_gpu_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_mixed(self): + """Test runtime estimates logging with mixed compute and collective sequence.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class MixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + h = self.norm(x) + h = torch.nn.functional.gelu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + + h = h * 0.5 + + gathered = torch.ops._c10d_functional.all_gather_into_tensor.default( + h, 2, "0" + ) + gathered = torch.ops._c10d_functional.wait_tensor.default(gathered) + + return gathered.sum(dim=0) + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = MixedModule().to(device_type) + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device=device_type)) + + # Verify artifact was logged + self.assertIn( + '"inductor_runtime_and_tensor_meta"', self.buffer.getvalue() + ) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Should have both compute and collective ops + op_types = {op["type"] for op in ops} + self.assertIn("compute", op_types) + self.assertIn("collective", op_types) + + # Just check each op has an estimated runtime value (any value, including 0) + for op in ops: + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) + + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @requires_distributed() + @requires_gpu_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging_multiple_ops(self): + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class Mixed(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + y = torch.relu(self.linear(x)) + y = torch.ops._c10d_functional.all_reduce.default(y, "sum", "0") + y = torch.ops._c10d_functional.wait_tensor.default(y) + return y + 1 + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + mod = Mixed().to(device_type) + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device=device_type)) + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + types = sorted({op.get("type") for op in data.get("ops", [])}) + self.assertExpectedInline( + str(types), """['collective', 'compute']""" + ) + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging(self): + """Emit unified runtime+tensor-metadata artifact and assert a stable simplified JSON inline.""" + with self._setup_runtime_estimates_capture() as payload_buffer: + + def f(x): + y = x.transpose(0, 1) + z = y.mean(dim=0) + w = z.to(torch.float16) + return w + + compiled = torch.compile(f, backend="inductor", fullgraph=True) + compiled(torch.ones(2, 3)) + + # Verify artifact was logged + self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()) + + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + ops = data.get("ops", []) + + simplified_ops = [] + for op in ops: + outs = [ + { + "shape": out.get("shape", []), + "stride": out.get("stride", []), + "dtype": out.get("dtype", None), + } + for out in op.get("outputs", []) + ] + if outs: + simplified_ops.append( + { + "type": op.get("type", ""), + "outputs": outs, + } + ) + + self.assertExpectedInline( + {"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []}, + """{'ops': [{'type': 'compute', 'outputs': [{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}""", + ) + + self.assertParses() + + @requires_tlparse + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging_dynamic_shapes(self): + """Same as test_tensor_metadata_logging, but with dynamic shapes enabled to cover to_size_hints.""" + with self._setup_runtime_estimates_capture() as payload_buffer: + + def f(x): + y = x.transpose(0, 1) + z = y.mean(dim=0) + w = z.to(torch.float16) + return w + + compiled = torch.compile(f, backend="inductor", dynamic=True) + compiled(torch.ones(2, 3)) + + # Verify artifact was logged + self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()) + + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + ops = data.get("ops", []) + + simplified_ops = [] + for op in ops: + outs = [ + { + "shape": out.get("shape", []), + "stride": out.get("stride", []), + "dtype": out.get("dtype", None), + } + for out in op.get("outputs", []) + ] + if outs: + simplified_ops.append( + { + "type": op.get("type", ""), + "outputs": outs, + } + ) + + self.assertExpectedInline( + {"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []}, + ( + "{'ops': [{'type': 'compute', 'outputs': [" + "{'shape': [2], 'stride': [1], 'dtype': 'float32'}, " + "{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}" + ), + ) + + self.assertParses() + + @contextmanager + def _setup_graph_execution_capture(self): + """Helper to capture the 'graph_execution' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter(StructuredTraceTestingFilter("graph_execution")) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + @torch._inductor.config.patch(force_disable_caches=True) + def test_graph_execution_order(self): + """Verify graph execution order is aggregated into a single artifact.""" + torch._dynamo.reset() + with self._setup_graph_execution_capture() as payload_buffer: + + def fn(x): + y = x + 1 + torch._dynamo.graph_break() + return y + 2 + + compiled = torch.compile(fn, backend="inductor") + from torch._inductor.debug import record_and_log_graph_execution_order + + with record_and_log_graph_execution_order(): + compiled(torch.randn(1)) + + payload_content = payload_buffer.getvalue().strip() + payload = json.loads(payload_content) + executions = payload["graph_execution_order"] + self.assertTrue(all(isinstance(e["compile_id"], str) for e in executions)) + self.assertExpectedInline( + json.dumps(payload), + """{"graph_execution_order": [{"compile_id": "0/0"}, {"compile_id": "1/0"}]}""", + ) + self.assertParses() + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/dynamo/test_subclasses_xpu.py b/test/xpu/dynamo/test_subclasses_xpu.py new file mode 100644 index 0000000000..7d57715102 --- /dev/null +++ b/test/xpu/dynamo/test_subclasses_xpu.py @@ -0,0 +1,4394 @@ +# Owner(s): ["module: dynamo"] +import functools +import itertools +import unittest +from functools import partial + +import torch +import torch._dynamo.test_case +import torch._dynamo.testing +import torch._functorch.config +import torch.utils._pytree as pytree +import torch.utils.checkpoint +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm +from torch._functorch._aot_autograd.utils import make_boxed_compiler +from torch._functorch.compilers import min_cut_rematerialization_partition +from torch._higher_order_ops.wrap import wrap +from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.nested._internal.nested_tensor import ( + jagged_from_list, + jagged_from_tensor_and_lengths, + nested_view_from_values_offsets, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + NestedTensorTestCase, + parametrize, + subtest, +) +from torch.testing._internal.triton_utils import requires_gpu_and_triton +from torch.testing._internal.two_tensor import TwoTensor +from torch.utils._python_dispatch import return_and_correct_aliasing + + +def nontraceable_subclass(c): + return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c}) + + +def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): + actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) + self.assertEqual(actual_recompiles, expected_recompiles) + + +def get_jagged_tensor(nested_size, offsets, requires_grad=True): + # Makes a jagged tensor with N constituent tensors with size + # as specified ((S0, S1, S2), D) + D = nested_size[1] + out = [] + for s in nested_size[0]: + out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)) + return jagged_from_list(out, offsets) + + +def get_view_test_cases(): + # Test all cases with both an NT base and a dense base + # Subclass -> Subclass + # Dense -> Subclass + + # NB: Don't close over loop variables, they will not get copied into the + # closure + # + # NB: These return functions so we don't generate tensors during test + # collection time + + def mk_basic(base_is_nt): + # There are three cases to consider here based on the logic in + # meta_utils.py + # + # (1) basic case: + # view is not a leaf and has the same requires grad as its basic case + x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) + x = x.clone() if base_is_nt else x + if x.is_leaf: + raise AssertionError("Expected x to not be a leaf") + return x.unsqueeze(-1) + + def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2): + x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1) + x = x.clone() if base_is_nt else x + with torch.no_grad(): + x_view = x.unsqueeze(-1) + # The issue is this doesn't quite work + x_view.requires_grad_(requires_grad_2) + + return x_view + + def mk_obscure(base_is_nt): + x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) + x = x.clone() if base_is_nt else x + # intermediate leaf view + with torch.no_grad(): + x_view = x.unsqueeze(-1) + x_view.requires_grad_(True) + x_view_view = x_view.unsqueeze(-1) + return x_view_view + + for base_is_nt in [False, True]: + prefix = f"base_is_nt_{base_is_nt}" + + yield partial(mk_basic, base_is_nt), f"{prefix}_basic" + + # (2) leaf view case: + # the view has to be a leaf (w/ requires_grad True or requires_grad False) + # base w/ requires_grad True or requires_grad False + for requires_grad_1, requires_grad_2 in itertools.product( + [True, False], repeat=2 + ): + yield ( + partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2), + f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}", + ) + + # (3) obscure case: + # view is not a leaf (implies requires_grad True) + # base w/ requires_grad False) + yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" + + # Subclass -> Dense + yield ( + lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(), + "subclass_dense", + ) + + # Dense -> Subclass -> Dense -> Subclass + def mk_dense_subclass_dense_subclass(): + values = torch.randn(10, 5) + offsets = torch.tensor([0, 3, 6, 10]) + return nested_view_from_values_offsets( + nested_view_from_values_offsets(values, offsets).values(), offsets + ) + + yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass" + + def mk_subclass_dense_subclass_dense(): + x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() + offsets2 = x.offsets().detach().clone() + nested_view_from_values_offsets(x.values(), offsets2).values() + + yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense" + + +VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} + + +compile_full_eager = torch.compile(backend="eager", fullgraph=True) + + +class BaseTorchFunction(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + +class MockSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + +class AttrSubclass(torch.Tensor): + x: int = 10 + size: int = 10 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + +class DummyNDim(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func == torch.Tensor.ndim.__get__: + return 10 + + return super().__torch_function__(func, types, args, kwargs) + + +class WrapperSubclass: + def __init__(self, tensor): + self.tensor = tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) + kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) + + return func(*args, **kwargs) + + +class SigmoidToExpSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func == torch.Tensor.sigmoid: + return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) + + return super().__torch_function__(func, types, args, kwargs) + + +# Wrapper subclass with two inner tensors: data and scale +# data has same shape as outer, and scale has single dim size +class ScaledTensor(torch.Tensor): + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + *, + constant: int = 0, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): + self._data = data + self._scale = scale + self._constant = constant + + def __tensor_flatten__(self): + ctx = {"_constant": self._constant} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + if len(inner_tensors) != 2: + raise AssertionError(f"Expected 2 inner tensors, got {len(inner_tensors)}") + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + constant=metadata["_constant"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + scaled_tensor = args[0] + out = func(scaled_tensor._data, *args[1:], **kwargs) + return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" + + +class OptionalScaledTensor(torch.Tensor): + def __new__( + cls, + data, + scale, + *, + constant: int = 0, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__(self, data: torch.Tensor, scale, constant: int = 0): + self._data = data + self._scale = scale + self._constant = constant + + def __tensor_flatten__(self): + ctx = {"_constant": self._constant} + if self._scale is not None: + return ["_data", "_scale"], ctx + else: + return ["_data"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + return OptionalScaledTensor( + inner_tensors["_data"], + inner_tensors.get("_scale", None), + constant=metadata["_constant"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + scaled_tensor = args[0] + out = func(scaled_tensor._data, *args[1:], **kwargs) + if scaled_tensor._scale is not None: + out = out * scaled_tensor._scale + return OptionalScaledTensor( + out, scaled_tensor._scale, constant=scaled_tensor._constant + ) + + def __repr__(self): + return ( + f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})" + ) + + +class CtxSubclassTensor(torch.Tensor): + """ + Class used to verify guarding on the subclass metadata + """ + + @staticmethod + def __new__(cls, a, constant): + shape = a.shape + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, a, constant): + self.a = a + self.constant = constant + + def __repr__(self): + a_repr = repr(self.a) + return f"CtxSubclassTensor({a_repr})" + + def __tensor_flatten__(self): + return ["a"], (self.constant,) + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, sizes, strides): + constant = meta[0] + a = inner_tensors["a"] + return CtxSubclassTensor(a, constant) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + biggest_constant = max( + [ + x.constant + for x in pytree.tree_flatten(args)[0] + if isinstance(x, CtxSubclassTensor) + ] + ) + args_a = pytree.tree_map( + lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args + ) + kwargs_a = pytree.tree_map( + lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs + ) + out_a = func(*args_a, **kwargs_a) + out = pytree.tree_map( + lambda x: ( + CtxSubclassTensor(x, biggest_constant) + if isinstance(x, torch.Tensor) + else x + ), + out_a, + ) + + if func == torch.ops.aten.mul.Tensor: + out = out + out.constant + + return return_and_correct_aliasing(func, args, kwargs, out) + + +class DeferredInitSubclass(torch.Tensor): + """ + A traceable wrapper subclass that calls super().__init__() BEFORE + setting instance attributes (similar to torchao patterns). + """ + + @staticmethod + def __new__(cls, data, scale): + return torch.Tensor._make_wrapper_subclass( + cls, data.shape, dtype=data.dtype, device=data.device + ) + + def __init__(self, data, scale): + super().__init__() + self._data = data + self._scale = scale + + def __tensor_flatten__(self): + return ["_data"], {"_scale": self._scale} + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, ctx, outer_size, outer_stride): + return cls(inner_tensors["_data"], ctx["_scale"]) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if kwargs is None: + kwargs = {} + args_inner = pytree.tree_map_only(DeferredInitSubclass, lambda x: x._data, args) + out = func(*args_inner, **kwargs) + if isinstance(out, torch.Tensor): + return DeferredInitSubclass(out, args[0]._scale) + return out + + +def func(a): + return a.sin() + + +class EagerRecordGraphAndInputs: + def __init__(self) -> None: + self.graphs = [] + self.example_inputs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + self.graphs.append(gm) + self.example_inputs.append(example_inputs) + return gm + + +GLOBAL_TEST_SUBCLASSES = { + MockSubclass, + DummyNDim, + SigmoidToExpSubclass, + BaseTorchFunction, +} + + +# Returns True if the function recompiles between inputs1 and inputs2 with the +# specified dynamic setting. +def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): + compile_count = [0] + + def counter(gm, example_inputs): + compile_count[0] += 1 + return gm + + compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) + compiled_f(*inputs1) + compiled_f(*inputs2) + return compile_count[0] > 1 + + +class SubclassTests(torch._dynamo.test_case.TestCase): + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): + _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) + + def test_no_call_to_new(self): + class BadNewTorchFunction(torch.Tensor): + def __new__(cls, *args, **kwargs): + raise RuntimeError("Oops!") + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return torch.add(x, 1) + + input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) + + res = fn(input) + self.assertIsInstance(res, BadNewTorchFunction) + + def test_no_torch_function_recompiles(self): + class NJT: + def __repr__(self): + return f"NJT(shape={self.shape})" + + def __init__(self, values, offsets): + self._values = values + self._offsets = offsets + + def sin(self): + return torch.sin(self) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func == torch.sin: + self = args[0] + return NJT(func(self._values), self._offsets) + raise AssertionError("should not get here") + + values1 = torch.randn(10, 3, 4, requires_grad=True) + values2 = torch.randn(10, 3, 4, requires_grad=True) + offsets = torch.tensor([0, 3, 10]) + njt1 = NJT(values1, offsets) + njt2 = NJT(values2, offsets) + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + return torch.sin(x) + + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + f(njt1) + f(njt2) + + def test_base_torch_function_tracing(self): + def fn(x): + return torch.add(x, 1) + + input = torch.ones(2, 2).as_subclass(BaseTorchFunction) + out = fn(input) + out_opt = compile_full_eager(fn)(input) + self.assertIsInstance(out, BaseTorchFunction) + self.assertEqual(out, out_opt) + + def test_torch_function_subclass_with_mode(self): + # Subclass __torch_function__ must still be inlined when a + # TorchFunctionMode is active, otherwise the runtime wrapper + # (DisableTorchFunctionSubclass) is not applied and the subclass + # dispatch fires twice. + class ScaledTensor(torch.Tensor): + @staticmethod + def __new__(cls, data): + return torch.Tensor._make_subclass(cls, data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + raw = super().__torch_function__(func, types, args, kwargs) + if isinstance(raw, torch.Tensor) and func is torch.add: + return raw * 2.0 + return raw + + class NoopMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + return func(*args, **(kwargs or {})) + + a = ScaledTensor(torch.tensor([1.0, 2.0])) + b = ScaledTensor(torch.tensor([3.0, 4.0])) + + with NoopMode(): + eager = torch.add(a, b) + torch._dynamo.reset() + compiled = torch.compile(torch.add, backend="eager")(a, b) + self.assertEqual(eager, compiled) + + def test_torch_function_reentrant_dispatch(self): + class ScaledTensor(torch.Tensor): + @staticmethod + def __new__(cls, data): + return torch.Tensor._make_subclass(cls, data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + raw = super().__torch_function__(func, types, args, kwargs) + if isinstance(raw, torch.Tensor) and func is torch.add: + return raw * 2.0 + return raw + + a = ScaledTensor(torch.tensor([1.0, 2.0])) + b = ScaledTensor(torch.tensor([3.0, 4.0])) + + eager = torch.add(a, b) + torch._dynamo.reset() + compiled = torch.compile(torch.add, backend="eager")(a, b) + self.assertEqual(eager, compiled) + + def test_tensorify_under_disabled_torch_function(self): + # Fixes #180906 + # The checks tensorify_python_scalars works under dispatch + # as it relies on MetaProxy's __torch_function__ to intercept calls + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars + + # Build a minimal graph containing _local_scalar_dense (i.e. .item()) + # on a floating-point placeholder — just enough for tensorify to act. + shape_env = ShapeEnv() + with FakeTensorMode(shape_env=shape_env) as fake_mode: + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4) + + scale_ph = graph.placeholder("scale") + scale_ph.meta["val"] = torch.tensor(1.0) + + item_node = graph.call_function( + torch.ops.aten._local_scalar_dense.default, (scale_ph,) + ) + # tensorify needs a backed SymFloat with a sympy expression + item_node.meta["val"] = shape_env.create_unbacked_symfloat() + + mul_node = graph.call_function(torch.ops.aten.mul.Tensor, (x, item_node)) + mul_node.meta["val"] = torch.randn(4) + graph.output(mul_node) + + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + # Run tensorify_python_scalars with __torch_function__ + # disabled — without the _EnableTorchFunction fix, this raises: + # RuntimeError: prims::convert_element_type() Expected + # a value of type 'Tensor' ... found type 'MetaProxy'. + with torch._C.DisableTorchFunctionSubclass(): + tensorify_python_scalars(gm, shape_env, fake_mode) + + # The pass should have inserted a convert_element_type node + # that upcasts the scale placeholder to float64. + convert_nodes = [ + n + for n in gm.graph.nodes + if n.op == "call_function" + and n.target is torch.ops.prims.convert_element_type.default + ] + self.assertTrue(len(convert_nodes) > 0) + # The first convert_element_type is the float64 upcast; + # verify it carries correct metadata. + self.assertEqual(convert_nodes[0].meta["val"].dtype, torch.float64) + + def test_torch_function_state_graph_break(self): + @torch.compile(backend="eager") + def fn(x): + with torch._C.DisableTorchFunctionSubclass(): + torch._dynamo.graph_break() + return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) + + input = torch.ones(2, 2) + res, _ = fn(input) + self.assertFalse(res) + + def test_disable_all_torch_function(self): + @torch.compile(backend="eager") + def fn(x): + with torch._C.DisableTorchFunction(): + torch._dynamo.graph_break() + return ( + torch._C._is_torch_function_enabled(), + torch._C._is_torch_function_all_disabled(), + torch.add(x, 1.0), + ) + + input = torch.ones(2, 2) + res1, res2, _ = fn(input) + self.assertFalse(res1) + self.assertTrue(res2) + + def test_disable_all_torch_function_restore_values(self): + @torch.compile(backend="eager") + def fn(x): + with torch._C.DisableTorchFunction(): + x = torch._C._is_torch_function_all_disabled() + + return ( + x, + torch._C._is_torch_function_all_disabled(), + torch.add(x, 1.0), + ) + + input = torch.ones(2, 2) + res1, res2, _ = fn(input) + self.assertTrue(res1) + self.assertFalse(res2) + + def test_disable_all_torch_function_restore_values_graph_break(self): + @torch.compile(backend="eager") + def fn(x): + with torch._C.DisableTorchFunction(): + torch._dynamo.graph_break() + x = torch._C._is_torch_function_all_disabled() + + return ( + x, + torch._C._is_torch_function_all_disabled(), + torch.add(x, 1.0), + ) + + input = torch.ones(2, 2) + res1, res2, _ = fn(input) + self.assertTrue(res1) + self.assertFalse(res2) + + def test_torch_function_state_nested(self): + @torch.compile(backend="eager") + def fn(x): + with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunctionSubclass(): + x = x + 1 + # Should reset to the outer state (disabled) after exiting ctx manager + return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) + + input = torch.ones(2, 2) + res, _ = fn(input) + self.assertFalse(res) + + def test_torch_function_state_tracing(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + with torch._C.DisableTorchFunctionSubclass(): + torch.add(x, 1.0) + + input = torch.ones(2, 2) + + fn(input) + + def test_torch_function_state_guards(self): + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x): + torch.add(x, 1.0) + + input = torch.ones(2, 2) + + with torch._C.DisableTorchFunctionSubclass(): + fn(input) + + fn(input) + + self.assertEqual(cnt.frame_count, 2) + + def test_return_subclass(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return MockSubclass(torch.add(x, 1.0)) * 2 + + input = torch.ones(2, 2) + + res = fn(input) + self.assertIsInstance(res, MockSubclass) + + def test_return_as_subclass(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return torch.add(x, 1.0).as_subclass(MockSubclass) * 2 + + input = torch.ones(2, 2) + + res = fn(input) + self.assertIsInstance(res, MockSubclass) + + def test_return_local_subclass(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return LocalSubclass(torch.add(x, 1.0)) * 2 + + input = torch.ones(2, 2) + + res = fn(input) + self.assertIsInstance(res, LocalSubclass) + + def test_torch_function_list_args(self): + HANDLED_FUNCTIONS = {} + + class MyClass: + def __init__(self, foo): + self.foo = foo + + @classmethod + def __torch_function__( + cls, + func, + types, + args=(), + kwargs=None, + ): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS or not all( + [ # noqa: C419 + issubclass(t, (torch.Tensor, MyClass)) for t in types + ] + ): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + def _stack(input, dim=0, *, out=None): + return MyClass(sum([x.foo for x in input])) + + HANDLED_FUNCTIONS[torch.stack] = _stack + + @torch.compile(backend="eager", fullgraph=True) + def fn(v0, v1): + return torch.stack([v0, v1]) + + ret = fn(MyClass(1), MyClass(1)) + self.assertEqual(ret.foo, 2) + + @parametrize( + "comparison", + [ + subtest(isinstance, "isinstance"), + subtest(lambda instance, type_: type(instance) is type_, "equality"), + subtest(lambda instance, type_: type(instance) is type_, "identity"), + ], + ) + @parametrize( + "input_type", + [ + subtest(torch.Tensor, "tensor"), + subtest(DummyNDim, "subclass"), + ], + ) + def test_type_check(self, comparison, input_type): + def fn(x): + if comparison(x, DummyNDim): + return torch.ones(1, 1) + else: + return torch.zeros(2, 2) + + input = torch.ones(2, 2).as_subclass(input_type) + exp_res = fn(input) + act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) + self.assertEqual(exp_res, act_res) + + def test_torch_function_call_on_method(self): + x = torch.ones(2, 2) + y = torch.ones(2, 2) + z = torch.ones(2, 2) + wrapped = x.as_subclass(SigmoidToExpSubclass) + wrapped2 = y.as_subclass(SigmoidToExpSubclass) + + def fn(w): + return w.exp() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped2) + res_exp2 = z.exp() + + self.assertEqual(res_exp, res_act) + self.assertEqual(res_exp, res_exp2) + + def test_torch_function_call_on_method_arg(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func == torch._C.TensorBase.add_: + func = torch._C.TensorBase.sub_ + + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + def sigmoid(self): + return None + + x = torch.ones(2, 2) + y = torch.ones(2, 2) + z = torch.ones(2, 2) + wrapped = y.as_subclass(LocalSubclass) + wrapped2 = z.as_subclass(LocalSubclass) + + def fn(a, w): + a.add_(w) + return a + + fn_opt = torch.compile(fn, backend="eager") + + res_exp = fn(x, wrapped) + res_act = fn_opt(y, wrapped2) + + self.assertEqual(res_exp, res_act) + + def test_user_overridden_method_unsupported(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + def sigmoid(self): + return None + + def fn(x): + x.sigmoid() + + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) + + def test_user_overridden_attr_unsupported(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + ndim = 10 + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim + + msg = "`torch.compile` only support tracing certain types of overridden tensor subclass attributes" + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + + def test_user_overridden_property_unsupported(self): + class LocalSubclass(torch.Tensor): + def __init__(self, *args, **kwargs) -> None: + self._ndim = 10 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + + @property + def ndim(self): + return self._ndim + + @ndim.setter + def ndim(self, value): + self._ndim = value + + def fn(x): + return x + x.ndim + + x = LocalSubclass(torch.ones(2, 2)) + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) + + def test_overridden_method_guarding(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + @torch.compile(backend="eager") + def fn(x): + return x.sigmoid() + + with torch._dynamo.config.patch(error_on_recompile=True): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + fn(x) + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + + with self.assertRaisesRegex(TypeError, "'bool' object is not callable"): + LocalSubclass.sigmoid = False + fn(x) + + def test_torch_function_call_on_attr(self): + x = torch.ones(2, 2) + wrapped = x.as_subclass(DummyNDim) + + def fn(w): + return w.ndim + torch.ones(2) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + + self.assertEqual(res_exp, res_act) + self.assertEqual(res_exp, torch.ones(2) + 10) + + def test_torch_function_wrapper_class(self): + x = torch.ones(2, 2) + wrapped = WrapperSubclass(x) + + def fn(w): + return torch.add(w, 1.0) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + self.assertEqual(res_exp, res_act) + + def test_no_torch_function_on_size_bytecode(self): + class TestTensor(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + out = func(*args, **kwargs) + + if func == torch.clone: + return out * 2 + else: + return out + + def fn(x): + return torch.clone(x) + + inp = torch.ones(4, 4) + x = inp.as_subclass(TestTensor) + torch._dynamo.mark_dynamic(x, 0) + compiled_fn = torch.compile(fn, fullgraph=True) + out = compiled_fn(x) + self.assertEqual(out, torch.ones(4, 4) * 2) + + def test_tensor_subclass_unpack(self): + class Foo(torch.Tensor): + pass + + torch._dynamo.config.traceable_tensor_subclasses.add(Foo) + try: + + @torch.compile(backend="eager", fullgraph=True) + def fn_list(x): + return list(x) + + x = torch.ones(3).as_subclass(Foo) + res_list = fn_list(x) + + for elem in res_list: + self.assertIsInstance(elem, Foo) + self.assertEqual(len(res_list), 3) + finally: + torch._dynamo.config.traceable_tensor_subclasses.discard(Foo) + + def test_torch_function_wrapper_class_with_kwargs(self): + x = torch.ones(2, 2) + wrapped = WrapperSubclass(x) + + def fn(w): + return torch.add(w, 1.0, alpha=2.0) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + self.assertEqual(res_exp, res_act) + + def test_tensor_subclass_with_non_classmethod_torch_function(self): + class MySubclass(torch.Tensor): + def __torch_function__(self, func, types, args, kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + def fn(x): + return x + 1 + + fn_opt = compile_full_eager(fn) + + x = torch.randn(2, 2).as_subclass(MySubclass) + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_tensor_subclass_custom_attr(self): + class AttrSubclass(torch.Tensor): + x: int = 10 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + return super().__torch_function__(func, types, args, kwargs) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.x + torch.ones(2, 2) + + input = torch.ones(2, 2).as_subclass(AttrSubclass) + fn_opt = compile_full_eager(fn) + + res_exp = fn(input) + res_act = fn_opt(input) + self.assertEqual(res_exp, res_act) + + def test_make_subclass(self): + # Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo + # models its aliasing relationships correctly. + class MySubclass(torch.Tensor): + pass + + def fn(x): + # Downcast then upcast + y = torch.Tensor._make_subclass(MySubclass, x) + z = torch.Tensor._make_subclass(torch.Tensor, x) + # Now `x, y, z` should have the same underlying data. + x += 1 + y += 2 + z += 3 + res = x * y + z + return res + + x0 = torch.randn(2, 2) + x1 = x0.clone() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + + def test_subclass_override_shape_and_to(self): + # This is a slight variabtion of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + class MySubclass(torch.Tensor): + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + return new + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + def fn(x): + x_shape = x.shape + y = x.to("cpu") + return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape + + x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + self.assertEqual(x0.tensor_shape, x1.tensor_shape) + + def test_subclass_dont_invoke_torch_function_on_overridden_method(self): + # We shouldn't fire `__torch_function__` for overridden tensor methods. + class MySubclass(torch.Tensor): + def to(self, device): + return self * len(device) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func is torch.Tensor.to: + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x.to("cpu") + + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_subclass_dont_invoke_torch_function_on_overridden_attr(self): + from types import MethodWrapperType + + # We shouldn't fire `__torch_function__` for overridden tensor attrs. + class MySubclass(torch.Tensor): + def ndim(self): + return 42 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if type(func) is MethodWrapperType and func.__name__ == "ndim": + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x + x.ndim() + + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_parameter_subclass_with_old_torch_function(self): + class MySubclass(torch.nn.Parameter): + pass + + def fn(x): + x = x.t() + x = x.T + return x + 1 + + fn_opt = compile_full_eager(fn) + + x = torch.randn(2, 2).as_subclass(MySubclass) + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_subclass_with_disabled_torch_function(self): + class MySubclass(torch.Tensor): + __torch_function__ = torch._C._disabled_torch_function_impl + + def fn(x): + x = x.t() + x = x.T + return x + 1 + + fn_opt = compile_full_eager(fn) + + x = torch.randn(2, 2).as_subclass(MySubclass) + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): + # This is a slight variation of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + # which basically + # 1. uses tensor subclass to attach quantization metadata onto tensors + # 2. preserve them across torch ops + # 3. use the metadata to dequantize the tensor + # 4. convert it to a regular tensor. + # + # The test is meant to make sure Dynamo won't graph break over it. + class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + return self + + def __init__(self, *args, quant_type=None, **kwargs): + self.quant_type = quant_type + + def as_tensor(self): + return torch.Tensor._make_subclass( + torch.Tensor, self, self.requires_grad + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + quant_type = None + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [ + ( + cls(x, quant_type=quant_type) + if isinstance(x, torch.Tensor) + else x + ) + for x in result + ] + return type(result)(wrapped) + else: + return result + + def f(x): + tmp = x * 2 + tmp = tmp + tmp.quant_type + tmp = tmp.as_tensor() + return tmp * 3 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = GGUFParameter(torch.ones(2), quant_type=42) + res = f(x) + ref = opt_f(x) + self.assertEqual(res, ref) + + def test_newly_constructed_tensor_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from constructor call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = MySubclass(torch.ones(2)) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_as_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from as_subclass call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = torch.ones(2).as_subclass(MySubclass) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_subclass_method_override(self): + class MyTensor(torch.Tensor): + def sum(self, dim=None, keepdim=False): + return super().sum(dim=dim, keepdim=keepdim).as_subclass(MyTensor) + + def fn(x): + y = x.as_subclass(MyTensor) + return y.sum(dim=1) + + x = torch.randn(4, 10) + fn_opt = torch.compile(fn, backend="eager", fullgraph=False) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + self.assertIsInstance(res_act, MyTensor) + + def test_tensor_subclass_attr_codegen_tos(self): + # This repros a very subtle interaction between + # `TensorWithTFOverrideVariable` attribute mutation codegen and + # `PyCodegen.top_of_stack`. It was uncovered from + # `test_tensor_subclass_deepcopy`. + class MySubclass(torch.Tensor): + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_subclass(cls, torch.ones(0)) + r.elem = elem + return r + + def f(t): + return MySubclass(t.elem.clone()) + + opt_f = compile_full_eager(f) + + t = MySubclass(torch.ones(2)) + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + self.assertEqual(res.elem, ref.elem) + self.assertEqual(type(res), type(ref)) + + def test_nontraceable_tensor_subclass(self): + # This will error if Dynamo tries to wrap it as a tensor variable, + # because that involves calling certain methods to inspect the tensor + # property, which will blow up in the overridden `__torch_function__`. + class MySubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("one shall not pass") + + def f(t): + return t.foo + torch.ones(10) + + opt_f = torch.compile(f, backend="eager", fullgraph=False) + + t = MySubclass(torch.ones(2)) + t.foo = 42 + # Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo + # from wrapping `t`. + with nontraceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + + def test_compile_with_fake_tensor_dynamic_dim(self): + x = torch.randn([3, 4]) + + def f(x): + return torch.sin(x) + + def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): + torch._dynamo.reset() + cnt = torch._dynamo.testing.CompileCounter() + + opt_f = torch.compile(f, backend=cnt, fullgraph=True) + + x1 = torch.rand_like(x) + f(x) + f(torch.randn([4, 3])) + shape_env = ShapeEnv() + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=shape_env + ) as fake_mode: + x_fake = fake_mode.from_tensor( + x, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[dim_dynamic for i in range(x.dim())] + ), + ) + x1_fake = fake_mode.from_tensor( + x1, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[dim_dynamic for i in range(x.dim())] + ), + ) + opt_f(x_fake) + opt_f(x1_fake) + + self.assertEqual(cnt.frame_count, exp_frame_count) + self.assertEqual(cnt.op_count, exp_op_count) + + test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) + test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) + test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) + + def test_compile_with_fake_tensor_automatic_dynamic(self): + def f(x): + return torch.sin(x) + + def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): + torch._dynamo.reset() + cnt = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnt, fullgraph=True) + + shape_env = ShapeEnv() + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=shape_env + ) as fake_mode: + for inp in inps: + fake_inp = fake_mode.from_tensor( + inp, + symbolic_context=StatelessSymbolicContext( + [dim_dynamic for i in range(x.dim())] + ), + ) + opt_f(fake_inp) + self.assertEqual(cnt.frame_count, exp_frame_count) + self.assertEqual(cnt.op_count, exp_op_count) + + x = torch.randn([3, 4]) + y = torch.randn([4, 5]) + z = torch.randn([5, 6]) + a = torch.randn([3, 5]) + b = torch.randn([4, 4]) + # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs + # to opt_f will be tensors with SymInt sizes. Dynamo will treat input + # as dynamic automatically and will only compile once + for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: + test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) + test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) + test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) + + for dim_dynamic in [DimDynamic.STATIC]: + # Recompile once, first with dim 0 and 1 become Dynamic + test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) + # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. + test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) + # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. + test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) + + def test_compile_with_functionalization(self): + x = torch.randn([3, 4]) + x_clone = x.clone() + x_clone2 = x.clone() + backend = EagerRecordGraphAndInputs() + cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) + + @torch.compile(backend=cnt, fullgraph=True) + def f(x): + return x.add_(1.0) + torch.nn.functional.relu_(x) + + f_out = f(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 3) + self.assertEqual(len(backend.graphs), 1) + self.assertEqual(len(backend.example_inputs), 1) + + actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + add_: "f32[3, 4]" = l_x_.add_(1.0) + relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None + add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None + return (add,) +""", + ) + + ff = torch.func.functionalize(f) + ff_out = ff(x_clone) + + self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.op_count, 6) + self.assertEqual(len(backend.graphs), 2) + self.assertEqual(len(backend.example_inputs), 2) + actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + add_: "f32[3, 4]" = l_x_.add_(1.0) + relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None + add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None + return (add,) +""", + ) + self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) + + # Cannot reuse the version from AOTAutograd, since that uses python functional tensors. + def to_fun(x): + x_functional = torch._to_functional_tensor(x) + torch._mirror_autograd_meta_to(x, x_functional) + return x_functional + + def aot_f_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch._enable_functionalization(reapply_views=False) + try: + func_args = pytree.tree_map(to_fun, args) + func_kwargs = pytree.tree_map(to_fun, kwargs) + return func(*func_args, **func_kwargs) + finally: + torch._disable_functionalization() + + return wrapper + + aot_ff = aot_f_wrapper(f) + aot_ff_out = aot_ff(x_clone2) + + self.assertEqual(cnt.frame_count, 3) + self.assertEqual(cnt.op_count, 9) + self.assertEqual(len(backend.graphs), 3) + self.assertEqual(len(backend.example_inputs), 3) + actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + add_: "f32[3, 4]" = l_x_.add_(1.0) + relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None + add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None + return (add,) +""", + ) + self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) + + self.assertEqual(f_out, ff_out) + self.assertEqual(f_out, aot_ff_out) + + try: + torch._enable_functionalization(reapply_views=False) + xf = pytree.tree_map(to_fun, x) + x_view = xf.t() + with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): + f(x_view) + finally: + torch._disable_functionalization() + + def test_compile_higher_order_with_functionalization(self): + backend = EagerRecordGraphAndInputs() + cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) + + @torch.compile(backend=cnt, fullgraph=True) + def f(x): + return wrap(lambda x: x.add_(1.0), x) + + def check_count_and_graph( + exp_frame_count, exp_op_count, exp_n_graph, exp_graph + ): + self.assertEqual(cnt.frame_count, exp_frame_count) + self.assertEqual(cnt.op_count, exp_op_count) + self.assertEqual(len(backend.graphs), exp_n_graph) + actual = normalize_gm( + backend.graphs[exp_n_graph - 1].print_readable(print_output=False) + ) + self.assertExpectedInline(actual, exp_graph, skip=1) + + t = torch.randn([3, 4]) + t_clone = t.clone() + t_clone2 = t.clone() + f(t) + + check_count_and_graph( + 1, + 2, + 1, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None + getitem: "f32[3, 4]" = wrap[0]; wrap = None + return (getitem,) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[3, 4]"): + add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None + return (add_,) +""", + ) + + ff = torch.func.functionalize(f) + ff_out = ff(t_clone) # noqa: F841 + # frame count and op count are incremented due to re-compilation + check_count_and_graph( + 2, + 4, + 2, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None + getitem: "f32[3, 4]" = wrap[0]; wrap = None + return (getitem,) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[3, 4]"): + add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None + return (add_,) +""", + ) + + try: + x = torch._to_functional_tensor(t_clone2) + torch._mirror_autograd_meta_to(t_clone2, x) + torch._enable_functionalization(reapply_views=False) + aot_f_out = f(x) # noqa: F841 + finally: + torch._disable_functionalization() + + # frame count and op count are incremented due to re-compilation + check_count_and_graph( + 3, + 6, + 3, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None + getitem: "f32[3, 4]" = wrap[0]; wrap = None + return (getitem,) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[3, 4]"): + add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None + return (add_,) +""", + ) + + def test_has_torch_function(self): + class MyTensor: + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.max: + return torch.tensor(123) + return func(*args, **kwargs) + + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return func(*args, **kwargs) + + def fn(x): + return torch.overrides.has_torch_function_unary( + x + ), torch.overrides.has_torch_function_variadic(x) + + for test_class in [MyTensor, LocalSubclass]: + x = test_class() + ref0 = fn(x) + ref1 = fn(4) + opt_fn = torch.compile(fn, backend="eager") + res0 = opt_fn(x) + res1 = opt_fn(4) + self.assertEqual(ref0, res0) + self.assertEqual(ref1, res1) + + def test_wrapper_subclass_guards_on_inner_tensor(self): + # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. + # Also adds additional guards on the inner tensor's sizes. + # When the first input to an op has x.shape[0] > 5, we insert an extra add node. + class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): + @staticmethod + def __new__(cls, inner): + # Double the outer-most dimension + outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] + return torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + outer_shape, + inner.stride(), + None, + None, + inner.dtype, + inner.layout, + inner.device, + False, + inner.requires_grad, + ) + + def __init__(self, inner): + self.inner_elem = inner + + def __tensor_flatten__(self): + return ["inner_elem"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): + return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) + + def __repr__(self): + return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args_inner = torch.utils._pytree.tree_map_only( + DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args + ) + out_inner = func(*args_inner, **kwargs) + + # Add guards on the inner tensor's sizes + if args_inner[0].shape[0] > 3: + out_inner += 2 + + return DoubleSizeMaybeAddGeThreeTensor(out_inner) + + curr_var_to_val = None + curr_var_to_sources = None + guards = None + + def backend(gm, args): + context = torch._guards.TracingContext.get() + + # Grab info on sources and guards from the shapeenv + nonlocal curr_var_to_val + nonlocal curr_var_to_sources + nonlocal guards + + guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] + curr_var_to_val = { + str(k): v + for k, v in context.fake_mode.shape_env.backed_var_to_val.items() + } + curr_var_to_sources = { + str(k): v[0].name + for k, v in context.fake_mode.shape_env.var_to_sources.items() + } + return gm + + @torch.compile(backend=backend) + def fn(x): + if x.shape[0] < 13: + return torch.mul(x, x) + else: + return torch.div(x, x) + + inp = torch.ones(4, 4) + + x = DoubleSizeMaybeAddGeThreeTensor(inp) + torch._dynamo.mark_dynamic(x, 0) + res = fn(x) # noqa: F841 + # During fakeifying, we end up allocating a separate symint + # for the outer and inner tensor (in this test, s0 is unused). + expected_var_to_val = { + "s50": 4, + "s77": 8, + } + expected_var_to_sources = { + "s50": "L['x'].inner_elem.size()[0]", + "s77": "L['x'].size()[0]", + } + self.assertEqual(curr_var_to_val, expected_var_to_val) + self.assertEqual(curr_var_to_sources, expected_var_to_sources) + self.assertExpectedInline( + "\n".join(guards), + """\ +Eq(2*s50, s77) +2*s50 < 13 +s50 > 3""", + ) + + def test_wrapper_subclass_with_same_sized_inner_tensor(self): + # shouldn't recompile for different sizes when dynamic=True + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) + self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) + + # should recompile for different data size when dynamic=False + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + # avoid recompile using manual mark_dynamic() for different data size + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size + torch._dynamo.mark_dynamic(sub1, 0) + torch._dynamo.mark_dynamic(sub1, 1) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) + self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + def test_wrapper_subclass_with_differently_sized_inner_tensor(self): + # should recompile for different scale size when dynamic=False + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) + sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + # still recompiles using manual mark_dynamic() on outer for different scale size + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) + # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size + torch._dynamo.mark_dynamic(sub1, 0) + torch._dynamo.mark_dynamic(sub1, 1) + sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + def test_recompiles_with_optional_inner_tensor(self): + def f(x): + return x + 1 + + # sub1 does not have the optional tensor specified while sub2 does + sub1 = OptionalScaledTensor(torch.randn(2, 4), None) + sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4)) + + # sanity check; don't recompile for same input + self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True)) + self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True)) + + # these should recompile; optional tensor changes between specified and unspecified + self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True)) + self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True)) + + f_compiled = torch.compile(f, backend="aot_eager") + self.assertEqual(f(sub1)._data, f_compiled(sub1)._data) + self.assertEqual(f(sub2)._data, f_compiled(sub2)._data) + + def test_torch_dispatch_subclass_guard_recompile(self): + x = torch.ones(2, 2) + x_two = TwoTensor(x.clone(), x.clone()) + + def fn(w): + return torch.add(w, 1.0) + + fn_opt = torch.compile(backend="eager")(fn) + + ref = fn(x_two) + res = fn_opt(x_two) + self.assertEqual(ref, res) + + # ensure no recompilation on same input type + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + fn_opt(TwoTensor(x + 1, x + 2)) + + # recompile! + ref = fn(x) + res = fn_opt(x) + self.assertEqual(ref, res) + + def test_tensor_subclass_ctx_guards(self): + x = CtxSubclassTensor(torch.ones(2), 3) + x2 = CtxSubclassTensor(torch.ones(2), 3) + x3 = CtxSubclassTensor(torch.ones(2), 4) + _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) + _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) + + def test_tensor_subclass_ctx_recursive_guards(self): + x0 = torch.ones(2, 2) + x1 = CtxSubclassTensor(x0.clone(), 2) + x2 = CtxSubclassTensor(x0.clone(), 3) + tt0 = TwoTensor(x0.clone(), x1) + tt1 = TwoTensor(x0.clone(), x2) + + _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True) + + def test_tensor_subclass_ctx_custom_guards_override(self): + class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): + @classmethod + def __metadata_guard__(cls, orig_data, other): + return orig_data[0] <= other[0] + + x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2) + x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) + x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1) + _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) + _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) + + def test_tensor_subclass_ctx_custom_guards_error_arg_num(self): + import torch._dynamo.exc + + class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): + @classmethod + def __metadata_guard__(cls, y): + # Shouldn't reach here + return False + + x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) + self.assertRaisesRegex( + torch._dynamo.exc.InternalTorchDynamoError, + "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments", + lambda: torch.compile(lambda x: x * x, backend="eager")(x), + ) + + def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self): + import torch._dynamo.exc + + class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): + def __metadata_guard__(self, x, y): + return False + + x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) + self.assertRaisesRegex( + torch._dynamo.exc.InternalTorchDynamoError, + "Tensor subclass method __metadata_guard__ must be a classmethod", + lambda: torch.compile(lambda x: x * x, backend="eager")(x), + ) + + def test_tensor_subclass_metadata_with_symint(self): + # TENSOR_SUBCLASS_METADATA_MATCH replaces SymInts in metadata with + # _AnyCompare sentinels so that (a) deepcopy doesn't pull in the + # ShapeEnv, (b) dynamic dims aren't over-guarded, and (c) unbacked + # SymInts don't cause errors. + from torch._subclasses.fake_tensor import FakeTensorMode + + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + sym_ctx = StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC], + ) + fake_t = fake_mode.from_tensor(torch.randn(8), symbolic_context=sym_ctx) + sym_int = fake_t.shape[0] # SymInt with hint 8 + + # Construct a real tensor whose metadata contains a SymInt. + x1 = CtxSubclassTensor(torch.randn(8, 8), sym_int) + + @torch.compile(backend="eager", dynamic=True, fullgraph=True) + def f(x): + return x * x + + f(x1) + + # Without the fix, the guard stores the SymInt's hint value (8) + # and recompiles when it sees a different constant. With the fix, + # the SymInt position is replaced by _AnyCompare so the guard + # passes for any constant value. + x2 = CtxSubclassTensor(torch.randn(8, 8), 3) + _check_recompiles(self, f, (x1,), (x2,), False) + + def test_subclass_constructor_proxying(self): + import dataclasses + from collections import namedtuple + from typing import Any + + @dataclasses.dataclass(frozen=True) + class SubclassTensorArgs: + original_shape: torch.Size + device: torch.device + inner_meta: Any + + SubclassTensorArgs2 = namedtuple( + "SubclassTensorArgs2", + [ + "original_shape", + "device", + "inner_meta", + ], + ) + + class SubclassTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, meta): + shape = a.shape + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, a, meta): + self.a = a + self.meta = meta + + def __repr__(self): + a_repr = repr(self.a) + return f"SubclassTensor({a_repr})" + + def __tensor_flatten__(self): + return ["a"], self.meta + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, _, __): + a = inner_tensors["a"] + return SubclassTensor(a, meta) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, args + ) + kwargs_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs + ) + out_a = func(*args_a, **kwargs_a) + out = pytree.tree_map( + lambda x: ( + SubclassTensor(x, SubclassTensorArgs2(x.shape, x.device, None)) + if isinstance(x, torch.Tensor) + else x + ), + out_a, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + @torch.compile(fullgraph=True, backend="eager") + def f1(x): + meta = SubclassTensorArgs( + x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + + @torch.compile(fullgraph=True, backend="eager") + def f1(x): + meta = SubclassTensorArgs2( + x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + + def test_torch_function_subclass_survives_into_aot_autograd(self): + # If you have a tensor subclass that relies on dispatch into the same op + # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), + # the torch function-ness will survive into AOTAutograd. Today, NestedTensor + # actually relies on this behavior! Because that torch function logic + # runs during AOTAutograd, this test tests that there is no logic below + # that relies torch function that gets unexpectedly disabled after we + # redispatch from the subclass's torch function. + class SubTensor(torch.Tensor): + @staticmethod + def __new__(cls, t): + return torch.Tensor._make_wrapper_subclass( + cls, + t.shape, + t.stride(), + t.storage_offset(), + torch.contiguous_format, + t.dtype, + torch.strided, + t.device, + False, + t.requires_grad, + "sizes", + False, + False, + None, + ) + + def __init__(self, t): + super().__init__() + self._t = t + + def __tensor_flatten__(self): + return ["_t"], {} + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + t = inner_tensors["_t"] + return SubTensor(t) + + def __repr__(self): + return f"SubTensor({self._t})" + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) + output = func(*new_args, **kwargs) + output = pytree.tree_map_only( + torch.Tensor, lambda t: SubTensor(t), output + ) + return output + + @torch.compile(dynamic=True, backend="eager") + def f(x): + return x.unflatten(-1, [2, 5]) + + s = SubTensor(torch.randn(3, 10)) + f(s) + + # Guard validation upsets the guard + # https://github.com/pytorch/pytorch/issues/129936 + @unittest.expectedFailure + def test_recompile_with_symbool_inputs(self): + def f(pred: bool): + if pred: + return torch.ones([3, 4]) + else: + return torch.ones([4, 3]) + + def test_recompilation( + f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards + ): + torch._dynamo.reset() + shape_env = ShapeEnv() + backend = torch._dynamo.testing.EagerAndRecordGraphs() + cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) + f_cond = torch.compile(f, backend=cnt, fullgraph=True) + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=shape_env + ) as fake_mode: + fake_inp = fake_mode.from_tensor( + x, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] + ), + ) + for i, size in enumerate(sizes): + pred = fake_inp.size(0) == size + f_cond(pred) + actual = normalize_gm( + backend.graphs[exp_frame_count[i] - 1].print_readable( + print_output=False + ) + ) + actual_guard_str = [str(guard.expr) for guard in shape_env.guards] + self.assertExpectedInline(actual, exp_graphs[i]) + self.assertEqual(cnt.frame_count, exp_frame_count[i]) + self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) + + true_graph = """\ +class GraphModule(torch.nn.Module): + def forward(self): + ones: "f32[3, 4]" = torch.ones([3, 4]) + return (ones,) +""" + false_graph = """\ +class GraphModule(torch.nn.Module): + def forward(self): + ones: "f32[4, 3]" = torch.ones([4, 3]) + return (ones,) +""" + test_recompilation( + f, + torch.randn([3, 4]), + [3, 3, 4, 5], + exp_graphs=[true_graph, true_graph, false_graph, false_graph], + exp_frame_count=[1, 1, 2, 2], + exp_shape_env_guards=[ + [], + # s0 is specialized and guarded in outer shape_env when dynamo checks the guards + ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], + [ + "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", + "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", + ], + [ + "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", + "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", + "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", + ], + ], + ) + + test_recompilation( + f, + torch.randn([3, 4]), + [4, 5, 3, 3], + exp_graphs=[false_graph, false_graph, true_graph, true_graph], + exp_frame_count=[1, 1, 2, 2], + exp_shape_env_guards=[ + [], + # s0 is specialized and guarded in outer shape_env when dynamo checks the guards + ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], + [ + "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", + "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", + ], + [ + "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", + "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", + "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", + ], + ], + ) + + def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): + def f(x_subclass): + tmp_subclass = torch.add(x, 1) + return torch.mul(tmp_subclass._scale, tmp_subclass._constant) + + x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) + out_ref = f(x) + out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(out_ref, out_test) + + def test_support_bases(self): + import abc + + import torch.fx._symbolic_trace + + class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): + def __new__(cls, name, bases, dct): + x = super().__new__(cls, name, bases, dct) + x.attr = 100 + return x + + class Multistreamable(abc.ABC): # noqa: B024 + pass + + class Foo(Multistreamable, metaclass=Meta): + pass + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + typ = type(Foo()) + typ.__bases__ + return typ.__bases__ + + self.assertEqual(f(torch.randn(1)), (Multistreamable,)) + + @torch.compile(backend="eager", fullgraph=True) + def g(x): + typ = type(Foo()) + typ.__base__ + return typ.__base__ + + self.assertEqual(g(torch.randn(1)), Multistreamable) + + @parametrize("dynamic", [False, True]) + def test_subclass_views(self, dynamic): + def _get_views(t): # returns (view: Tensor, expects_raises_false) + # Note that any closed-over SymInts will be symbolicized during fake-ification. + yield t.narrow(dim=-1, start=3, length=8), False + yield t.split(5, -1)[2], False + yield t.split_with_sizes([9, 6], -1)[1], False + yield t.unsqueeze(-1).expand(4, 15, 10), False + yield t.select(-1, 6), False + # https://github.com/pytorch/pytorch/issues/128649 + yield t[2:3, 5:9], dynamic + yield t.view(-1, 15), False + + def f(x): + return x * 2 + + compiled_f = torch.compile( + f, backend="aot_eager", fullgraph=True, dynamic=dynamic + ) + + # Take a view of a subclass to pass as input. + t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) + for view, expects_raises in _get_views(t): + torch._dynamo.reset() + out_ref = f(view) + if expects_raises: + with self.assertRaises(AssertionError): + out_test = compiled_f(view) + else: + out_test = compiled_f(view) + self.assertEqual(out_ref, out_test) + + @parametrize("dynamic", [True, False]) + def test_mark_static_with_subclass_desugaring(self, dynamic): + from collections.abc import Callable + from typing import Any + + from torch._dynamo.decorators import mark_static_address + from torch._inductor.compile_fx import compile_fx + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + x_inner = torch.ones(4) + x = TwoTensor(x_inner, x_inner) + mark_static_address(x, guard=False) + + def inner_compile( + gm: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + cudagraphs: BoxedBool | None = None, + static_input_idxs: list[int] | None = None, + is_backward: bool = False, + graph_id: int | None = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: BoxedDeviceIndex | None = None, + layout_opt: bool | None = None, + extern_node_serializer: Callable[[list[Any]], Any] | None = None, + **kwargs: Any, + ): + if dynamic: + self.assertEqual(static_input_idxs, [2, 3, 4]) + else: + self.assertEqual(static_input_idxs, [1, 2]) + return gm + + compiler = functools.partial(compile_fx, inner_compile=inner_compile) + + @torch.compile(backend=compiler, dynamic=dynamic) + def fn(t0, t1, t2): + return t0 + t1 + t2 + 2 + + fn(torch.ones(4), x, torch.ones(4)) + + def test_subclass_parameters_are_static_under_training(self): + from collections.abc import Callable + from typing import Any + + from torch._inductor.compile_fx import compile_fx + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + def inner_compile( + gm: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + cudagraphs: BoxedBool | None = None, + static_input_idxs: list[int] | None = None, + is_backward: bool = False, + graph_id: int | None = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: BoxedDeviceIndex | None = None, + layout_opt: bool | None = None, + extern_node_serializer: Callable[[list[Any]], Any] | None = None, + **kwargs: Any, + ): + # Important bit: there are 3 params: linear.weight.a, linear.weight.b, linear.bias, + # which are the first 3 args of the graph. + self.assertEqual(static_input_idxs, [0, 1, 2]) + return gm + + compiler = functools.partial(compile_fx, inner_compile=inner_compile) + + mod = torch.nn.Linear(4, 4) + w_a = torch.randn(4, 4) + w_b = torch.randn(4, 4) + w = torch.nn.Parameter(TwoTensor(w_a, w_b).requires_grad_()) + mod.weight = w + + mod = torch.compile(mod, backend=compiler) + + mod(torch.randn(4)) + + # copied from common_utils.py::NestedTensorTestCase + def assertEqualIgnoringNestedInts(self, a, b): + # unbinding NJTs allows us to compare them as essentially equal without + # caring about exact nested int comparison + def _unbind_njts(x): + if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged: + return x.unbind() + else: + return x + + self.assertEqual( + pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b) + ) + + def _compile_check( + self, + fn, + inps, + *, + dynamic=True, + fullgraph=True, + call_backward=False, + ): + def call_backward_fn(t): + if t.is_nested: + from torch.nested._internal.nested_tensor import buffer_from_jagged + + t = buffer_from_jagged(t) + return t.sum().backward(retain_graph=True) + + torch.manual_seed(0) + fw_compiler = EagerRecordGraphAndInputs() + bw_compiler = EagerRecordGraphAndInputs() + compiler_fn = aot_autograd( + fw_compiler=make_boxed_compiler(fw_compiler), + bw_compiler=make_boxed_compiler(bw_compiler), + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + ) + + c = torch.compile(backend=compiler_fn, dynamic=dynamic, fullgraph=fullgraph)(fn) + for inp in inps: + expected = fn(*inp) + # reset the seed for randn to generate the same tensor + torch.manual_seed(0) + got = c(*inp) + self.assertEqualIgnoringNestedInts(expected, got) + + if call_backward: + re = pytree.tree_map_only( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, + call_backward_fn, + expected, + ) + rg = pytree.tree_map_only( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, + call_backward_fn, + got, + ) + self.assertEqualIgnoringNestedInts(re, rg) + + if call_backward: + return fw_compiler.graphs, bw_compiler.graphs + return fw_compiler.graphs, None + + def test_tensor_subclass_TwoTensor_simple(self): + def f(tt): + return tt * tt.size()[0] + + a = torch.ones(3, 4, requires_grad=True) + b = a.detach().clone().requires_grad_(True) + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_2: "Sym(s16)", # PlainAOTInput(idx=1) + primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None + mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None + return ( + mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + mul_3, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_1, # SavedForBackwardsAOTOutput(idx=0) + primals_5, # SavedForBackwardsAOTOutput(idx=1) + primals_7, # SavedForBackwardsAOTOutput(idx=2) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None + mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None + return ( + None, # None + None, # None + mul_8, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + mul_9, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_clone_view(self): + def f(tt): + y = tt.clone() + return y.view(y.shape[1], y.shape[0]) + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_2: "Sym(s16)", # PlainAOTInput(idx=1) + primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None + view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None + return ( + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + primals_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_5, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + primals_7, # SavedForBackwardsAOTOutput(idx=1) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return ( + None, # None + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_mul(self): + def f(tt, a, b): + s0, s1 = a.size() + s2, s3 = b.size() + # return tt * a.size()[1] + return tt * s0 * s1 * s2 * s3 + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt, a, b)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s97)", # PlainAOTInput(idx=0) + primals_2: "Sym(s98)", # PlainAOTInput(idx=1) + primals_3: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s98)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None + mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None + mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None + mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None + mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None + mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None + mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None + mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None + return ( + mul_24, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + mul_27, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_1, # SavedForBackwardsAOTOutput(idx=0) + primals_2, # SavedForBackwardsAOTOutput(idx=1) + primals_5, # SavedForBackwardsAOTOutput(idx=2) + primals_7, # SavedForBackwardsAOTOutput(idx=3) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s97)", # PlainAOTInput(idx=0) + primals_2: "Sym(s98)", # PlainAOTInput(idx=1) + primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None + mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None + mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None + mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None + mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None + mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None + mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None + mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None + return ( + None, # None + None, # None + mul_38, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + mul_39, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_view(self): + def f(tt): + y = tt.clone() + return y.view(y.shape[0], y.shape[1]) + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_2: "Sym(s16)", # PlainAOTInput(idx=1) + primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None + view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None + return ( + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + primals_7, # SavedForBackwardsAOTOutput(idx=1) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return ( + None, # None + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_view_mul(self): + def f(tt): + y = tt.clone() + return y.view(y.shape[0] * y.shape[1]) + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_2: "Sym(s16)", # PlainAOTInput(idx=1) + primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None + view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None + view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None + return ( + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + primals_7, # SavedForBackwardsAOTOutput(idx=1) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return ( + None, # None + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_return_tensor_and_subclass(self): + def f(tt): + y = tt.clone() + return y.a, y.view(y.shape[0] * y.shape[1]) + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s47)", # PlainAOTInput(idx=0) + primals_2: "Sym(s16)", # PlainAOTInput(idx=1) + primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a') + primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b') + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + ): + clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None + view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]) + view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None + return ( + clone, # PlainAOTOutput(idx=0) + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') + mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + primals_7, # SavedForBackwardsAOTOutput(idx=1) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0) + primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) + tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a') + tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b') + ): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return ( + None, # None + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) + primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) + ) +""", + ) + + @unittest.expectedFailure + def test_tensor_subclass_TwoTensor_return_multiple(self): + def f(tt): + y = tt.clone() + z = tt.clone() + return y.a, y.view(y.shape[0] * y.shape[1]), y.b, z.view(-1) + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]", primals_3: "Sym(3)", primals_4: "Sym(4)", primals_5: "Sym(3)", primals_6: "Sym(4)"): + clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None + clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + + mul: "Sym(12)" = primals_5 * primals_6 + view: "f32[12]" = torch.ops.aten.view.default(clone, [mul]) + view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [mul]); clone_1 = None + return [clone, view, view_1, mul, primals_5, primals_6] +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(3)", primals_6: "Sym(4)", tangents_1: "f32[12]", tangents_2: "f32[12]"): + view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_6]); tangents_1 = None + view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_6]); tangents_2 = primals_5 = primals_6 = None + return [view_2, view_3, None, None] +""", + ) + + def test_tensor_subclass_TwoTensor_automatic_dynamic_shapes(self): + def f(tt): + y = tt.clone() + return y.a, y.view(-1), y.b + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt1 = TwoTensor(a, b) + + a = torch.ones(3, 5, requires_grad=True) + b = a.clone() + tt2 = TwoTensor(a, b) + + fw, bw = self._compile_check( + f, [(tt1,), (tt2,)], dynamic=None, call_backward=True + ) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a') + primals_2: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b') + ): + clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None + clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + + view: "f32[12]" = torch.ops.aten.view.default(clone, [-1]) + view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [-1]) + return ( + clone, # PlainAOTOutput(idx=0) + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') + clone_1, # PlainAOTOutput(idx=2) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(fw[1].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s16)", # PlainAOTInput(idx=0) + primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a') + primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b') + primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1) + primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) + ): + clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + + view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) + sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) + view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) + return ( + clone, # PlainAOTOutput(idx=0) + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') + sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0) + clone_1, # PlainAOTOutput(idx=2) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + tangents_1: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a') + tangents_2: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b') + ): + view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [3, 4]); tangents_1 = None + view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [3, 4]); tangents_2 = None + return ( + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[1].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) + tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a') + tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b') + ): + view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + return ( + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) + primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_mark_dynamic_shapes(self): + def f(tt): + y = tt.clone() + return y.a, y.view(-1), y.b + + a = torch.ones(3, 4, requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + torch._dynamo.mark_dynamic(tt, 1) + + fw, bw = self._compile_check( + f, + [ + (tt,), + ], + dynamic=None, + call_backward=True, + ) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s16)", # PlainAOTInput(idx=0) + primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a') + primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b') + primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1) + primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) + ): + clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + + view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) + sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) + view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) + return ( + clone, # PlainAOTOutput(idx=0) + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') + sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0) + clone_1, # PlainAOTOutput(idx=2) + primals_5, # SavedForBackwardsAOTOutput(idx=0) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) + tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a') + tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b') + ): + view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + return ( + None, # None + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b') + primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) + primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) + ) +""", + ) + + def test_tensor_subclass_TwoTensor_different_shape(self): + def f(tt): + y = tt.clone() + return y.view(3, 2, 4) + + a = torch.ones((2 * 4 * 3), requires_grad=True) + b = a.clone() + tt = TwoTensor(a, b) + + fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a') + primals_2: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b') + ): + clone: "f32[24]" = torch.ops.aten.clone.default(primals_1); primals_1 = None + clone_1: "f32[24]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + + view: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone, [3, 2, 4]); clone = None + view_1: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone_1, [3, 2, 4]); clone_1 = None + return ( + view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') + view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + tangents_1: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') + tangents_2: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') + ): + view_2: "f32[24]" = torch.ops.aten.view.default(tangents_1, [24]); tangents_1 = None + view_3: "f32[24]" = torch.ops.aten.view.default(tangents_2, [24]); tangents_2 = None + return ( + view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') + view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') + ) +""", + ) + + def test_tensor_subclass_TwoTensor_return_shape(self): + @torch.compile(backend="aot_eager", dynamic=True) + def fn(x): + return x.clone().view(x.shape[0] * x.shape[1]) + + a = torch.ones(2, 3) + b = a.clone() + tt = TwoTensor(a, b) + out = fn(tt) + self.assertEqual(tt.view(2 * 3), out) + self.assertEqual(out.shape, (6,)) + + def test_tensor_subclass_TwoTensor_nested(self): + @torch.compile(backend="aot_eager", dynamic=True) + def f(x, i, y): + out1 = x.sin() + i.sin() + y.sin() + val1 = x.shape[0] * i.shape[1] * y.shape[0] + return out1 * val1 + + i = torch.randn(2, 2, requires_grad=True) + x = TwoTensor(i, i.clone()) + y = TwoTensor(x.clone(), x.clone()) + + out = f(x, i, y) + + x_test = x.detach().clone().requires_grad_(True) + i_test = i.detach().clone().requires_grad_(True) + y_test = y.detach().clone().requires_grad_(True) + + out_test = f(x_test, i_test, y_test) + torch.allclose(out, out_test) + + out.sum().backward() + out_test.sum().backward() + torch.allclose(x.grad, x_test.grad) + torch.allclose(i.grad, i_test.grad) + torch.allclose(y.grad, y_test.grad) + + def test_subclass_TwoTensor_TwoTensor_TwoTensor(self): + @torch.compile(backend="aot_eager", dynamic=True) + def f(x): + return x.sin() + + data = torch.randn(2, 3) + s = TwoTensor(data, data.clone()) + y = TwoTensor(s, s.clone()) + z = TwoTensor(s, y) + out = f(z) + self.assertEqual(out, z.sin()) + + def test_subclass_TwoTensor_nested_diff_sizes(self): + class TT(TwoTensor): + @staticmethod + def __new__(cls, a, b, outer_size=None, outer_stride=None): + if outer_size is None: + outer_size = a.size() + if outer_stride is None: + outer_stride = a.stride() + + assert ( # noqa: S101 + a.device == b.device + and a.layout == b.layout + and a.requires_grad == b.requires_grad + and a.dtype == b.dtype + ) + shape = outer_size + kwargs = {} + kwargs["strides"] = outer_stride + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None # noqa: S101 + a, b = inner_tensors["a"], inner_tensors["b"] + if type(a) is torch.Tensor: + assert outer_size is not None # noqa: S101 + assert outer_stride is not None # noqa: S101 + return TT(a, b, outer_size, outer_stride) + + @torch.compile(dynamic=True, backend="eager") + def f(x, y): + tmp1 = x.sin() + tmp2 = y.sin() + return tmp1.sum(), tmp2.sum() + + x = TT( + TT( + torch.randn(3, 4), + torch.randn(5, 6, 7), + ), + TT( + torch.randn(4), + torch.randn(2, 3), + ), + ) + + y = TT( + torch.randn(2, 3, 4, 5), + TT( + torch.randn(3, 4), + torch.randn(5), + ), + ) + + out = f(x, y) + self.assertEqual(out, (x.sin().sum(), y.sin().sum())) + + def test_njt_subclass_simple(self): + def f(nt): + y = nt.clone() + return y * y.size(0) + + nt, _ = get_jagged_tensor(((2, 3, 4), 5), None, True) + + fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s51)", # PlainAOTInput(idx=0) + primals_2: "Sym(s71)", # PlainAOTInput(idx=1) + primals_3: "Sym(s55)", # PlainAOTInput(idx=2) + primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values') + primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets') + primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor') + primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor') + primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) + primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2) + primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) + ): + clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None + return ( + mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') + primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') + primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor') + primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor') + primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) + primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_1, # SavedForBackwardsAOTOutput(idx=0) + primals_8, # SavedForBackwardsAOTOutput(idx=1) + primals_10, # SavedForBackwardsAOTOutput(idx=2) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s51)", # PlainAOTInput(idx=0) + primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) + primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) + tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values') + tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets') + tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor') + tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor') + ): + mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None + return ( + None, # None + None, # None + None, # None + mul_1, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values') + tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets') + tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor') + tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor') + primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0) + primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) + primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) + ) +""", + ) + + def test_njt_subclass_from_cat(self): + # create from an existing NJT + def f(nt): + y = nt.clone() + z = torch.cat([y, y], dim=-1) + return z + + nt, _ = get_jagged_tensor(((2, 3, 4), 5), None, True) + + fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_1: "Sym(s51)", # PlainAOTInput(idx=0) + primals_2: "Sym(s71)", # PlainAOTInput(idx=1) + primals_3: "Sym(s55)", # PlainAOTInput(idx=2) + primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values') + primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets') + primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor') + primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor') + primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) + primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2) + primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) + ): + clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None + add_2: "Sym(2*s55)" = primals_10 + primals_10 + return ( + cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') + primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') + primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor') + primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor') + primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) + add_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) + add_2, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + primals_8, # SavedForBackwardsAOTOutput(idx=0) + primals_10, # SavedForBackwardsAOTOutput(idx=1) + add_2, # SavedForBackwardsAOTOutput(idx=2) + ) +""", + ) + + self.assertExpectedInline( + normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class GraphModule(torch.nn.Module): + def forward( + self, + primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) + primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) + add_2: "Sym(2*s55)", + tangents_1: "f64[s64, 2*s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values') + tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets') + tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor') + tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor') + ): + slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) + slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None + add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None + return ( + None, # None + None, # None + None, # None + add_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values') + tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets') + tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor') + tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor') + primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0) + primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) + primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) + ) +""", + ) + + def test_njt_subclass_from_buffer(self): + # create the NJT from a buffer(?) + def f(nt): + nested_size = ((2, 3, 4), 5) + offsets = None + nt2, _ = get_jagged_tensor(nested_size, offsets, requires_grad=False) + nt3 = torch.cat([nt2, nt], dim=-1) + return nt3.sin() * nt3.size(0) + + nested_size = ((2, 3, 4), 5) + offsets = None + nt, _ = get_jagged_tensor(nested_size, offsets, requires_grad=False) + + fw, _ = self._compile_check( + f, + [(nt,)], + dynamic=True, + call_backward=False, # we cannot set requires_grad=True inside a compile region + ) + + self.assertExpectedInline( + normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), + """\ +class (torch.nn.Module): + def forward( + self, + arg0_1: "Sym(s51)", # PlainAOTInput(idx=0) + arg1_1: "Sym(s71)", # PlainAOTInput(idx=1) + arg2_1: "Sym(s55)", # PlainAOTInput(idx=2) + arg3_1: "f64[9, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values') + arg4_1: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets') + arg5_1: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor') + arg6_1: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor') + arg7_1: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) + arg8_1: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2) + arg9_1: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) + ): + randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) + randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) + randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) + + cat: "f64[9, 5]" = torch.ops.aten.cat.default([randn, randn_1, randn_2]); randn = randn_1 = randn_2 = None + zeros: "i64[1]" = torch.ops.aten.zeros.default([1], dtype = torch.int64, device = device(type='cpu'), pin_memory = False) + _tensor_constant0: "i64[3]" = self._tensor_constant0 + lift_fresh_copy: "i64[3]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None + cumsum: "i64[3]" = torch.ops.aten.cumsum.default(lift_fresh_copy, 0); lift_fresh_copy = None + cat_1: "i64[4]" = torch.ops.aten.cat.default([zeros, cumsum]); zeros = cumsum = None + zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False) + zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False) + + cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None + + sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2) + mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None + + sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None + sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0) + return ( + mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') + cat_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') + zeros_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor') + zeros_2, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor') + sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) + sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) + ) +""", + ) + + def test_deferred_init_subclass_init_not_traced(self): + """ + Tracing a function that constructs a DeferredInitSubclass must not crash. + + The bug: when Dynamo's frame hook intercepts __init__ as a root frame, + self is partially initialised (attributes not yet set by __init__). + Previously, wrap_tensor would call __tensor_flatten__ on this + partially-initialised self and raise AttributeError. + + The fix skips tracing __init__ of traceable wrapper subclasses at the + frame level (convert_frame.py), so __init__ runs eagerly like + @torch._disable_dynamo would. + """ + # Compile __init__ directly, simulating the root-frame interception + # scenario that occurs in practice (e.g. Diffusers + TorchAO + Dynamo). + compiled_init = torch.compile( + DeferredInitSubclass.__init__, backend="eager", fullgraph=False + ) + data = torch.randn(4, 4) + shell = DeferredInitSubclass.__new__(DeferredInitSubclass, data, 2.0) + + # Should not raise AttributeError from __tensor_flatten__ on partial self + compiled_init(shell, data, 2.0) + + self.assertEqual(shell._data, data) + self.assertEqual(shell._scale, 2.0) + + +instantiate_parametrized_tests(SubclassTests) + + +class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): + def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): + return get_jagged_tensor(nested_size, offsets, requires_grad) + + def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): + # Makes a jagged tensor with N constituent tensors with size + # as specified ((S0, S1, S2), D) + max_dim = (starts + lengths).max() + values_tensor = torch.randn( + starts.shape[0], + max_dim.item(), + inner_dim, + requires_grad=requires_grad, + dtype=torch.float64, + ) + return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) + + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): + _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) + + def test_unary_does_not_recompile(self): + nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) + nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) + self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) + + def test_binary_does_not_recompile(self): + def binary(nt1, nt2): + if nt1.shape == nt2.shape: + return nt1 + nt2 + else: + return nt1.sin() + + # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). + # This causes a recompile later on when it realizes the batch and last dim + # should not always be equal. To avoid that, we use (3, j0, 5) here. + nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) + nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) + nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) + self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) + + def test_binary_recompiles(self): + def binary(nt1, nt2): + if nt1.shape == nt2.shape: + return nt1 + nt2 + else: + return nt1.sin() + + # Binary recompiles because singleton ints no longer match + nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) + nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) + + def _validate_compile(self, fn, arg_fn): + def _gen_grad_outputs(out_val): + if isinstance(out_val, (list, tuple)): + return tuple(torch.ones_like(c) for c in out_val) + else: + return (torch.ones_like(out_val),) + + with self.branch_nested_state(): + from torch.nested._internal.nested_tensor import _tensor_symint_registry + + # Validate that compilation does not modify eager state + registry_before = list(_tensor_symint_registry.items()) + count_before = torch.nested._internal.nested_tensor._tensor_id_counter + + guards_exported = [] + guards_failed = [] + + def append_guard_export(guards): + for g in guards: + if g.code_list is not None: + guards_exported.append(g.code_list[0]) + + def append_guard_fail(guards): + guards_failed.extend(guards) + + compiled = torch._dynamo.optimize( + nopython=True, + backend="aot_eager", + guard_export_fn=append_guard_export, + guard_fail_fn=append_guard_fail, + )(fn) + registry_after = list(_tensor_symint_registry.items()) + count_after = torch.nested._internal.nested_tensor._tensor_id_counter + self.assertEqual(registry_before, registry_after) + self.assertEqual(count_before, count_after) + + args = arg_fn() + compile_out = compiled(*args) + compile_grads = [] + g_args = [arg for arg in args if arg.requires_grad] + if len(g_args) > 0: + compile_grad_outputs = _gen_grad_outputs(compile_out) + compile_grads = torch.autograd.grad( + compile_out, inputs=g_args, grad_outputs=compile_grad_outputs + ) + + with self.branch_nested_state(): + args = arg_fn() + ref_out = fn(*args) + ref_grads = [] + g_args = [arg for arg in args if arg.requires_grad] + if len(g_args) > 0: + ref_grad_outputs = _gen_grad_outputs(ref_out) + ref_grads = torch.autograd.grad( + ref_out, inputs=g_args, grad_outputs=ref_grad_outputs + ) + + # Validate correctness forward + if isinstance(compile_out, (list, tuple)): + # TODO: Fix assertEqual() to support NJTs so this isn't necessary + self.assertEqual(len(compile_out), len(ref_out)) + for c, r in zip(compile_out, ref_out): + self.assertEqualIgnoringNestedInts(c, r) + else: + self.assertEqualIgnoringNestedInts(compile_out, ref_out) + + # Validate correctness backward + for compile_grad, ref_grad in zip(compile_grads, ref_grads): + self.assertEqualIgnoringNestedInts(compile_grad, ref_grad) + + return guards_exported, guards_failed + + def test_in_graph_is_nested_call(self): + def f(nt): + if nt.is_nested: + return nt + 2 + else: + return nt + 1 + + cnt = CompileCounterWithBackend("aot_eager") + compiled_f = torch.compile(f, backend=cnt, fullgraph=True) + nt, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + output = compiled_f(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(len(cnt.graphs), 1) + graph = cnt.graphs[0] + norm_graph = normalize_gm(graph.print_readable(print_output=False)) + + # expect -no- is_nested calls within the graph + self.assertExpectedInline( + norm_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s71: "Sym(s71)", L_nt_: "NestedTensor(f64[3, s71, 5])"): + l_nt_ = L_nt_ + + add: "NestedTensor(f64[3, s71, 5])" = l_nt_ + 2; l_nt_ = None + return (add,) +""", + ) + + # Note: [What kind of guards are involved in nested tensor compilation] + # + # Until we implement UnionFind, dynamic shapes guards are not involved. + # we rely only on dynamo's tensor aliasing guards. + # + # This is possible because dynamo able to generate tensor aliasing guards + # not only for the outer tensor, but also for the inner tensor. + # + # The case where dynamic shapes guards would eventually come into play is + # when my inputs are (1) two non-aliased tensors, but (2) declared as + # equal using a "trust me assert equal" API. + + # Note: [Compiling nested tensor global state] + # + # Today there are two pieces of global eager state that NJTs deals with: + # - tensor_id_counter: a global counter that assigns unique ids to tensors + # - tensor_symint_registry: maps tensor to nested int + # - this is used in eager only (we should get rid of this because it is + # not necessary to cache nested int in eager) + # - during tracing, we DO need to cache nested int, but we do so on + # the FakeTensor. + # + # Ideally we would like to satisfy the following: + # - (1) The eager state is not mutated during tracing + # - (2) Running the compiled function should mutate the eager state in the + # same way that running the eager function would + # (a) The global counter should be incremented + # (b) The registry is updated in the same way + # + # Today we can satisfy (1) and (2a) but cannot satisfy (2b) + # + # Today, (1) is satisfied because we maintain a separate counter during + # tracing, and cache nested int on FakeTensor instead of relying on + # tensor_symint_registry. + # + # (2) is cannot be completely satisfied because we trace away the + # side-effectful operations (which we can fix this by wrapping the + # side-effectful operations in a custom op, and threading through effect + # tokens.) The current plan is to do that in the UnionFind impl. + # + # Interestingly, despite this, the state is mutated in a way that is somewhat + # close to what we want, e.g. if I construct a nested tensor using an + # offsets in the compiled region and return it, AOTAutograd runtime wrapper + # must rewrap the inner->inner graph outputs back into subclass. This + # triggers the eager logic to run, updating the counter and registry. + # + # Notably however, compile differs in two ways from eager: + # (1) The order in which the offsets are assigned ids is different + # the registry would be set in the order the offsets are returned + # which is not necessarily the same order as they were constructed. + # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping + # logic will not be triggered. + # + # I claim that correctness is not affected by these differences today. + # e.g. there is never the case where two distinct offsets silently share + # the same id. + # + # (1) is clearly not a problem, and (2) should only be a problem if + # the nested int is returned on its own, without the corresponding NJT + # being returned. This is not a problem in the current implementation + # because returning only a shape is not supported! + + # Note: [Creating symbolic nested int] + # + # We must create a symbolic nested int when we construct a nested tensor + # from a tensor. There are two main cases: + # + # 1. The offsets has NOT been used to construct a NJT + # - Create a new plain nested int with current val of fake nt id counter + # - Increment the fake nt id counter + # - Create a new symint with plain nested int as hint + # 2. The offsets HAS been used to construct a NJT + # - Create a new symint with plain nested int as hint + # + # More details on case 2: + # - During fakification of the offsets, we check the eager registry, and + # if the tensor HAS been used to construct a NJT, + # we create a symint, with the existing nested int as hint, and cache + # it on to the FakeTensor. + # + # [ Always use ephemeral source ] + # + # We create the new symint ALWAYS with ephemeral source whether that is + # in case (1) or (2) even though we could've had a proper source for case (2). + # Using a proper source would enable a few more (edge) cases, but since + # we plan to handle things more holistically in the future anyway, we don't + # bother doing so today. + # + # Using an ephemeral source has some consequences. But we are happy if + # - We do not silently miss recompiles, e.g. we guard when necessary. + # We know that this is true, because dynamo guards alone are already + # sufficient. + # - We are not producing errors for the cases we care about + # + # The main case we care about is when we guard that two shapes are equal. + # In this case, the replacements logic would simplify away the ephemeral + # symbol, and there is no error produced. + # The unsupported case is when we guard that two shapes are not equal, in + # which, we will try and fail to generate a guard. + + # + # Case 1: in-graph construction where the offsets are passed as inputs + # + def test_in_graph_construction_from_input(self): + # The offsets is passed as an input + def fn(values, offsets): + return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2 + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + self._validate_compile(fn, arg_fn=lambda: (values, offsets)) + + # Do not specialize on the offsets + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64) + self._validate_compile(fn, arg_fn=lambda: (values, different_offsets)) + + def test_in_graph_construction_from_input_2(self): + # Construct two NJTs, both are passed as inputs + def fn(values, offsets1, offsets2): + nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1) + nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) + return nt2, nt1 + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) + # 1. Offsets are different + guards_exported, guards_failed = self._validate_compile( + fn, arg_fn=lambda: (values, offsets, offsets2) + ) + self.assertEqual(len(guards_failed), 0) + self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported) + + # TODO + # 2. Offsets are the same + new_guards_exported, _ = self._validate_compile( + fn, arg_fn=lambda: (values, offsets, offsets) + ) + self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed)) + self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported) + + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + offsets3 = offsets.clone() + self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3)) + + # Do a binary op + def fn(values, offsets, offsets2): + nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) + nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) + return nt1 * nt2 + + self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets)) + + def test_in_graph_construction_from_input_4(self): + # The offsets is taken from an NJT input + def fn(nt, other_values): + nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets()) + return nt + nt2 + + values = torch.randn(9, 5, requires_grad=True) + other_values = torch.randn(9, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) + + def arg_fn(values=values, other_values=other_values, offsets=offsets): + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + return nt, other_values + + self._validate_compile(fn, arg_fn=arg_fn) + + # Do not specialize on the offsets + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + different_offsets = offsets.clone() + + def arg_fn( + values=values, other_values=other_values, offsets=different_offsets + ): + nt = torch.nested.nested_tensor_from_jagged(values, different_offsets) + return nt, other_values + + self._validate_compile(fn, arg_fn=arg_fn) + + def test_in_graph_construction_from_input_5(self): + # Construct from lengths instead of offsets + def fn(values, lengths): + nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) + return nt.sin() + + values = torch.randn(9, 5, requires_grad=True) + lengths = torch.tensor([2, 4, 3]) + self._validate_compile(fn, arg_fn=lambda: (values, lengths)) + + def test_in_graph_construction_from_input_6(self): + # Construct with symbolic int. + def fn(values, offsets, max_seqlen): + t = torch.nested.nested_tensor_from_jagged( + values, offsets, max_seqlen=max_seqlen + ) + return torch.nested.nested_tensor_from_jagged( + values, t.offsets(), max_seqlen=t._maybe_max_seqlen + ) + + opt_fn = torch.compile(fn, fullgraph=True, dynamic=True, backend="eager") + values = torch.randn(10, 5) + offsets = torch.tensor([0, 2, 4, 7, 10]) + max_seqlen = 5 + + ref = fn(values, offsets, max_seqlen) + res = opt_fn(values, offsets, max_seqlen) + self.assertEqualIgnoringNestedInts(ref, res) + + # + # Case 2: in-graph construction where offsets are graph intermediates + # + def test_in_graph_construction_from_intermediate(self): + # offsets is an intermediate computed from lengths + def fn(values, lengths): + offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]) + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) + return (nt * nt2).sin() + + values = torch.randn(9, 5, requires_grad=True) + lengths = torch.tensor([2, 4, 3]) + self._validate_compile(fn, arg_fn=lambda: (values, lengths)) + + # Do not specialize on the lengths + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + different_lengths = lengths.clone() + self._validate_compile(fn, arg_fn=lambda: (values, different_lengths)) + + def test_in_graph_construction_from_intermediate_2(self): + def fn(values, offsets): + return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + self._validate_compile(fn, arg_fn=lambda: (values, offsets)) + + def test_in_graph_construction_from_intermediate_3(self): + # Note that due to CSE, clone is not necessarily called twice! + def fn(values, offsets): + nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) + nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone()) + return nt2, nt1 + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + self._validate_compile(fn, arg_fn=lambda: (values, offsets)) + + def test_in_graph_construction_from_intermediate_4(self): + # Shared intermediate (should be same as case #1) + def fn(values): + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + values2 = torch.ones_like(values) + nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets) + return nt * nt2 + + values = torch.randn(10, 5).requires_grad_(True) + self._validate_compile(fn, arg_fn=lambda: (values,)) + + # AssertionError: s2 (could be from ['', + @unittest.expectedFailure + def test_in_graph_construction_from_intermediate_5(self): + # non-shared intermediate + def fn(values): + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + values2 = torch.ones_like(values) + nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone()) + if nt2.shape[1] != nt.shape[1]: + return nt * 2 + else: + return nt * 3 + + values = torch.randn(10, 5).requires_grad_(True) + self._validate_compile(fn, arg_fn=lambda: (values,)) + + # + # Case 3: in-graph construction where offsets are both direct graph inputs + # and passed in as part of an NJT's offsets. + # + def test_in_graph_construction_mixed(self): + def fn(nt, values, offsets): + nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) + return nt * nt2 + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + + def arg_fn(values=values, offsets=offsets): + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + return nt, values, offsets + + self._validate_compile(fn, arg_fn) + + # See Note: [Creating symbolic nested int] + # AssertionError: s2 (could be from ['', + @unittest.expectedFailure + def test_in_graph_construction_mixed_2(self): + def fn(nt, values, offsets, nt2): + # Intermediate offsets has ephemeral source + intermediate_nt = torch.nested.nested_tensor_from_jagged( + values, offsets.clone() + ) + # This creates a dynamic shapes neq guard + if nt2.shape[1] != intermediate_nt.shape[1]: + # We should always go here. + nt = nt * 2 + return nt + + values = torch.randn(10, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) + offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) + + def arg_fn(values=values, offsets=offsets, offsets2=offsets2): + # Values is shared, but it shouldn't matter + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2) + return nt, values, offsets, nt2 + + self._validate_compile(fn, arg_fn) + + def test_in_graph_construction_mixed_3(self): + # More involved mixed case + def fn(nt, values, offsets): + nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) + nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets) + return nt1 + nt2 + nt + + values = torch.randn(9, 5, requires_grad=True) + offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) + + def arg_fn(values=values, offsets=offsets): + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + return nt, values, offsets + + self._validate_compile(fn, arg_fn) + + def test_return_shape(self): + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(nt): + return (nt * 2).shape + + compiled = torch.compile(fn, fullgraph=True, backend="aot_eager") + compiled(nt) + + def test_inference_tensor(self): + with torch.inference_mode(): + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(n): + return n * 2 + + torch.compile(fn, backend="eager")(nt) + + # TODO: cannot parametrize this test class with device for some reason + def _test_autograd(self, backend): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + # TODO: Switch to public API when it exists + nt2, _ = jagged_from_list([a, b, c], nt.offsets()) + + def fn1(nt1, nt2): + return (nt1 + nt2).sin().cos() + + compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) + out = compiled_f(nt, nt2) + out_buffer = out.values() + ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) + + out_ref = fn1(nt, nt2) + out_buffer_ref = out_ref.values() + ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) + + self.assertTrue(torch.allclose(ga, ga_ref)) + self.assertTrue(torch.allclose(gb, gb_ref)) + self.assertTrue(torch.allclose(gc, gc_ref)) + + def test_basic_autograd(self): + self._test_autograd("aot_eager") + + @requires_gpu_and_triton + def test_basic_autograd_inductor(self): + self._test_autograd("inductor") + + def test_subclass_with_mutation_in_graph(self): + # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed + # to remain in the graph. Normally this is allowed, but it's not allowed if + # the graph handles subclasses at all. + # Whether the mutation is allowed or not allowed in the graph alters the number + # of outputs from the forward graph. Previously, a bug in this handling meant + # that sometimes the expected number and actual number of outputs from the + # joint graph did not match, causing assertion failures. + def fn(x, y): + z = x.sin() + y.sin_() + return z.cos(), y.cos() + + fn_c = torch.compile(fn, backend="inductor") + + values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] + values_copy = [x.detach().clone().requires_grad_(True) for x in values] + + nt, offsets = jagged_from_list(values, None) + nt_copy, offsets = jagged_from_list(values_copy, offsets) + y = torch.rand((4, 8)) + y_copy = y.clone() + + ret = fn_c(nt, y)[0] + ref = fn(nt_copy, y_copy)[0] + + self.assertEqual(ret.values(), ref.values()) + + ret.values().sum().backward() + ref.values().sum().backward() + for ref_v, res_v in zip(values_copy, values): + self.assertEqual(ref_v.grad, res_v.grad) + + @torch._dynamo.config.patch({"capture_scalar_outputs": True}) + def test_unbind(self): + # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). + # This causes a recompile later on when it realizes the batch and last dim + # should not always be equal. To avoid that, we use (3, j0, 5) here. + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) + nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) + + def fn(x): + return x.unbind() + + compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) + out = compiled_f(nt) + + out_ref = fn(nt) + + # correctness + self.assertEqual(len(out), len(out_ref)) + for x, x_ref in zip(out, out_ref): + self.assertTrue(torch.allclose(x, x_ref)) + + # We specialize on the length of offsets, e.g. (1) we recompile if the + # length of the offsets is different. (2) we don't recompile if the + # length of the offsets is the same, even if the size of the constituent + # tensors are different. + self._check_recompiles(fn, (nt,), (nt2,), False) + self._check_recompiles(fn, (nt,), (nt3,), True) + + def test_inline_nested_tensor_from_jagged(self): + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(x): + return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) + + torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) + + # The test here: nn.Parameters that are secretly subclasses + # have a metaclass that overrides __isinstance__, + # that dynamo needs to respect when it inlines the if statement. + def test_param_subclass_isinstance_input(self): + x_inner = torch.randn(16, 16, requires_grad=True) + x = torch.nn.Parameter(TwoTensor(x_inner, x_inner)) + m = torch.nn.Linear(16, 16) + m.weight = x + + def fn(): + if isinstance(m.weight, torch.nn.Parameter): + return m.weight + 1 + else: + return m.weight + 2 + + out_ref = fn() + out_test = torch.compile(fn, backend="aot_eager")() + self.assertEqual(out_ref, out_test) + + def test_buffer_subclass_isinstance_input(self): + from torch.nn.parameter import Buffer + + buf = Buffer(torch.ones(5)) + + def fn(b): + if isinstance(b, torch.nn.Buffer): + return b + 1 + else: + return b + 2 + + out_ref = fn(buf) + out_test = torch.compile(fn, backend="aot_eager", fullgraph=True)(buf) + self.assertEqual(out_ref, out_test) + + torch._dynamo.reset() + + tensor = torch.ones(5) + out_ref_tensor = fn(tensor) + out_test_tensor = torch.compile(fn, backend="aot_eager", fullgraph=True)(tensor) + self.assertEqual(out_ref_tensor, out_test_tensor) + + def test_buffer_subclass_check(self): + from torch.nn.parameter import Buffer + + def check_buffer(x): + return isinstance(x, torch.nn.Buffer) + + buf = Buffer(torch.ones(5)) + compiled_fn = torch.compile(check_buffer, fullgraph=True, backend="inductor") + self.assertTrue(compiled_fn(buf)) + + torch._dynamo.reset() + + tensor = torch.ones(5) + compiled_fn = torch.compile(check_buffer, fullgraph=True, backend="inductor") + self.assertFalse(compiled_fn(tensor)) + + def _input_view_test(self, nt_view_name): + nt_view = VIEW_TEST_CASES[nt_view_name]() + + def fn(x): + return x.sin() + + out_ref = fn(nt_view) + torch._dynamo.reset() + compile_fn = torch.compile( + fn, fullgraph=True, backend="aot_eager", dynamic=True + ) + out = compile_fn(nt_view) + + # Check metadata and values are correct + self.assertTrue(out.size() == out_ref.size()) + self.assertTrue(out.stride() == out_ref.stride()) + if out.is_nested: + self.assertTrue(torch.allclose(out.values(), out_ref.values())) + else: + self.assertTrue(torch.allclose(out, out_ref)) + + # Check that no upper/lower bound guards are incurred + def backend(gm, args): + context = torch._guards.TracingContext.get() + guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] + + # varies based on the type of view + guard_str = "\n".join(guards) + + if nt_view_name == "base_is_nt_False_basic": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s20, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_False_leaf_False_False": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_False_leaf_False_True": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s20, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_False_leaf_True_False": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s20, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_False_leaf_True_True": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s20, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_False_obscure": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s64) +Eq(s20, s64) +Eq(s80 - 1, s77) +Eq(s72, s71)""", + ) + elif nt_view_name == "base_is_nt_True_basic": + self.assertExpectedInline( + guard_str, + """\ +Eq(s17 - 1, s83) +Eq(s20, s83)""", + ) + elif nt_view_name == "base_is_nt_True_leaf_False_False": + self.assertExpectedInline( + guard_str, + """Eq(s17 - 1, s83)""", + ) + elif nt_view_name == "base_is_nt_True_leaf_False_True": + self.assertExpectedInline( + guard_str, + """\ +Eq(s17 - 1, s83) +Eq(s20, s83)""", + ) + elif nt_view_name == "base_is_nt_True_leaf_True_False": + self.assertExpectedInline( + guard_str, + """\ +Eq(s17 - 1, s83) +Eq(s20, s83)""", + ) + elif nt_view_name == "base_is_nt_True_leaf_True_True": + self.assertExpectedInline( + guard_str, + """\ +Eq(s17 - 1, s83) +Eq(s20, s83)""", + ) + elif nt_view_name == "base_is_nt_True_obscure": + self.assertExpectedInline( + guard_str, + """\ +Eq(s17 - 1, s83) +Eq(s20, s83)""", + ) + elif nt_view_name == "dense_subclass_dense_subclass": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s77) +Eq(s80 - 1, s78) +Eq(s72, s71)""", + ) + elif nt_view_name == "subclass_dense": + self.assertExpectedInline( + guard_str, + """\ +Eq(s85 - 1, s77) +Eq(s20, s77)""", + ) + else: + raise NotImplementedError + return gm + + torch._dynamo.reset() + compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) + out = compile_fn(nt_view) + + @parametrize( + "nt_view_name", + [k for k in VIEW_TEST_CASES if k != "subclass_dense_subclass_dense"], + ) + def test_inputs_to_compiled_fn_are_views(self, nt_view_name): + self._input_view_test(nt_view_name) + + def test_subclass_gives_static_shapes_when_dynamic_false(self): + def check_graph(gm, *args): + first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"] + # We compiled with dynamic=False, expect no SymInt sizes on our placeholders + self.assertTrue( + all(isinstance(x, int) for x in first_node_example_val.shape) + ) + return gm + + @torch.compile(backend=check_graph, dynamic=False) + def f(x): + return x + 1 + + x_inner = torch.ones(4) + x = TwoTensor(x_inner, x_inner) + x_view = x.view(2, 2) + out = f(x_view) # noqa: F841 + + # NJT1 -> Dense -> NJT2 -> Dense view + # During view replay, the Dense -> NJT2 part will construct an intermediate, + # symbolically-sized NJT that is immediately deconstructed to return the final dense + # view. To construct this intermediate properly, we need the associated nested int + # to be symbolic. This view is expected to fail compilation until symbolic nested ints + # are cached onto fake offsets to solve this problem. + @unittest.expectedFailure + def test_subclass_dense_subclass_dense_view(self): + self._input_view_test("subclass_dense_subclass_dense") + + +instantiate_parametrized_tests(TestNestedTensor) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/dynamo/test_trace_rules_xpu.py b/test/xpu/dynamo/test_trace_rules_xpu.py new file mode 100644 index 0000000000..37592546a0 --- /dev/null +++ b/test/xpu/dynamo/test_trace_rules_xpu.py @@ -0,0 +1,562 @@ +# Owner(s): ["module: dynamo"] +import dataclasses +import importlib +import inspect +import math +import types +import unittest +import warnings +from typing import Any + +import torch +import torch._dynamo.config as config +import torch._dynamo.test_case +import torch._functorch.deprecated as deprecated_func +from torch._dynamo.testing import CompileCounter +from torch._dynamo.trace_rules import ( + LEGACY_MOD_INLINELIST, + load_object, + lookup_inner, + manual_torch_name_rule_map, + MOD_INLINELIST, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, +) +from torch._dynamo.utils import hashable, is_safe_constant, istype +from torch._dynamo.variables import ( + SkipFunctionVariable, + TorchInGraphFunctionVariable, + UserFunctionVariable, +) +from torch.testing._internal.common_utils import skipIfWindows + +try: + from .utils import create_dummy_module_and_function +except ImportError: + from utils import create_dummy_module_and_function + + +ignored_c_binding_in_graph_function_names = { + # Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`. + "torch._nested_tensor_from_mask", + "torch._nested_from_padded", + "torch.sparse_compressed_tensor", + "torch.sparse_bsc_tensor", + "torch.sparse_bsr_tensor", + "torch.sparse_coo_tensor", + "torch.sparse_csc_tensor", + "torch.sparse_csr_tensor", + "torch.cuda._get_device_properties", + "torch.xpu._get_device_properties", + # Ignored and go through rules defined at `trace_rules.check`. + "torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode", + "torch._cslt_sparse_mm_search", + "torch._C._abort", + "torch._C._mps_is_on_macos_or_newer", + "torch._C._swap_tensor_impl", + "torch._C._unsafe_reset_storage", + "torch._dynamo.eval_frame.reset_code", + "torch._C.autocast_decrement_nesting", + "torch._C.autocast_increment_nesting", + "torch._C.clear_autocast_cache", + "torch._C.set_anomaly_enabled", + "torch._C.set_autocast_cache_enabled", + "torch._C.set_autocast_cpu_dtype", + "torch._C.set_autocast_cpu_enabled", + "torch._C.set_autocast_enabled", + "torch._C.set_autocast_gpu_dtype", + "torch._C.set_autocast_ipu_dtype", + "torch._C.set_autocast_ipu_enabled", + "torch._C.set_autocast_xla_dtype", + "torch._C.set_autocast_xla_enabled", + "torch.resize_as_", + "torch.resize_as_sparse_", + "torch._C._data_address", + "torch._C._is_cow_tensor", + "torch._lazy_clone", + "torch._test_parallel_materialize", + "torch._C._storage_address", + "torch._C._pickle_save", + "torch._validate_sparse_compressed_tensor_args", + "torch._validate_sparse_csr_tensor_args", + "torch._validate_sparse_bsr_tensor_args", + "torch._validate_sparse_csc_tensor_args", + "torch._validate_sparse_coo_tensor_args", + "torch._validate_sparse_bsc_tensor_args", + "torch._validate_compressed_sparse_indices", +} +if torch._C._llvm_enabled(): + ignored_c_binding_in_graph_function_names |= { + "torch._C._te.set_llvm_aot_workflow", + "torch._C._te.set_llvm_target_cpu", + "torch._C._te.set_llvm_target_attrs", + "torch._C._te.set_llvm_target_triple", + } + + +# Helper function to dump the torch name rule map generated based on +# the heuristic defined in gen_allowed_objs_and_ids. +def dump_allowed_torch_name_rule_map() -> None: + m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map + for k, v in m.items(): + print(f'"{k}": {v.__name__},') + + +@dataclasses.dataclass +class AllowedObjects: + """ + Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs + from the heuristic defined in `gen_allowed_objs_and_ids`. + """ + + object_ids: dict[int, str] + c_binding_in_graph_functions: set[Any] + non_c_binding_in_graph_functions: set[Any] + name_rule_map: dict[str, Any] + + +def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects: + """ + Walk torch.* and get the ids of all the stuff in it + """ + + warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") + torch_object_ids = {} + c_binding_in_graph_functions = set() + non_c_binding_in_graph_functions = set() + torch_name_rule_map = {} + + # In some platforms, these functions were loaded as classes instead of functions. + # To mitigate these weird cases, we need this special check. + def is_special_functions(obj): + return hashable(obj) and obj in { + torch._C._cuda_isCurrentStreamCapturing, + torch._C._graph_pool_handle, + } + + # Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set + # if it's a torch function or method. + # This is used to generate the in graph function list based on heuristic. + def heuristic_record_if_in_graph_function(obj, module, name): + try: + if hasattr(obj, "__wrapped__"): + obj = obj.__wrapped__ + except Exception: + pass + if isinstance( + obj, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + ), + ) or is_special_functions(obj): + torch_name_rule_map[f"{module.__name__}.{name}"] = ( + TorchInGraphFunctionVariable + ) + if c_binding_only: + if not hasattr(obj, "__code__"): + c_binding_in_graph_functions.add(obj) + else: + if hasattr(obj, "__code__"): + non_c_binding_in_graph_functions.add(obj) + else: + c_binding_in_graph_functions.add(obj) + + def _is_allowed_module_prefix(obj): + allowed_modules = ("torch", "math") + # torch.nn.modules.rnn is disallowed because these modules internally + # flatten their parameters. This flattening process will call + # Tensor.set_ with a Storage, and Storages cannot be traced with + # AOTAutograd; so we need to graph-break. To ensure this, we inline + # these functions, rather than keep them opaque-ly in the graph. + disallowed_modules = [ + "torch.optim.", + "torch.nn.modules.rnn.", + "torch._dynamo.", + "torch._C._dynamo.", + "torch._inductor.", + "torch._C.inductor.", + "torch.fx.", + "torch._C._autograd", + "torch._C._cudart", + "torch._C._distributed_autograd", + "torch._C._distributed_c10d", + "torch._C._distributed_rpc", + "torch._C._functorch", + "torch._C._monitor", + "torch._C._nvtx", + "torch._C._lazy", + "torch._C._profiler", + "torch.__config__", + "torch._custom_op", + "torch._decomp", + "torch._dispatch", + "torch._export", + "torch._functorch.make_functional", + "torch._functorch.compile_utils", + "torch._functorch.partitioners", + "torch._functorch.aot_autograd", + "torch._functorch.compilers", + "torch._functorch.fx_minifier", + "torch.autograd.profiler_util", + "torch.autograd.profiler", + "torch._jit_internal", + "torch._library", + "torch._lobpcg", + "torch._logging", + "torch._meta_registrations", + "torch._namedtensor_internals", + "torch._numpy", + "torch._sources", + "torch._subclasses", + "torch._tensor", + "torch._tensor_str", + "torch._utils", + "torch._utils_internal", + "torch._vmap_internals", + "torch.compiler", + "torch.distributed", + "torch.export", + "torch.hub", + "torch.jit", + "torch.library", + "torch.masked.maskedtensor", + "torch.nn.init", + "torch.nn.modules.module", + "torch.nn.parallel", + "torch.nn.utils", + "torch.multiprocessing", + "torch.onnx", + "torch.overrides", + "torch.package", + "torch.profiler", + "torch.serialization", + "torch.storage", + "torch.utils", + "torch.distributed.", + ] + + allowed_modules_dot = tuple([x + "." for x in allowed_modules]) + module = inspect.getmodule(obj) + if module is None: + return False + + mod_name = module.__name__ + + if any(mod_name.startswith(m) for m in disallowed_modules): + return False + + return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) + + def _find_torch_objects(module): + if any( + module.__name__.startswith(mod_name) + for mod_name in config.allowed_functions_module_string_ignorelist + ): + return + torch_object_ids[id(module)] = module.__name__ + for name, obj in list(module.__dict__.items()): + if id(obj) not in torch_object_ids: + # Dynamo allows all builtins into the graph and does not attempt + # to introspect into them. We don't want to allow instances of + # HigherOrderOperator into the graph all the time (Dynamo needs + # to introspect the body functions of these HigherOrderOperator + # first, decide they are safe, and then allow them into the graph). + # So we exclude HigherOrderOperator from being a builtin. + import torch._ops + + if isinstance(obj, torch._ops.HigherOrderOperator): + continue + + # We want to trace through `grad` and `vmap` + if obj in ( + torch.func.grad, + deprecated_func.grad, + torch.func.vmap, + deprecated_func.vmap, + torch.nn.functional.triplet_margin_with_distance_loss, + torch.cond, + ): + continue + + if isinstance(obj, types.ModuleType): + if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( + obj + ): + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + _find_torch_objects(obj) + elif _is_allowed_module_prefix(obj): + if record: + heuristic_record_if_in_graph_function(obj, module, name) + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + elif inspect.getmodule(obj) is None and not is_safe_constant(obj): + if record: + heuristic_record_if_in_graph_function(obj, module, name) + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + + _find_torch_objects(torch) + _find_torch_objects(math) + + return AllowedObjects( + torch_object_ids, + c_binding_in_graph_functions, + non_c_binding_in_graph_functions, + torch_name_rule_map, + ) + + +class TraceRuleTests(torch._dynamo.test_case.TestCase): + def _check_set_equality(self, generated, used, rule_map, ignored_set): + x = generated - used + y = used - generated + msg1 = ( + f"New torch objects: {x} " + f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " + "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." + ) + msg2 = ( + f"Existing torch objects: {y} were removed. " + f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " + "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." + ) + self.assertTrue(len(x) == 0, msg1) + self.assertTrue(len(y) == 0, msg2) + + # We are using python function and module string names for these inlinelist, + # this unit test is to make sure the functions/modules can be correctly imported + # or loaded in case there is typo in the strings. + def test_skipfiles_inlinelist(self): + for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST): + try: + mod = importlib.import_module(m) + except ImportError: + continue + else: + self.assertTrue( + isinstance(mod, types.ModuleType), + f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST " + "is not a python module, please check and correct it.", + ) + + @unittest.skip( + "This test keeps getting broken and our disable infra is not handling well. see #120627" + ) + def test_torch_name_rule_map_updated(self): + # Generate the allowed objects based on heuristic defined in `allowed_functions.py`, + objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True) + # Test C binding in graph functions are updated in torch_name_rule_map. + generated = objs.c_binding_in_graph_functions + used = set() + for x in ( + set(torch_c_binding_in_graph_functions.keys()) + | ignored_c_binding_in_graph_function_names + ): + obj = load_object(x) + if obj is not None: + used.add(obj) + self._check_set_equality( + generated, + used, + "torch_c_binding_in_graph_functions", + "ignored_c_binding_in_graph_function_names", + ) + # For non C binding in graph functions, we only test if they can be loaded successfully. + for f in torch_non_c_binding_in_graph_functions: + self.assertTrue( + isinstance( + load_object(f), + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + ), + ) + ) + + def test_force_inline_torch_function(self): + # `torch._dynamo.utils.istype` is skipped by default + def fn(x): + if istype(x, torch.Tensor): + return x + 1 + else: + return x - 1 + + _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() + # Force inline `torch._dynamo.utils.istype` by setting trace rule. + _manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable + + _torch_name_rule_map = [ + _manual_torch_name_rule_map, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, + ] + + self.assertTrue( + "torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST + ) + self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) + + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + ), + ): + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_force_inline_custom_function(self): + mod, func = create_dummy_module_and_function() + + def fn(x): + return func(x) + + _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() + # Force inline `mod.func` by setting trace rule. + _manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = ( + UserFunctionVariable + ) + + _torch_name_rule_map = [ + _manual_torch_name_rule_map, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, + ] + + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + ), + ): + # First adding the module to SKIP_DIRS so that it will be skipped by default. + skip_dirs_backup = torch._dynamo.trace_rules.SKIP_DIRS.copy() + skip_dirs_re_backup = torch._dynamo.trace_rules.SKIP_DIRS_RE + try: + torch._dynamo.trace_rules.add(mod.__name__) + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + finally: + torch._dynamo.trace_rules.SKIP_DIRS = skip_dirs_backup + torch._dynamo.trace_rules.SKIP_DIRS_RE = skip_dirs_re_backup + + def test_no_special_handlers_for_torch_non_c_bindings(self): + handlers = TorchInGraphFunctionVariable._get_handlers() + # These handlers are manually audited to be safe + safe_handlers = ( + "handle_tracing_state_functions", # No global state (constant) + "handle_radians", # No global state (constant) + "handle_is_tensor", # No global state + "handle_torch_compile", # No global state, constant + "handle_ntuple", # No global state + "handle_is_grad_enabled", # Safely implemented + "handle_use_deterministic_algorithms", # Guarded variable + "handle_are_deterministic_algorithms_enabled", # Guarded constant + "handle_device_interface_stream", # No global state + "handle_cudnn_is_acceptable", # No global state + "handle_assert", # No global state (constant) + "handle_nested_tensor", # No global state + "handle_current_stream", # Safely implemented + "handle_synchronize", # Device type from function identity or arg + "handle_functorch_autograd_grad", # Only inspects placeholder metadata + ) + for fn in handlers: + if isinstance(fn, staticmethod) or inspect.ismethod(fn): + fn_name = f"{fn.__module__}#{fn.__name__}" + else: + fn_name = f"{fn.__module__}.{fn.__name__}" + if handlers[fn].__name__ in safe_handlers: + continue + self.assertFalse( + fn_name in torch_non_c_binding_in_graph_functions, + ( + f"torch function {fn_name} has a special handler {handlers[fn].__name__}.\n" + "We expected all functions in `torch_non_c_binding_in_graph_functions` to be safe to cache.\n" + "Functions with special handlers may not be safe to cache, since they can close over global state.\n" + "If your handler/function is safe to cache, please add it to the list of safe handlers above.\n" + "Otherwise, add it to `manual_torch_name_rule_map` instead." + ), + ) + + def test_almost_impossible_missing_name(self): + class weird: + def __getattribute__(self, name): + if name == "__name__": + raise AttributeError("test") + + w = weird() + o = set() + with self.assertRaises(AttributeError): + w.__name__ + self.assertEqual(lookup_inner(w, name=None, reasons=o), SkipFunctionVariable) + + +class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase): + @unittest.skipIf( + not torch.distributed.is_available(), + "need to import MLP module from distributed", + ) + @skipIfWindows( + msg="AssertionError: False is not true : MLP did not survive skip files" + ) + def test_module_survive_skip_files(self): + from torch.testing._internal.common_fsdp import MLP + + model = MLP(3) + inp = torch.randn((2, 3)) + frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER + model.compile(backend="eager") + model(inp) + frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER + self.assertTrue( + frame_count_after > frame_count_before, "MLP did not survive skip files" + ) + + +class SingleOpCompileTests(torch._dynamo.test_case.TestCase): + def test_top_level_torch_exp_compiles_through_dynamo(self): + x = torch.randn(4) + + # Sanity: lambda version should go through Dynamo + lambda_counter = CompileCounter() + opt_lambda = torch.compile(lambda t: torch.exp(t), backend=lambda_counter) + y_lambda = opt_lambda(x) + self.assertEqual( + lambda_counter.frame_count, + 1, + "Sanity check failed: lambda version did not compile through Dynamo exactly once.", + ) + # Regression target: torch.compile(torch.exp) + top_level_counter = CompileCounter() + opt_exp = torch.compile(torch.exp, backend=top_level_counter) + y_exp = opt_exp(x) + self.assertEqual( + top_level_counter.frame_count, + 1, + "Expected torch.compile(torch.exp) to compile through Dynamo exactly once.", + ) + # Numerical results should match + self.assertTrue(torch.allclose(y_lambda, y_exp)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index eb82beeaf0..8dcad4fcaa 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -255,6 +255,13 @@ "profiler/test_execution_trace_xpu.py": None, "profiler/test_profiler_xpu.py": None, "export/test_hop_xpu.py": None, + "tracked_process_pool_xpu.py": None, + "dynamo/test_trace_rules_xpu.py": None, + "dynamo/test_subclasses_xpu.py": None, + "dynamo/test_structured_trace_xpu.py": None, + "dynamo/test_export_xpu.py": None, + "dynamo/test_debug_utils_xpu.py": None, + "dynamo/test_activation_checkpointing_xpu.py": None, "export/test_export_opinfo_xpu.py": None, "export/test_converter_xpu.py": None, "export/test_cpp_serdes_xpu.py": None, diff --git a/test/xpu/tracked_process_pool_xpu.py b/test/xpu/tracked_process_pool_xpu.py new file mode 100644 index 0000000000..deeee011d0 --- /dev/null +++ b/test/xpu/tracked_process_pool_xpu.py @@ -0,0 +1,113 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, TypeVar + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork +from typing_extensions import ParamSpec + +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) +_P = ParamSpec("_P") +_R = TypeVar("_R") +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: int | None = None, + mp_context: BaseContext | None = None, + initializer: Callable[[], object] | None = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + # pyrefly: ignore [bad-argument-type] + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100])