From 4e6e6b7aa5bcbe42c11a0b700e95b334a1a3ec7d Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 20 Mar 2026 14:19:38 -0700 Subject: [PATCH] Update environment variables for JAX backend PiperOrigin-RevId: 886971839 --- .../shared_pathways_service/isc_pathways.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index edd5715..b25df7e 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -123,6 +123,18 @@ def _deploy_pathways_proxy_server( _logger.info("Successfully deployed Pathways proxy.") +def _restore_env_var(key: str, original_value: str | None) -> None: + """Restores an environment variable to its original value or unsets it.""" + if original_value is None: + _logger.info("Unsetting environment variable: %s", key) + os.environ.pop(key, None) + else: + _logger.info( + "Restoring environment variable '%s' to '%s'", key, original_value + ) + os.environ[key] = original_value + + class _ISCPathways: """Class for managing TPUs for interactive supercomputing. @@ -163,6 +175,10 @@ def __init__( self._proxy_port = None self.proxy_server_image = proxy_server_image self.proxy_options = proxy_options or ProxyOptions() + self._old_jax_platforms = None + self._old_jax_backend_target = None + self._old_jax_platforms_config = None + self._old_jax_backend_target_config = None def __repr__(self): return ( @@ -176,6 +192,15 @@ def __repr__(self): def __enter__(self): """Enters the context manager, ensuring cluster exists.""" + self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY) + self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY) + self._old_jax_platforms_config = getattr( + jax.config, _JAX_PLATFORMS_KEY, None + ) + self._old_jax_backend_target_config = getattr( + jax.config, _JAX_BACKEND_TARGET_KEY, None + ) + try: _deploy_pathways_proxy_server( pathways_service=self.pathways_service, @@ -199,11 +224,17 @@ def __enter__(self): ) # Update the JAX backend to use the proxy. + os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY + os.environ[ + _JAX_BACKEND_TARGET_KEY + ] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}" + jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY) jax.config.update( _JAX_BACKEND_TARGET_KEY, f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}", ) + pathwaysutils.initialize() _logger.info( "Interactive supercomputing proxy client ready for cluster '%s'.", @@ -221,7 +252,7 @@ def __exit__(self, exc_type, exc_value, traceback): _logger.info("Exiting ISCPathways context.") self._cleanup() - def _cleanup(self): + def _cleanup(self) -> None: """Cleans up resources created by the ISCPathways context.""" # 1. Clear JAX caches and run garbage collection. _logger.info("Starting Pathways proxy cleanup.") @@ -248,6 +279,16 @@ def _cleanup(self): gke_utils.delete_gke_job(self._proxy_job_name) _logger.info("Pathways proxy GKE job deletion complete.") + # 4. Restore JAX variables. + _logger.info("Restoring JAX env and config variables...") + _restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms) + _restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target) + jax.config.update(_JAX_PLATFORMS_KEY, self._old_jax_platforms_config) + jax.config.update( + _JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target_config + ) + _logger.info("JAX variables restored.") + @contextlib.contextmanager def connect(