🐛 Bug
In SPMD mode, optimized_mod in dynamo_bridge.py retraces the graph on every step because freshly created input tensors report "" from _get_xla_sharding_specs, while the captured spec from the previous dispatch holds {replicated}.
In xla_sharding_util.cpp, when dispatching sharded data for a tensor whose sharding_spec is nullptr, the runtime assigns xla::HloSharding::Replicate(). So after the first dispatch, those tensors carry a {replicated} annotation.
extract_graph_helper captures xla_args_sharding_spec by calling _get_xla_sharding_specs. At capture time the input tensors have no annotation (""). After the graph is compiled and dispatched, the same tensor objects now carry {replicated}.
On the next call to optimized_mod, the check:
if torch_xla._XLAC._get_xla_sharding_specs(xla_args_tensor_only) != xla_args_sharding_spec:
compares "" (fresh runtime tensor) against {replicated} (captured post-dispatch) and treats it as a sharding change, triggering a full retrace via extract_graph_helper every single step.
Important to note is that we have a PR on our fork of pytorch_xla: tenstorrent#24, which wraps and normalizes "" to <replicated> to solve this issue. Would appreciate any input that you can provide is this upstreamable?
To Reproduce
This is tt-specific code, but any custom backend should have the same issue:
import os
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
import numpy as np
import torch
import torch.nn as nn
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla.distributed.spmd import Mesh
xr.use_spmd()
mesh = Mesh(np.array(range(2)), (1, 2), ("batch", "model"))
device = torch_xla.device()
model = nn.Linear(8, 8, bias=False).to(device)
xs.mark_sharding(model.weight, mesh, ("model", None))
compiled = torch.compile(model, backend="tt")
with torch.no_grad():
for i in range(5):
x = torch.randn(4, 8).to(device) # CPU → device: no XLA IR sharding
out = compiled(x)
out.to("cpu") # sync
print(f"step {i}")
Expected behavior
Expected to trace only once, actually traces 5 times.
🐛 Bug
In SPMD mode, optimized_mod in dynamo_bridge.py retraces the graph on every step because freshly created input tensors report "" from _get_xla_sharding_specs, while the captured spec from the previous dispatch holds {replicated}.
In xla_sharding_util.cpp, when dispatching sharded data for a tensor whose sharding_spec is nullptr, the runtime assigns
xla::HloSharding::Replicate(). So after the first dispatch, those tensors carry a {replicated} annotation.extract_graph_helpercapturesxla_args_sharding_specby calling_get_xla_sharding_specs. At capture time the input tensors have no annotation (""). After the graph is compiled and dispatched, the same tensor objects now carry {replicated}.On the next call to optimized_mod, the check:
if torch_xla._XLAC._get_xla_sharding_specs(xla_args_tensor_only) != xla_args_sharding_spec:compares "" (fresh runtime tensor) against {replicated} (captured post-dispatch) and treats it as a sharding change, triggering a full retrace via extract_graph_helper every single step.
Important to note is that we have a PR on our fork of pytorch_xla: tenstorrent#24, which wraps and normalizes
""to<replicated>to solve this issue. Would appreciate any input that you can provide is this upstreamable?To Reproduce
This is tt-specific code, but any custom backend should have the same issue:
Expected behavior
Expected to trace only once, actually traces 5 times.