Skip to content

Torchax registering jax as PyTorch accelerator #9730

@ajakovljevicTT

Description

@ajakovljevicTT

🐛 Bug

When torchax is imported, it registers "jax" as a PyTorch accelerator, which causes torch.compile(backend='inductor') to fail for functions/models with no tensor inputs.

torchax/init.py:82 unconditionally calls:
torch.utils.rename_privateuse1_backend('jax')

This makes torch.accelerator.is_available() return True.

In torch/_inductor/codecache.py:812-813, FxGraphHashDetails.__init__ does:

if no_tensor_inputs and torch.accelerator.is_available():
    self.default_cuda_device_index = torch.accelerator.current_device_index()

Since PyTorch isn't actually linked with jax device support, current_device_index() raises:
RuntimeError: PyTorch is not linked with support for jax devices

torchax gets imported in via torch_xla/distributed/spmd/xla_sharding.py:725 calling maybe_get_torchax() inside mark_sharding().

To Reproduce

I'm giving a small python repro below:

import torch
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

# Setup SPMD mesh and mark tensor sharding (standard torch_xla SPMD usage)
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh = xs.Mesh(list(range(num_devices)), (num_devices,), ('data',))
t = torch.randn(4, 4).to('xla')
xs.mark_sharding(t, mesh, (0, None))

# Now inductor fails for no-tensor-input functions
@torch.compile(backend='inductor')
def make_grid():
    return torch.zeros(3, 3)

make_grid()  # RuntimeError: PyTorch is not linked with support for jax devices

This will fail with the following error:

  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2196, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2171, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 2392, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2681, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 117, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1106, in aot_module_simplified
    compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 242, in aot_stage2_compile
    return aot_stage2_inference(aot_state, aot_graph_capture)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 315, in aot_stage2_inference
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/schemas.py", line 1251, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2558, in fw_compiler_base
    return compile_fx_forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2275, in compile_fx_forward
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 782, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_aot.py", line 144, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 900, in _compile_fx_inner
    (key_info, cache_info) = FxGraphCache.prepare_key(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 1489, in prepare_key
    key, debug_lines = compiled_fx_graph_hash(
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 955, in compiled_fx_graph_hash
    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 845, in __init__
    self.default_cuda_device_index = torch.accelerator.current_device_index()
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/accelerator/__init__.py", line 132, in current_device_index
    return torch._C._accelerator_getDeviceIndex()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PyTorch is not linked with support for jax devices

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions