Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 1 addition & 35 deletions pathwaysutils/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
`pathwaysutils`'s compatibility window.
"""

import functools
from typing import Any

import jax
import functools


class _FakeJaxFunction:
Expand All @@ -47,36 +45,6 @@ def __call__(self, *args, **kwargs):
raise ImportError(self.error_message)


try:
# jax>=0.7.0
from jax.extend import backend # pylint: disable=g-import-not-at-top

register_backend_cache = backend.register_backend_cache

del backend
except AttributeError:
# jax<0.7.0
from jax._src import util # pylint: disable=g-import-not-at-top

def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument
return util.cache_clearing_funs.add(cache.cache_clear)

del util

try:
# jax>=0.7.1
from jax.extend import backend # pylint: disable=g-import-not-at-top

ifrt_proxy = backend.ifrt_proxy
del backend
except AttributeError:
# jax<0.7.1
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top

ifrt_proxy = xla_extension.ifrt_proxy
del xla_extension


try:
# jax>=0.8.0
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -129,7 +97,5 @@ def ifrt_reshard_available() -> bool:
del jax


del jax
del Any
del _FakeJaxFunction
del functools
4 changes: 2 additions & 2 deletions pathwaysutils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
from typing import Any, Callable

from pathwaysutils import jax as pw_jax
from jax.extend import backend


def lru_cache(
Expand All @@ -38,7 +38,7 @@ def wrap(f):

wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
pw_jax.register_backend_cache(wrapper, "Pathways LRU cache")
backend.register_backend_cache(wrapper, "Pathways LRU cache")
return wrapper

return wrap
6 changes: 3 additions & 3 deletions pathwaysutils/proxy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import jax
from jax.extend import backend
from pathwaysutils import jax as pw_jax
from jax.extend.backend import ifrt_proxy


def register_backend_factory():
backend.register_backend_factory(
"proxy",
lambda: pw_jax.ifrt_proxy.get_client(
lambda: ifrt_proxy.get_client(
jax.config.read("jax_backend_target"),
pw_jax.ifrt_proxy.ClientConnectionOptions(),
ifrt_proxy.ClientConnectionOptions(),
),
priority=-1,
)
4 changes: 2 additions & 2 deletions pathwaysutils/test/proxy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import jax
from jax.extend import backend
from pathwaysutils import jax as pw_jax
from jax.extend.backend import ifrt_proxy
from pathwaysutils import proxy_backend

from absl.testing import absltest
Expand All @@ -38,7 +38,7 @@ def test_no_proxy_backend_registration_raises_error(self):
def test_proxy_backend_registration(self):
self.enter_context(
mock.patch.object(
pw_jax.ifrt_proxy,
ifrt_proxy,
"get_client",
return_value=mock.MagicMock(),
)
Expand Down
Loading