Skip to content

[Bug] Calling torch_xla.device() or setting random seed before use_spmd() produces SIGSEGV for unmarked tensors #9735

@Dogacel

Description

@Dogacel

🐛 Bug

Calling torch_xla.device() or setting random seed before use_spmd() produces a SIGSEGV for unmarked tensors.

To Reproduce

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

from torch_xla.distributed.spmd import Mesh

xm.set_rng_state(42)

torch_xla.device()

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

t = torch.randn(8192, 4096).to(torch_xla.device())

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)


x = torch.randn(1024, 8192).to(torch_xla.device())


mult = torch.matmul(x, t)
torch_xla.sync()

print("Result shape:", mult.shape)
print("total sum of result:", mult.sum().item())

Steps to reproduce the behavior:

  1. Comment out set_rng_state(...) or device() calls
  2. Run and observe it works
  3. Uncomment either one of them
*** SIGSEGV (@0x1f0), see go/stacktraces#s15 received by PID 55852 (TID 57319) on cpu 126; stack trace: ***
PC: @     0x71a5473605d9  (unknown)  std::_Function_handler<>::_M_invoke()
    @     0x71a4c9e9abc5       1904  (unknown)
    @     0x71a709042520       3184  (unknown)
    @     0x71a550c3c4de         32  std::_Function_handler<>::_M_invoke()
    @     0x71a547fb7072        320  Eigen::ThreadPoolDevice::parallelFor()
    @     0x71a550c402c5        608  tsl::thread::ThreadPool::ParallelFor()
    @     0x71a547e96b4d       1376  torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
    @     0x71a547c3ce75        816  torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
    @     0x71a64003f4b8        192  torch::lazy::MultiWait::Complete()
    @     0x71a550c3c488         64  absl::lts_20250512::internal_any_invocable::RemoteInvoker<>()
    @     0x71a550c322c2         96  tsl::(anonymous namespace)::PThread::ThreadFn()
    @     0x71a709094ac3  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=71a5473605d9,71a4c9e9abc4,71a70904251f,71a550c3c4dd,71a547fb7071,71a550c402c4,71a547e96b4c,71a547c3ce74,71a64003f4b7,71a550c3c487,71a550c322c1,71a709094ac2&map=
E0120 20:49:43.195168   57319 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0120 20:49:43.195187   57319 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E0120 20:49:43.195191   57319 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0120 20:49:43.195195   57319 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0120 20:49:43.195228   57319 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0120 20:49:43.195234   57319 coredump_hook.cc:457] RAW: Dumping core locally.
E0120 20:49:43.452738   57319 process_state.cc:808] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)

Expected behavior

Either use_spmd() should raise an error to prevent weird errors moving forward or should handle the calls to device() or other methods more gracefully.

Environment

  • Reproducible on XLA backend [CPU/TPU]: Libtpu version: 0.0.21, Accelerator type: v6e, 8 chips 1 node.
  • torch_xla version: 2.9.0

Additional Details

It took me multiple days to understand this was caused by setting seed before calling use_spmd().

from accelerate.utils import set_seed
set_seed(42)

This calls xm.set_rng_state(seed) for XLA devices. I suspect the underlying torch_xla._XLAC._xla_get_default_device() call is causing it. Somehow some tensors end up in the virtual device, and some tensors end up in actual device.

When I tried to debug this issue, I ended up marking every tensor created as sharded later on. For example doing

x = torch.randn(1024, 8192).to(torch_xla.device())

xs.mark_sharding(x, mesh, partition_spec)

Fixes this issue. However, whenever the underlying library creates a new tensor which is not marked, it raises the issue. I inspected torch_xla._XLAC._get_xla_sharding_spec values on non-buggy and buggy versions. But they show up similarly.

Before sharding:
After sharding: {devices=[8,1]0,1,2,3,4,5,6,7}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions