@@ -70,6 +70,34 @@ def _create_replicator_file(
7070 os .rename (temp_file , replicator_file )
7171
7272
73+ def _initialize_jax_from_mtc (
74+ local_checkpoint_directory : epath .Path ,
75+ jax_initialization_timeout_seconds : int = 900 ,
76+ ) -> str :
77+ """Initialize jax with jax_init_info."""
78+ local_checkpoint_directory = epath .Path (local_checkpoint_directory )
79+ process_id , coordinator_address = _retrieve_jax_init_info (
80+ local_checkpoint_directory
81+ )
82+ if not process_id or not coordinator_address :
83+ raise ValueError (
84+ 'Data is missing from the JAX init info file: Current values:'
85+ f' process_id: { process_id } , coordinator_address: { coordinator_address } '
86+ )
87+ logging .info (
88+ 'Using process_id %s and coordinator_address %s to initialize JAX'
89+ ' distributed runtime...' ,
90+ process_id ,
91+ coordinator_address ,
92+ )
93+ jax .distributed .initialize (
94+ process_id = int (process_id ),
95+ coordinator_address = coordinator_address ,
96+ initialization_timeout = jax_initialization_timeout_seconds ,
97+ )
98+ return process_id
99+
100+
73101def initialize_multi_tier_checkpointing (
74102 local_checkpoint_directory : epath .Path ,
75103 * ,
@@ -78,6 +106,7 @@ def initialize_multi_tier_checkpointing(
78106 run_name : Optional [str ] = None ,
79107 data_parallelism : Optional [int ] = None ,
80108 jax_initialization_timeout_seconds : int = 900 ,
109+ use_mtc_process_ids : bool = True ,
81110):
82111 """Initializes multi-tier checkpointing.
83112
@@ -91,27 +120,18 @@ def initialize_multi_tier_checkpointing(
91120 equal to ICI data parallelism * DCN data parallelism. If not provided, it
92121 will be inferred from the number of slices.
93122 jax_initialization_timeout_seconds: The timeout for JAX initialization.
123+ use_mtc_process_ids: Use the MTC rank server to calculate process ids.
94124 """
95- local_checkpoint_directory = epath .Path (local_checkpoint_directory )
96- process_id , coordinator_address = _retrieve_jax_init_info (
97- local_checkpoint_directory
98- )
99- if not process_id or not coordinator_address :
100- raise ValueError (
101- 'Data is missing from the JAX init info file: Current values:'
102- f' process_id: { process_id } , coordinator_address: { coordinator_address } '
125+ if use_mtc_process_ids :
126+ process_id = _initialize_jax_from_mtc (
127+ local_checkpoint_directory , jax_initialization_timeout_seconds
103128 )
104- logging .info (
105- 'Using process_id %s and coordinator_address %s to initialize JAX'
106- ' distributed runtime...' ,
107- process_id ,
108- coordinator_address ,
109- )
110- jax .distributed .initialize (
111- process_id = int (process_id ),
112- coordinator_address = coordinator_address ,
113- initialization_timeout = jax_initialization_timeout_seconds ,
114- )
129+ else :
130+ process_id = None
131+ jax .distributed .initialize (
132+ initialization_timeout = jax_initialization_timeout_seconds ,
133+ )
134+
115135 multihost .initialize_runtime_to_distributed_ids ()
116136 multihost .initialize_distributed_to_device_ids ()
117137 _wait_for_replicator_file_to_disappear (local_checkpoint_directory )
@@ -127,14 +147,24 @@ def initialize_multi_tier_checkpointing(
127147 process_index_to_node_rank = (
128148 multihost .runtime_to_distributed_ids ()
129149 )
130- logging .info (
131- 'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
132- ' ProcessIndex->NodeRank=%s' ,
133- process_id ,
134- node_rank ,
135- my_process_index ,
136- process_index_to_node_rank ,
137- )
150+ if use_mtc_process_ids :
151+ logging .info (
152+ 'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
153+ ' ProcessIndex->NodeRank=%s' ,
154+ process_id ,
155+ node_rank ,
156+ my_process_index ,
157+ process_index_to_node_rank ,
158+ )
159+ else :
160+ logging .info (
161+ 'Mapping of IDs (jax-init-info not used): NodeRank=%s, ProcessIndex=%s,'
162+ ' ProcessIndex->NodeRank=%s' ,
163+ node_rank ,
164+ my_process_index ,
165+ process_index_to_node_rank ,
166+ )
167+
138168 my_in_pipeline_index = my_process_index % nodes_per_slice
139169 peer_ranks = []
140170 for i in range (num_slices ):
0 commit comments