diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index b5fdd5f..e5bc106 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -17,10 +17,8 @@ `pathwaysutils`'s compatibility window. """ -import functools -from typing import Any -import jax +import functools class _FakeJaxFunction: @@ -47,36 +45,6 @@ def __call__(self, *args, **kwargs): raise ImportError(self.error_message) -try: - # jax>=0.7.0 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - register_backend_cache = backend.register_backend_cache - - del backend -except AttributeError: - # jax<0.7.0 - from jax._src import util # pylint: disable=g-import-not-at-top - - def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument - return util.cache_clearing_funs.add(cache.cache_clear) - - del util - -try: - # jax>=0.7.1 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - ifrt_proxy = backend.ifrt_proxy - del backend -except AttributeError: - # jax<0.7.1 - from jax.lib import xla_extension # pylint: disable=g-import-not-at-top - - ifrt_proxy = xla_extension.ifrt_proxy - del xla_extension - - try: # jax>=0.8.0 from jaxlib import _pathways # pylint: disable=g-import-not-at-top @@ -129,7 +97,5 @@ def ifrt_reshard_available() -> bool: del jax -del jax -del Any del _FakeJaxFunction del functools diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 1670ef9..6608704 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -16,7 +16,7 @@ import functools from typing import Any, Callable -from pathwaysutils import jax as pw_jax +from jax.extend import backend def lru_cache( @@ -38,7 +38,7 @@ def wrap(f): wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info - pw_jax.register_backend_cache(wrapper, "Pathways LRU cache") + backend.register_backend_cache(wrapper, "Pathways LRU cache") return wrapper return wrap diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index cc9ccce..42a52ad 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -15,15 +15,15 @@ import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy def register_backend_factory(): backend.register_backend_factory( "proxy", - lambda: pw_jax.ifrt_proxy.get_client( + lambda: ifrt_proxy.get_client( jax.config.read("jax_backend_target"), - pw_jax.ifrt_proxy.ClientConnectionOptions(), + ifrt_proxy.ClientConnectionOptions(), ), priority=-1, ) diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 50bf314..50ecc8f 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -17,7 +17,7 @@ import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy from pathwaysutils import proxy_backend from absl.testing import absltest @@ -38,7 +38,7 @@ def test_no_proxy_backend_registration_raises_error(self): def test_proxy_backend_registration(self): self.enter_context( mock.patch.object( - pw_jax.ifrt_proxy, + ifrt_proxy, "get_client", return_value=mock.MagicMock(), )