Skip to content

Commit 30f1c85

Browse files
author
Orbax Authors
committed
Add benchmarks for P2P CheckpointManager.
PiperOrigin-RevId: 873854861
1 parent daca635 commit 30f1c85

3 files changed

Lines changed: 295 additions & 1 deletion

File tree

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/emergency_checkpoint_manager_benchmark.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ mesh_configs:
1515
ici_parallelism: {"data": 1, "model": 4}
1616
dcn_parallelism: {"data": 2, "model": 1}
1717
- mesh_axes: ["data", "model", "tensor", "fsdp"]
18-
ici_parallelism: {"data": 1, "model": 16}
18+
ici_parallelism: {"data": 2, "model": 8}
1919
dcn_parallelism: {"data": 2, "model": 1}
20+
allow_split_physical_axes: true
2021
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
2122
ici_parallelism: {"fsdp": 16, "tensor": 2, "data": 1}
2223
dcn_parallelism: {"data": 2}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 2, "model": 4}
16+
dcn_parallelism: {"data": 2, "model": 1}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 1, "model": 16}
20+
dcn_parallelism: {"data": 2, "model": 1}
21+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
22+
ici_parallelism: {"fsdp": 16, "tensor": 2, "data": 1}
23+
dcn_parallelism: {"data": 2}
24+
25+
checkpoint_config:
26+
spec:
27+
a_1d: {dtype: "float32", shape: [32], sharding: [null]}
28+
b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
29+
c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
30+
d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
31+
e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
32+
f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
33+
g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
34+
h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
35+
i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
36+
j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
37+
k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
38+
custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
39+
40+
benchmarks:
41+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
42+
options:
43+
persistent_save_interval_steps: [2]
44+
persistent_max_to_keep: [5]
45+
local_save_interval_steps: [2]
46+
local_max_to_keep: 2
47+
replica_axis_index: 0
48+
train_steps: 5
49+
experimental_orbax_use_distributed_process_id: true
50+
experimental_use_distributed_id_for_mesh_consistency: true
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmarks for orbax.checkpoint.experimental.emergency.p2p.checkpoint_manager.CheckpointManager."""
16+
17+
from collections.abc import Sequence
18+
import dataclasses
19+
from typing import Any
20+
from absl import logging
21+
from etils import epath
22+
import jax
23+
from orbax.checkpoint import checkpoint_utils
24+
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
25+
from orbax.checkpoint._src.multihost import multihost
26+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
27+
from orbax.checkpoint._src.testing.benchmarks.core import mesh_utils
28+
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
29+
from orbax.checkpoint._src.testing.benchmarks.core import pytree_utils
30+
from orbax.checkpoint._src.tree import utils
31+
from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib
32+
from orbax.checkpoint.experimental.emergency.p2p import checkpoint_manager as p2p_checkpoint_manager
33+
from orbax.checkpoint.experimental.emergency.p2p import options as p2p_options
34+
35+
36+
# ==============================================================================
37+
# 1. Define the Options Dataclass
38+
# ==============================================================================
39+
@dataclasses.dataclass(frozen=True)
40+
class P2pBenchmarkOptions(benchmarks_core.BenchmarkOptions):
41+
"""Configuration options for benchmarks targeting P2P CheckpointManager.
42+
43+
Attributes:
44+
persistent_save_interval_steps: The interval at which persistent checkpoints
45+
should be saved.
46+
persistent_max_to_keep: The maximum number of persistent checkpoints to
47+
keep.
48+
local_save_interval_steps: The interval at which local checkpoints should be
49+
saved.
50+
local_max_to_keep: The maximum number of local checkpoints to keep.
51+
replica_axis_index: The index of the replica axis in the global mesh.
52+
train_steps: The number of training steps to run.
53+
"""
54+
55+
persistent_save_interval_steps: int | Sequence[int] = 5
56+
persistent_max_to_keep: int | Sequence[int] = 5
57+
local_save_interval_steps: int | Sequence[int] = 2
58+
local_max_to_keep: int | Sequence[int] = 2
59+
replica_axis_index: int | Sequence[int] = 0
60+
train_steps: int | Sequence[int] = 10
61+
experimental_use_distributed_id_for_mesh_consistency: (
62+
bool | Sequence[bool]
63+
) = True
64+
experimental_orbax_use_distributed_process_id: bool | Sequence[bool] = True
65+
66+
67+
# ==============================================================================
68+
# 2. Implement the Benchmark Generator
69+
# ==============================================================================
70+
def _create_checkpoint_manager(
71+
local_directory: epath.Path,
72+
persistent_directory: epath.Path,
73+
global_mesh: jax.sharding.Mesh,
74+
abstract_state: Any,
75+
options: P2pBenchmarkOptions,
76+
) -> p2p_checkpoint_manager.CheckpointManager:
77+
"""Creates an P2P CheckpointManager."""
78+
return p2p_checkpoint_manager.CheckpointManager(
79+
local_directory=local_directory,
80+
persistent_directory=persistent_directory,
81+
global_mesh=global_mesh,
82+
abstract_state=abstract_state,
83+
options=p2p_options.CheckpointManagerOptions(
84+
local=p2p_options.LocalCheckpointOptions(
85+
save_interval_steps=options.local_save_interval_steps,
86+
max_to_keep=options.local_max_to_keep,
87+
),
88+
persistent=p2p_options.PersistentCheckpointOptions(
89+
save_interval_steps=options.persistent_save_interval_steps,
90+
max_to_keep=options.persistent_max_to_keep,
91+
),
92+
replica_axis_index=options.replica_axis_index,
93+
),
94+
)
95+
96+
97+
def _restore_and_validate(
98+
manager: p2p_checkpoint_manager.CheckpointManager,
99+
metrics: metric_lib.Metrics,
100+
pytree: Any,
101+
step: int,
102+
local_directory: epath.Path,
103+
restore_args: Any,
104+
):
105+
"""Restores a checkpoint and validates it."""
106+
# Wait for save to complete on all hosts.
107+
with metrics.measure(f"sync_global_processes_{step}"):
108+
multihost.sync_global_processes(f"save_completed_{step}")
109+
110+
step_dir = local_directory / str(step)
111+
step_dir_backup = local_directory / f"backup_{step}"
112+
if multihost.process_index() == 0 and step_dir.exists():
113+
logging.info("Process 0: removing local checkpoint to trigger P2P restore.")
114+
step_dir.rename(step_dir_backup)
115+
116+
with metrics.measure(f"restore_{step}"):
117+
restored = manager.restore(
118+
step,
119+
args=p2p_args_lib.Composite(
120+
state=pytree_checkpoint_handler.PyTreeRestoreArgs(
121+
restore_args=restore_args
122+
)
123+
),
124+
)["state"]
125+
pytree_utils.log_pytree("Restored Pytree", restored)
126+
logging.info("Assert Restored Pytree")
127+
pytree_utils.assert_pytree_equal(pytree, restored)
128+
129+
if multihost.process_index() == 0 and step_dir_backup.exists():
130+
logging.info("Process 0: restoring local checkpoint.")
131+
step_dir_backup.rename(step_dir)
132+
133+
with metrics.measure(f"reload_after_restore_{step}"):
134+
manager.reload()
135+
136+
137+
@benchmarks_core.benchmark_options(P2pBenchmarkOptions)
138+
class P2pCheckpointManagerBenchmark(benchmarks_core.BenchmarksGenerator):
139+
"""A generator for benchmarking P2P CheckpointManager."""
140+
141+
def test_fn(
142+
self, context: benchmarks_core.TestContext
143+
) -> benchmarks_core.TestResult:
144+
"""The core test logic for a single save/restore cycle."""
145+
metrics = metric_lib.Metrics()
146+
pytree = context.pytree
147+
persistent_directory = context.path / "persistent_p2p_ckpt"
148+
if context.local_path is not None:
149+
local_path = epath.Path(context.local_path) / "local_p2p_ckpt"
150+
local_directory = epath.Path(local_path)
151+
local_directory.mkdir(parents=True, exist_ok=True)
152+
else:
153+
local_directory = (
154+
context.path
155+
/ "local_p2p_ckpt"
156+
/ f"process_{multihost.process_index()}"
157+
)
158+
options = context.options
159+
mesh = context.mesh
160+
assert isinstance(options, P2pBenchmarkOptions)
161+
162+
if mesh is None:
163+
raise ValueError(
164+
"Mesh must be provided for P2pCheckpointManagerBenchmark"
165+
)
166+
# flags.FLAGS.experimental_use_distributed_id_for_mesh_consistency = (
167+
# options.experimental_use_distributed_id_for_mesh_consistency
168+
# )
169+
# flags.FLAGS.experimental_orbax_use_distributed_process_id = (
170+
# options.experimental_orbax_use_distributed_process_id
171+
# )
172+
if not multihost.is_runtime_to_distributed_ids_initialized():
173+
multihost.initialize_runtime_to_distributed_ids()
174+
175+
if not multihost.is_distributed_to_device_ids_initialized():
176+
multihost.initialize_distributed_to_device_ids()
177+
178+
mesh_utils.pretty_log_mesh("Global Mesh: ", mesh)
179+
180+
with metrics.measure("create_directories"):
181+
if jax.process_index() == 0:
182+
persistent_directory.mkdir(parents=True)
183+
local_directory.mkdir(parents=True, exist_ok=True)
184+
multihost.sync_global_processes("create directories")
185+
186+
with metrics.measure("create_abstract_pytree"):
187+
abstract_pytree = jax.tree.map(utils.to_shape_dtype_struct, pytree)
188+
logging.info("abstract_pytree: %r", abstract_pytree)
189+
190+
with metrics.measure("create_restore_args"):
191+
restore_args = checkpoint_utils.construct_restore_args(abstract_pytree)
192+
logging.info("restore_args: %r", restore_args)
193+
194+
with metrics.measure("create_checkpoint_manager"):
195+
manager = _create_checkpoint_manager(
196+
local_directory=local_directory,
197+
persistent_directory=persistent_directory,
198+
global_mesh=mesh,
199+
abstract_state=abstract_pytree,
200+
options=options,
201+
)
202+
203+
step = manager.latest_step()
204+
if step is not None:
205+
logging.info("Latest step: %d", step)
206+
207+
with metrics.measure(f"restore_and_validate_{step}"):
208+
_restore_and_validate(
209+
manager,
210+
metrics,
211+
pytree,
212+
step,
213+
local_directory,
214+
restore_args,
215+
)
216+
217+
start_step = step + 1 if step is not None else 0
218+
with metrics.measure("train_loop"):
219+
for step in range(start_step, options.train_steps):
220+
logging.info("Training step %d", step)
221+
with metrics.measure(f"save_{step}"):
222+
manager.save(
223+
step,
224+
args=p2p_args_lib.Composite(
225+
state=pytree_checkpoint_handler.PyTreeSaveArgs(pytree)
226+
),
227+
)
228+
with metrics.measure(f"wait_until_finished_{step}"):
229+
manager.wait_until_finished()
230+
231+
if step % options.local_save_interval_steps == 0:
232+
with metrics.measure(f"restore_and_validate_{step}"):
233+
_restore_and_validate(
234+
manager,
235+
metrics,
236+
pytree,
237+
step,
238+
local_directory,
239+
restore_args,
240+
)
241+
242+
manager.close()
243+
return benchmarks_core.TestResult(metrics=metrics)

0 commit comments

Comments
 (0)