Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import json
from typing import Any, Callable, Dict, Mapping, Sequence
import warnings

import jax
from pathwaysutils import jax as pw_jax
Expand Down Expand Up @@ -104,7 +105,7 @@ def _reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool = False,
donate: bool,
may_alias: bool | None,
jax_array_reshard_fn: Callable[..., Any],
**kwargs,
Expand Down Expand Up @@ -198,6 +199,61 @@ def _ifrt_jax_array_reshard(
)


def _reshard_with_sidechannel(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool,
may_alias: bool | None,
cache_resharding_plans: bool,
) -> Any:
"""Reshards `x` to `sharding` using sidechannel."""
return _reshard(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
cache_resharding_plans=cache_resharding_plans,
)


def _reshard_with_ifrt(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool,
may_alias: bool | None,
) -> Any:
"""Reshards `x` to `sharding` using IFRT.

Note: Resharding plan caching is not applicable to the IFRT implementation
and is not supported by this function.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
(must be a tree prefix of `x`), representing the device(s) and sharding to
which `x` should be sharded to. The result will be committed to the
device(s) of the sharding.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.

Returns:
A copy of `x` whose sharding is `sharding`.
"""
return _reshard(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_ifrt_jax_array_reshard,
)


def reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
Expand All @@ -221,29 +277,34 @@ def reshard(
reduce the amount of memory needed for resharding. Not used at the moment.
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
recreating plans for the same resharding operation. May improve
performance for use cases where the same resharding operation is done many
times. May degrade performance if most reshardings operations are
different, since the cache will cause Pathways Components to remain loaded
for each cached plan. `False` by default. Only used when IFRT resharding
is not available.
performance for use cases where the same resharding operation is done
many times. May degrade performance if most reshardings operations are
different, since the cache will cause Pathways Components to remain
loaded for each cached plan. `False` by default. This parameter is only
used when `pw_jax.ifrt_reshard_available()` is false.

Returns:
A copy of `x` whose sharding is `sharding`.
"""
if pw_jax.ifrt_reshard_available():
return _reshard(
if cache_resharding_plans:
warnings.warn(
"`cache_resharding_plans` is only applicable when using the"
" sidechannel resharding implementation, but IFRT resharding is"
" available and will be used. The `cache_resharding_plans` argument"
" will be ignored."
)
return _reshard_with_ifrt(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_ifrt_jax_array_reshard,
)
else:
return _reshard(
return _reshard_with_sidechannel(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
cache_resharding_plans=cache_resharding_plans,
)
147 changes: 147 additions & 0 deletions pathwaysutils/test/experimental/reshard_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Mapping
import json
from typing import Any
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from pathwaysutils import jax as pw_jax
from pathwaysutils import plugin_executable
from pathwaysutils.experimental import reshard


class ReshardTest(parameterized.TestCase):

@parameterized.parameters(
dict(reshard_kwargs={"donate": True}, expected_donate=True),
dict(reshard_kwargs={"donate": False}, expected_donate=False),
dict(reshard_kwargs={}, expected_donate=False),
)
def test_ifrt_reshard_donate(
self, reshard_kwargs: Mapping[str, Any], expected_donate: bool
):
x = jnp.array([1, 2])
devices = jax.devices()
sharding = jax.sharding.SingleDeviceSharding(devices[0])

mock_transfer = self.enter_context(
mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True)
)
self.enter_context(
mock.patch.object(
pw_jax, "ifrt_reshard_available", return_value=True, autospec=True
)
)

reshard.reshard(x, sharding, **reshard_kwargs)

# Signature: transfer_to_shardings(arrays, shardings, donate)
mock_transfer.assert_called_with(mock.ANY, mock.ANY, expected_donate)

@parameterized.parameters(
dict(reshard_kwargs={"donate": True}, expected_donate=True),
dict(reshard_kwargs={"donate": False}, expected_donate=False),
dict(reshard_kwargs={}, expected_donate=False),
)
def test_sidechannel_reshard_donate(
self, reshard_kwargs: Mapping[str, Any], expected_donate: bool
):
x = jnp.array([1, 2])
devices = jax.devices()
sharding = jax.sharding.SingleDeviceSharding(devices[0])

self.enter_context(
mock.patch.object(
pw_jax, "ifrt_reshard_available", return_value=False, autospec=True
)
)
mock_pe = self.enter_context(
mock.patch.object(plugin_executable, "PluginExecutable", autospec=True)
)
mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock())

reshard.reshard(x, sharding, **reshard_kwargs)

mock_pe.assert_called()
(json_request,), _ = mock_pe.call_args
request = json.loads(json_request)
self.assertEqual(request["reshardRequest"]["donateInput"], expected_donate)

@parameterized.parameters(True, False, None)
def test_ifrt_reshard_cache_resharding_plans(self, cache: bool | None):
x = jnp.array([1, 2])
devices = jax.devices()
sharding = jax.sharding.SingleDeviceSharding(devices[0])

mock_transfer = self.enter_context(
mock.patch.object(pw_jax, "transfer_to_shardings")
)
self.enter_context(
mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=True)
)

if cache is None:
reshard.reshard(x, sharding)
elif cache:
with self.assertWarnsRegex(
UserWarning, "cache_resharding_plans` is only applicable"
):
reshard.reshard(x, sharding, cache_resharding_plans=cache)
else:
reshard.reshard(x, sharding, cache_resharding_plans=cache)

mock_transfer.assert_called_once()

@parameterized.parameters(
dict(cache=True, expected_cache=True),
dict(cache=False, expected_cache=False),
dict(cache=None, expected_cache=False),
)
def test_sidechannel_reshard_cache_resharding_plans(
self, cache, expected_cache
):
x = jnp.array([1, 2])
devices = jax.devices()
sharding = jax.sharding.SingleDeviceSharding(devices[0])

self.enter_context(
mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=False)
)
mock_pe = self.enter_context(
mock.patch.object(plugin_executable, "PluginExecutable")
)
mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock())

mock_get_resharding_plan_cached = self.enter_context(
mock.patch.object(reshard, "_get_resharding_plan_cached")
)

if cache is None:
reshard.reshard(x, sharding)
else:
reshard.reshard(x, sharding, cache_resharding_plans=cache)

self.assertEqual(mock_pe.call_count, 0 if expected_cache else 1)

self.assertEqual(
mock_get_resharding_plan_cached.call_count,
1 if expected_cache else 0,
)

if __name__ == "__main__": absltest.main()
Loading