Skip to content

Commit a5d9a76

Browse files
author
Orbax Authors
committed
#p2p Add Grain data iterator checkpointing to P2P
PiperOrigin-RevId: 863828755
1 parent a118d85 commit a5d9a76

14 files changed

Lines changed: 1930 additions & 47 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""P2P composite checkpoint argument."""
16+
17+
from typing import Any, final
18+
19+
from orbax.checkpoint import args as args_lib
20+
from orbax.checkpoint.experimental.emergency.p2p import constants
21+
from orbax.checkpoint.experimental.emergency.p2p import utils
22+
23+
24+
@final
25+
class Composite(args_lib.Composite):
26+
"""Composite argument that supports 'state' and 'data_iter' keys."""
27+
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
if constants.STATE_SUBDIR not in self:
31+
raise ValueError(
32+
f'Composite must contain "{constants.STATE_SUBDIR}" key:'
33+
f' {list(self.keys())}'
34+
)
35+
for key in self:
36+
if key not in [constants.STATE_SUBDIR, constants.DATA_ITER_KEY]:
37+
raise ValueError(f'Unsupported key in Composite: {key}')
38+
if key == constants.DATA_ITER_KEY:
39+
if utils.pygrain() is None:
40+
raise ImportError(
41+
'grain library is not available. Please install grain'
42+
' to use data_iter.'
43+
)
44+
if not isinstance(
45+
self[key],
46+
(
47+
utils.pygrain().PyGrainCheckpointSave,
48+
utils.pygrain().PyGrainCheckpointRestore,
49+
),
50+
):
51+
raise TypeError(f'Unsupported type for data_iter: {type(self[key])}')
52+
53+
def __setitem__(self, key: str, value: Any):
54+
if key not in [constants.STATE_SUBDIR, constants.DATA_ITER_KEY]:
55+
raise KeyError(
56+
f'Invalid key: {key}. Only "{constants.STATE_SUBDIR}" and'
57+
f' "{constants.DATA_ITER_KEY}" are supported.'
58+
)
59+
if key == constants.DATA_ITER_KEY:
60+
if utils.pygrain() is None:
61+
raise ImportError(
62+
'grain library is not available. Please install grain'
63+
' to use data_iter.'
64+
)
65+
if not isinstance(
66+
value,
67+
(
68+
utils.pygrain().PyGrainCheckpointSave,
69+
utils.pygrain().PyGrainCheckpointRestore,
70+
),
71+
):
72+
raise TypeError(f'Unsupported type for data_iter: {type(value)}')
73+
self[key] = value

0 commit comments

Comments
 (0)