🐛 Bug
Calling torch_xla.device() or setting random seed before use_spmd() produces a SIGSEGV for unmarked tensors.
To Reproduce
import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
xm.set_rng_state(42)
torch_xla.device()
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8192, 4096).to(torch_xla.device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
x = torch.randn(1024, 8192).to(torch_xla.device())
mult = torch.matmul(x, t)
torch_xla.sync()
print("Result shape:", mult.shape)
print("total sum of result:", mult.sum().item())
Steps to reproduce the behavior:
- Comment out
set_rng_state(...) or device() calls
- Run and observe it works
- Uncomment either one of them
*** SIGSEGV (@0x1f0), see go/stacktraces#s15 received by PID 55852 (TID 57319) on cpu 126; stack trace: ***
PC: @ 0x71a5473605d9 (unknown) std::_Function_handler<>::_M_invoke()
@ 0x71a4c9e9abc5 1904 (unknown)
@ 0x71a709042520 3184 (unknown)
@ 0x71a550c3c4de 32 std::_Function_handler<>::_M_invoke()
@ 0x71a547fb7072 320 Eigen::ThreadPoolDevice::parallelFor()
@ 0x71a550c402c5 608 tsl::thread::ThreadPool::ParallelFor()
@ 0x71a547e96b4d 1376 torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
@ 0x71a547c3ce75 816 torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
@ 0x71a64003f4b8 192 torch::lazy::MultiWait::Complete()
@ 0x71a550c3c488 64 absl::lts_20250512::internal_any_invocable::RemoteInvoker<>()
@ 0x71a550c322c2 96 tsl::(anonymous namespace)::PThread::ThreadFn()
@ 0x71a709094ac3 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=71a5473605d9,71a4c9e9abc4,71a70904251f,71a550c3c4dd,71a547fb7071,71a550c402c4,71a547e96b4c,71a547c3ce74,71a64003f4b7,71a550c3c487,71a550c322c1,71a709094ac2&map=
E0120 20:49:43.195168 57319 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0120 20:49:43.195187 57319 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E0120 20:49:43.195191 57319 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0120 20:49:43.195195 57319 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0120 20:49:43.195228 57319 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0120 20:49:43.195234 57319 coredump_hook.cc:457] RAW: Dumping core locally.
E0120 20:49:43.452738 57319 process_state.cc:808] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)
Expected behavior
Either use_spmd() should raise an error to prevent weird errors moving forward or should handle the calls to device() or other methods more gracefully.
Environment
- Reproducible on XLA backend [CPU/TPU]: Libtpu version: 0.0.21, Accelerator type: v6e, 8 chips 1 node.
- torch_xla version: 2.9.0
Additional Details
It took me multiple days to understand this was caused by setting seed before calling use_spmd().
from accelerate.utils import set_seed
set_seed(42)
This calls xm.set_rng_state(seed) for XLA devices. I suspect the underlying torch_xla._XLAC._xla_get_default_device() call is causing it. Somehow some tensors end up in the virtual device, and some tensors end up in actual device.
When I tried to debug this issue, I ended up marking every tensor created as sharded later on. For example doing
x = torch.randn(1024, 8192).to(torch_xla.device())
xs.mark_sharding(x, mesh, partition_spec)
Fixes this issue. However, whenever the underlying library creates a new tensor which is not marked, it raises the issue. I inspected torch_xla._XLAC._get_xla_sharding_spec values on non-buggy and buggy versions. But they show up similarly.
Before sharding:
After sharding: {devices=[8,1]0,1,2,3,4,5,6,7}
🐛 Bug
Calling
torch_xla.device()or setting random seed beforeuse_spmd()produces a SIGSEGV for unmarked tensors.To Reproduce
Steps to reproduce the behavior:
set_rng_state(...)ordevice()callsExpected behavior
Either
use_spmd()should raise an error to prevent weird errors moving forward or should handle the calls todevice()or other methods more gracefully.Environment
Additional Details
It took me multiple days to understand this was caused by setting seed before calling
use_spmd().This calls
xm.set_rng_state(seed)for XLA devices. I suspect the underlyingtorch_xla._XLAC._xla_get_default_device()call is causing it. Somehow some tensors end up in the virtual device, and some tensors end up in actual device.When I tried to debug this issue, I ended up marking every tensor created as sharded later on. For example doing
Fixes this issue. However, whenever the underlying library creates a new tensor which is not marked, it raises the issue. I inspected
torch_xla._XLAC._get_xla_sharding_specvalues on non-buggy and buggy versions. But they show up similarly.