From e7d15ccbc1f6abc96568f0ce160477239476afcd Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 16 Jan 2026 21:28:36 -0800 Subject: [PATCH] Refactor resharding logic into helper functions. This change introduces `_reshard_with_sidechannel` and `_reshard_with_ifrt` to encapsulate the different resharding mechanisms used by the `reshard` function. These are internal APIs and should not be depended on. PiperOrigin-RevId: 857415781 --- pathwaysutils/experimental/reshard.py | 81 ++++++++-- .../test/experimental/reshard_test.py | 147 ++++++++++++++++++ 2 files changed, 218 insertions(+), 10 deletions(-) create mode 100644 pathwaysutils/test/experimental/reshard_test.py diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 6bbd51e..6806244 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -17,6 +17,7 @@ import collections import json from typing import Any, Callable, Dict, Mapping, Sequence +import warnings import jax from pathwaysutils import jax as pw_jax @@ -104,7 +105,7 @@ def _reshard( x: Any, sharding: jax.sharding.Sharding | Any, *, - donate: bool = False, + donate: bool, may_alias: bool | None, jax_array_reshard_fn: Callable[..., Any], **kwargs, @@ -198,6 +199,61 @@ def _ifrt_jax_array_reshard( ) +def _reshard_with_sidechannel( + x: Any, + sharding: jax.sharding.Sharding | Any, + *, + donate: bool, + may_alias: bool | None, + cache_resharding_plans: bool, +) -> Any: + """Reshards `x` to `sharding` using sidechannel.""" + return _reshard( + x, + sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_sidechannel_jax_array_reshard, + cache_resharding_plans=cache_resharding_plans, + ) + + +def _reshard_with_ifrt( + x: Any, + sharding: jax.sharding.Sharding | Any, + *, + donate: bool, + may_alias: bool | None, +) -> Any: + """Reshards `x` to `sharding` using IFRT. + + Note: Resharding plan caching is not applicable to the IFRT implementation + and is not supported by this function. + + Args: + x: An array, scalar, or (nested) standard Python container thereof. + sharding: A `Sharding` or a (nested) `Sharding` in standard Python container + (must be a tree prefix of `x`), representing the device(s) and sharding to + which `x` should be sharded to. The result will be committed to the + device(s) of the sharding. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. + + Returns: + A copy of `x` whose sharding is `sharding`. + """ + return _reshard( + x, + sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_ifrt_jax_array_reshard, + ) + + def reshard( x: Any, sharding: jax.sharding.Sharding | Any, @@ -221,29 +277,34 @@ def reshard( reduce the amount of memory needed for resharding. Not used at the moment. cache_resharding_plans: If `True`, uses a resharding plan cache to avoid recreating plans for the same resharding operation. May improve - performance for use cases where the same resharding operation is done many - times. May degrade performance if most reshardings operations are - different, since the cache will cause Pathways Components to remain loaded - for each cached plan. `False` by default. Only used when IFRT resharding - is not available. + performance for use cases where the same resharding operation is done + many times. May degrade performance if most reshardings operations are + different, since the cache will cause Pathways Components to remain + loaded for each cached plan. `False` by default. This parameter is only + used when `pw_jax.ifrt_reshard_available()` is false. Returns: A copy of `x` whose sharding is `sharding`. """ if pw_jax.ifrt_reshard_available(): - return _reshard( + if cache_resharding_plans: + warnings.warn( + "`cache_resharding_plans` is only applicable when using the" + " sidechannel resharding implementation, but IFRT resharding is" + " available and will be used. The `cache_resharding_plans` argument" + " will be ignored." + ) + return _reshard_with_ifrt( x, sharding, donate=donate, may_alias=may_alias, - jax_array_reshard_fn=_ifrt_jax_array_reshard, ) else: - return _reshard( + return _reshard_with_sidechannel( x, sharding, donate=donate, may_alias=may_alias, - jax_array_reshard_fn=_sidechannel_jax_array_reshard, cache_resharding_plans=cache_resharding_plans, ) diff --git a/pathwaysutils/test/experimental/reshard_test.py b/pathwaysutils/test/experimental/reshard_test.py new file mode 100644 index 0000000..9cecd99 --- /dev/null +++ b/pathwaysutils/test/experimental/reshard_test.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping +import json +from typing import Any +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from pathwaysutils import jax as pw_jax +from pathwaysutils import plugin_executable +from pathwaysutils.experimental import reshard + + +class ReshardTest(parameterized.TestCase): + + @parameterized.parameters( + dict(reshard_kwargs={"donate": True}, expected_donate=True), + dict(reshard_kwargs={"donate": False}, expected_donate=False), + dict(reshard_kwargs={}, expected_donate=False), + ) + def test_ifrt_reshard_donate( + self, reshard_kwargs: Mapping[str, Any], expected_donate: bool + ): + x = jnp.array([1, 2]) + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + + mock_transfer = self.enter_context( + mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) + ) + self.enter_context( + mock.patch.object( + pw_jax, "ifrt_reshard_available", return_value=True, autospec=True + ) + ) + + reshard.reshard(x, sharding, **reshard_kwargs) + + # Signature: transfer_to_shardings(arrays, shardings, donate) + mock_transfer.assert_called_with(mock.ANY, mock.ANY, expected_donate) + + @parameterized.parameters( + dict(reshard_kwargs={"donate": True}, expected_donate=True), + dict(reshard_kwargs={"donate": False}, expected_donate=False), + dict(reshard_kwargs={}, expected_donate=False), + ) + def test_sidechannel_reshard_donate( + self, reshard_kwargs: Mapping[str, Any], expected_donate: bool + ): + x = jnp.array([1, 2]) + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + + self.enter_context( + mock.patch.object( + pw_jax, "ifrt_reshard_available", return_value=False, autospec=True + ) + ) + mock_pe = self.enter_context( + mock.patch.object(plugin_executable, "PluginExecutable", autospec=True) + ) + mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock()) + + reshard.reshard(x, sharding, **reshard_kwargs) + + mock_pe.assert_called() + (json_request,), _ = mock_pe.call_args + request = json.loads(json_request) + self.assertEqual(request["reshardRequest"]["donateInput"], expected_donate) + + @parameterized.parameters(True, False, None) + def test_ifrt_reshard_cache_resharding_plans(self, cache: bool | None): + x = jnp.array([1, 2]) + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + + mock_transfer = self.enter_context( + mock.patch.object(pw_jax, "transfer_to_shardings") + ) + self.enter_context( + mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=True) + ) + + if cache is None: + reshard.reshard(x, sharding) + elif cache: + with self.assertWarnsRegex( + UserWarning, "cache_resharding_plans` is only applicable" + ): + reshard.reshard(x, sharding, cache_resharding_plans=cache) + else: + reshard.reshard(x, sharding, cache_resharding_plans=cache) + + mock_transfer.assert_called_once() + + @parameterized.parameters( + dict(cache=True, expected_cache=True), + dict(cache=False, expected_cache=False), + dict(cache=None, expected_cache=False), + ) + def test_sidechannel_reshard_cache_resharding_plans( + self, cache, expected_cache + ): + x = jnp.array([1, 2]) + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + + self.enter_context( + mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=False) + ) + mock_pe = self.enter_context( + mock.patch.object(plugin_executable, "PluginExecutable") + ) + mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock()) + + mock_get_resharding_plan_cached = self.enter_context( + mock.patch.object(reshard, "_get_resharding_plan_cached") + ) + + if cache is None: + reshard.reshard(x, sharding) + else: + reshard.reshard(x, sharding, cache_resharding_plans=cache) + + self.assertEqual(mock_pe.call_count, 0 if expected_cache else 1) + + self.assertEqual( + mock_get_resharding_plan_cached.call_count, + 1 if expected_cache else 0, + ) + +if __name__ == "__main__": absltest.main()