Skip to content

SPMD: spurious retrace every step due to sharding spec mismatch for unsharded tensors #9755

@ajakovljevicTT

Description

@ajakovljevicTT

🐛 Bug

In SPMD mode, optimized_mod in dynamo_bridge.py retraces the graph on every step because freshly created input tensors report "" from _get_xla_sharding_specs, while the captured spec from the previous dispatch holds {replicated}.
In xla_sharding_util.cpp, when dispatching sharded data for a tensor whose sharding_spec is nullptr, the runtime assigns xla::HloSharding::Replicate(). So after the first dispatch, those tensors carry a {replicated} annotation.
extract_graph_helper captures xla_args_sharding_spec by calling _get_xla_sharding_specs. At capture time the input tensors have no annotation (""). After the graph is compiled and dispatched, the same tensor objects now carry {replicated}.

On the next call to optimized_mod, the check:
if torch_xla._XLAC._get_xla_sharding_specs(xla_args_tensor_only) != xla_args_sharding_spec:

compares "" (fresh runtime tensor) against {replicated} (captured post-dispatch) and treats it as a sharding change, triggering a full retrace via extract_graph_helper every single step.

Important to note is that we have a PR on our fork of pytorch_xla: tenstorrent#24, which wraps and normalizes "" to <replicated> to solve this issue. Would appreciate any input that you can provide is this upstreamable?

To Reproduce

This is tt-specific code, but any custom backend should have the same issue:

import os
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"

import numpy as np
import torch
import torch.nn as nn
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla.distributed.spmd import Mesh

xr.use_spmd()

mesh = Mesh(np.array(range(2)), (1, 2), ("batch", "model"))
device = torch_xla.device()

model = nn.Linear(8, 8, bias=False).to(device)
xs.mark_sharding(model.weight, mesh, ("model", None))
compiled = torch.compile(model, backend="tt")

with torch.no_grad():
    for i in range(5):
        x = torch.randn(4, 8).to(device)  # CPU → device: no XLA IR sharding
        out = compiled(x)
        out.to("cpu")  # sync
        print(f"step {i}")

Expected behavior

Expected to trace only once, actually traces 5 times.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions