Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit a3b55dd

Browse files
danielsuocopybara-github
authored andcommitted
Disable e2e tests when jax_pmap_shmap_merge=True.
`trainer._multi_device_update_fn` uses `jax.pmap` and when `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly sharded as the underlying `jax.jit` expects. This would need to be fixed for when `jax_pmap_shmap_merge=True` by default. PiperOrigin-RevId: 811399636
1 parent 3d4a276 commit a3b55dd

3 files changed

Lines changed: 17 additions & 0 deletions

File tree

trax/models/reformer/reformer_e2e_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from absl.testing import absltest
2121
import gin
22+
import jax
2223

2324
from trax import test_utils
2425
from trax.models.reformer import reformer # pylint: disable=unused-import
@@ -36,6 +37,11 @@ def setUp(self):
3637
super().setUp()
3738
gin.clear_config()
3839
gin.add_config_file_search_path(_CONFIG_DIR)
40+
# NOTE(dsuo): trainer._multi_device_update_fn uses `jax.pmap` and when
41+
# `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly
42+
# sharded as the underlying `jax.jit` expects. This would need to be fixed.
43+
if jax.config.jax_pmap_shmap_merge:
44+
self.skipTest('Skipping test because jax_pmap_shmap_merge is enabled.')
3945
test_utils.ensure_flag('test_tmpdir')
4046

4147
def test_reformer_wmt_ende(self):

trax/models/research/terraformer_e2e_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from absl.testing import absltest
2121
import gin
22+
import jax
2223

2324
from trax import test_utils
2425
from trax.models.research import terraformer # pylint: disable=unused-import
@@ -36,6 +37,11 @@ def setUp(self):
3637
super().setUp()
3738
gin.clear_config()
3839
gin.add_config_file_search_path(_CONFIG_DIR)
40+
# NOTE(dsuo): trainer._multi_device_update_fn uses `jax.pmap` and when
41+
# `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly
42+
# sharded as the underlying `jax.jit` expects. This would need to be fixed.
43+
if jax.config.jax_pmap_shmap_merge:
44+
self.skipTest('Skipping test because jax_pmap_shmap_merge is enabled.')
3945
test_utils.ensure_flag('test_tmpdir')
4046

4147
def test_terraformer_wmt_ende(self):

trax/supervised/trainer_lib_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ def setUp(self):
156156
super().setUp()
157157
test_utils.ensure_flag('test_tmpdir')
158158
self._old_is_allow_float64 = tf_np.is_allow_float64()
159+
# NOTE(dsuo): trainer._multi_device_update_fn uses `jax.pmap` and when
160+
# `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly
161+
# sharded as the underlying `jax.jit` expects. This would need to be fixed.
162+
if jax.config.jax_pmap_shmap_merge:
163+
self.skipTest('Skipping test because jax_pmap_shmap_merge is enabled.')
159164
tf_np.set_allow_float64(False)
160165

161166
def tearDown(self):

0 commit comments

Comments
 (0)