Skip to content

Commit 0fdf0ff

Browse files
author
Orbax Authors
committed
Add escape hatch for jax_init_info\
PiperOrigin-RevId: 869850598
1 parent cf5a1ce commit 0fdf0ff

3 files changed

Lines changed: 113 additions & 27 deletions

File tree

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- `uvloop` dependency for improved event loop performance
13+
- Add option to disable for multi-tier checkpointing process id initialization
1314

1415
### Removed
1516

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,34 @@ def _create_replicator_file(
7070
os.rename(temp_file, replicator_file)
7171

7272

73+
def _initialize_jax_from_mtc(
74+
local_checkpoint_directory: epath.Path,
75+
jax_initialization_timeout_seconds: int = 900,
76+
) -> str:
77+
"""Initialize jax with jax_init_info."""
78+
local_checkpoint_directory = epath.Path(local_checkpoint_directory)
79+
process_id, coordinator_address = _retrieve_jax_init_info(
80+
local_checkpoint_directory
81+
)
82+
if not process_id or not coordinator_address:
83+
raise ValueError(
84+
'Data is missing from the JAX init info file: Current values:'
85+
f' process_id: {process_id}, coordinator_address: {coordinator_address}'
86+
)
87+
logging.info(
88+
'Using process_id %s and coordinator_address %s to initialize JAX'
89+
' distributed runtime...',
90+
process_id,
91+
coordinator_address,
92+
)
93+
jax.distributed.initialize(
94+
process_id=int(process_id),
95+
coordinator_address=coordinator_address,
96+
initialization_timeout=jax_initialization_timeout_seconds,
97+
)
98+
return process_id
99+
100+
73101
def initialize_multi_tier_checkpointing(
74102
local_checkpoint_directory: epath.Path,
75103
*,
@@ -78,6 +106,7 @@ def initialize_multi_tier_checkpointing(
78106
run_name: Optional[str] = None,
79107
data_parallelism: Optional[int] = None,
80108
jax_initialization_timeout_seconds: int = 900,
109+
use_mtc_process_ids: Optional[bool] = True,
81110
):
82111
"""Initializes multi-tier checkpointing.
83112
@@ -91,27 +120,18 @@ def initialize_multi_tier_checkpointing(
91120
equal to ICI data parallelism * DCN data parallelism. If not provided, it
92121
will be inferred from the number of slices.
93122
jax_initialization_timeout_seconds: The timeout for JAX initialization.
123+
use_mtc_process_ids: Use the MTC rank server to calculate process ids.
94124
"""
95-
local_checkpoint_directory = epath.Path(local_checkpoint_directory)
96-
process_id, coordinator_address = _retrieve_jax_init_info(
97-
local_checkpoint_directory
98-
)
99-
if not process_id or not coordinator_address:
100-
raise ValueError(
101-
'Data is missing from the JAX init info file: Current values:'
102-
f' process_id: {process_id}, coordinator_address: {coordinator_address}'
125+
if use_mtc_process_ids:
126+
process_id = _initialize_jax_from_mtc(
127+
local_checkpoint_directory, jax_initialization_timeout_seconds
103128
)
104-
logging.info(
105-
'Using process_id %s and coordinator_address %s to initialize JAX'
106-
' distributed runtime...',
107-
process_id,
108-
coordinator_address,
109-
)
110-
jax.distributed.initialize(
111-
process_id=int(process_id),
112-
coordinator_address=coordinator_address,
113-
initialization_timeout=jax_initialization_timeout_seconds,
114-
)
129+
else:
130+
process_id = None
131+
jax.distributed.initialize(
132+
initialization_timeout=jax_initialization_timeout_seconds,
133+
)
134+
115135
multihost.initialize_runtime_to_distributed_ids()
116136
multihost.initialize_distributed_to_device_ids()
117137
_wait_for_replicator_file_to_disappear(local_checkpoint_directory)
@@ -127,14 +147,24 @@ def initialize_multi_tier_checkpointing(
127147
process_index_to_node_rank = (
128148
multihost.runtime_to_distributed_ids()
129149
)
130-
logging.info(
131-
'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
132-
' ProcessIndex->NodeRank=%s',
133-
process_id,
134-
node_rank,
135-
my_process_index,
136-
process_index_to_node_rank,
137-
)
150+
if use_mtc_process_ids:
151+
logging.info(
152+
'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
153+
' ProcessIndex->NodeRank=%s',
154+
process_id,
155+
node_rank,
156+
my_process_index,
157+
process_index_to_node_rank,
158+
)
159+
else:
160+
logging.info(
161+
'Mapping of IDs (jax-init-info not used): NodeRank=%s, ProcessIndex=%s,'
162+
' ProcessIndex->NodeRank=%s',
163+
node_rank,
164+
my_process_index,
165+
process_index_to_node_rank,
166+
)
167+
138168
my_in_pipeline_index = my_process_index % nodes_per_slice
139169
peer_ranks = []
140170
for i in range(num_slices):

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,61 @@ def test_initialize_multi_tier_checkpointing_run_name_not_set(
288288
mock_initialize_distributed_to_device_ids.assert_called_once()
289289
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 1)
290290

291+
@mock.patch.object(
292+
initialization, "_wait_for_replicator_file_to_disappear", autospec=True
293+
)
294+
@mock.patch.object(initialization, "_create_replicator_file", autospec=True)
295+
@mock.patch.object(jax.distributed, "initialize", autospec=True)
296+
@mock.patch.object(
297+
multihost, "initialize_runtime_to_distributed_ids", autospec=True
298+
)
299+
@mock.patch.object(
300+
multihost, "initialize_distributed_to_device_ids", autospec=True
301+
)
302+
@mock.patch.object(multihost, "runtime_to_distributed_ids", autospec=True)
303+
def test_initialize_multi_tier_checkpointing_skip_init_info(
304+
self,
305+
mock_runtime_to_distributed_ids,
306+
mock_initialize_distributed_to_device_ids,
307+
mock_initialize_runtime_to_distributed_ids,
308+
mock_jax_distributed_initialize,
309+
mock_create_replicator_file,
310+
mock_wait_for_replicator_file_to_disappear,
311+
):
312+
mock_runtime_to_distributed_ids.return_value = [0, 1]
313+
mock_jax_distributed_initialize.return_value = None
314+
mock_initialize_runtime_to_distributed_ids.return_value = [None, None]
315+
mock_initialize_distributed_to_device_ids.return_value = None
316+
mock_create_replicator_file.return_value = [None, None]
317+
mock_wait_for_replicator_file_to_disappear.return_value = False
318+
319+
with tempfile.TemporaryDirectory() as tmp_dir:
320+
epath.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
321+
replicator_file = epath.Path(tmp_dir) / initialization._REPLICATOR_FILE
322+
replicator_file.write_text("replicator.yaml")
323+
self.assertTrue(replicator_file.exists())
324+
325+
restore_dir = epath.Path(tmp_dir) / "test-run-s1-n0-w0.restore"
326+
restore_dir.write_text("restore_dir")
327+
self.assertTrue(restore_dir.exists())
328+
329+
initialization.initialize_multi_tier_checkpointing(
330+
epath.Path(tmp_dir),
331+
num_slices=1,
332+
run_name="test-run",
333+
data_parallelism=1,
334+
use_mtc_process_ids=False,
335+
)
336+
mock_jax_distributed_initialize.assert_called_once_with(
337+
initialization_timeout=900,
338+
)
339+
mock_initialize_runtime_to_distributed_ids.assert_called_once()
340+
mock_initialize_distributed_to_device_ids.assert_called_once()
341+
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 2)
342+
mock_create_replicator_file.assert_called_once()
343+
expected_restore_dir = epath.Path(tmp_dir) / "1"
344+
self.assertTrue(expected_restore_dir.exists())
345+
291346

292347
if __name__ == "__main__":
293348
absltest.main()

0 commit comments

Comments
 (0)