diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index b5fdd5f..c74d9a8 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -18,7 +18,6 @@ """ import functools -from typing import Any import jax @@ -47,22 +46,6 @@ def __call__(self, *args, **kwargs): raise ImportError(self.error_message) -try: - # jax>=0.7.0 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - register_backend_cache = backend.register_backend_cache - - del backend -except AttributeError: - # jax<0.7.0 - from jax._src import util # pylint: disable=g-import-not-at-top - - def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument - return util.cache_clearing_funs.add(cache.cache_clear) - - del util - try: # jax>=0.7.1 from jax.extend import backend # pylint: disable=g-import-not-at-top @@ -130,6 +113,5 @@ def ifrt_reshard_available() -> bool: del jax -del Any del _FakeJaxFunction del functools diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 1670ef9..6608704 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -16,7 +16,7 @@ import functools from typing import Any, Callable -from pathwaysutils import jax as pw_jax +from jax.extend import backend def lru_cache( @@ -38,7 +38,7 @@ def wrap(f): wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info - pw_jax.register_backend_cache(wrapper, "Pathways LRU cache") + backend.register_backend_cache(wrapper, "Pathways LRU cache") return wrapper return wrap