From d825f22456a194695fc3eab422c7b0e6feb60486 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 7 Jan 2026 10:34:34 -0800 Subject: [PATCH] Require JAX>=0.7.0 ---- Directly use jax.extend.backend for register_backend_cache. This change updates pathwaysutils/lru_cache.py to import and use `jax.extend.backend.register_backend_cache` directly. The re-export of this function from `pathwaysutils.jax` is removed, along with version-specific compatibility code for older JAX versions. PiperOrigin-RevId: 853322003 --- pathwaysutils/jax/__init__.py | 18 ------------------ pathwaysutils/lru_cache.py | 4 ++-- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index b5fdd5f..c74d9a8 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -18,7 +18,6 @@ """ import functools -from typing import Any import jax @@ -47,22 +46,6 @@ def __call__(self, *args, **kwargs): raise ImportError(self.error_message) -try: - # jax>=0.7.0 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - register_backend_cache = backend.register_backend_cache - - del backend -except AttributeError: - # jax<0.7.0 - from jax._src import util # pylint: disable=g-import-not-at-top - - def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument - return util.cache_clearing_funs.add(cache.cache_clear) - - del util - try: # jax>=0.7.1 from jax.extend import backend # pylint: disable=g-import-not-at-top @@ -130,6 +113,5 @@ def ifrt_reshard_available() -> bool: del jax -del Any del _FakeJaxFunction del functools diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 1670ef9..6608704 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -16,7 +16,7 @@ import functools from typing import Any, Callable -from pathwaysutils import jax as pw_jax +from jax.extend import backend def lru_cache( @@ -38,7 +38,7 @@ def wrap(f): wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info - pw_jax.register_backend_cache(wrapper, "Pathways LRU cache") + backend.register_backend_cache(wrapper, "Pathways LRU cache") return wrapper return wrap