Skip to content

Commit 64886f7

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 868729588
1 parent 9cd8fc6 commit 64886f7

2 files changed

Lines changed: 179 additions & 0 deletions

File tree

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Merging utility for Orbax checkpoints."""
16+
17+
import asyncio
18+
from collections.abc import Sequence
19+
20+
from absl import app
21+
from absl import flags
22+
from etils import epath
23+
import jax
24+
from orbax.checkpoint.experimental.v1._src.layout import orbax_layout
25+
from orbax.checkpoint.experimental.v1._src.partial import merging
26+
27+
28+
FLAGS = flags.FLAGS
29+
30+
_IN_PATHS = flags.DEFINE_multi_string(
31+
'in_paths',
32+
None,
33+
'Paths of checkpoints to merge.',
34+
)
35+
_OUT_PATH = flags.DEFINE_string(
36+
'out_path',
37+
None,
38+
'Output checkpoint path.',
39+
)
40+
_PER_HOST_MEMORY_LIMIT_BYTES = flags.DEFINE_integer(
41+
'per_host_memory_limit_bytes',
42+
None,
43+
'Memory limit in bytes per CPU host for partial loading and saving.'
44+
' Non-uniform memory limits are not supported.',
45+
)
46+
47+
48+
def main(argv: Sequence[str]) -> None:
49+
if len(argv) > 1:
50+
raise app.UsageError('Too many command-line arguments.')
51+
52+
if not _IN_PATHS.value:
53+
raise app.UsageError('Flag --in_paths must be specified.')
54+
if _OUT_PATH.value is None:
55+
raise app.UsageError('Flag --out_path must be specified.')
56+
if _PER_HOST_MEMORY_LIMIT_BYTES.value is None:
57+
raise app.UsageError(
58+
'Flag --per_host_memory_limit_bytes must be specified.'
59+
)
60+
61+
if _PER_HOST_MEMORY_LIMIT_BYTES.value <= 0:
62+
raise ValueError('per_host_memory_limit_bytes must be positive.')
63+
64+
# Validate input checkpoints.
65+
layout = orbax_layout.OrbaxLayout()
66+
for path_str in _IN_PATHS.value:
67+
path = epath.Path(path_str)
68+
if not path.exists():
69+
raise FileNotFoundError(f'Input path {path_str} does not exist.')
70+
# OrbaxLayout.validate is async.
71+
try:
72+
asyncio.run(layout.validate(path))
73+
except Exception as e:
74+
raise ValueError(
75+
f'Input path {path_str} is not a valid checkpoint.'
76+
) from e
77+
78+
# Validate output path.
79+
out_path = epath.Path(_OUT_PATH.value)
80+
if out_path.exists():
81+
if not out_path.is_dir():
82+
raise ValueError(
83+
f'Output path {_OUT_PATH.value} exists but is not a directory.'
84+
)
85+
if list(out_path.iterdir()):
86+
raise ValueError(
87+
f'Output path {_OUT_PATH.value} exists and is not empty.'
88+
)
89+
90+
if jax.process_index() == 0:
91+
out_path.mkdir(parents=True, exist_ok=True)
92+
93+
merging.merge_checkpoints(
94+
_IN_PATHS.value,
95+
_OUT_PATH.value,
96+
_PER_HOST_MEMORY_LIMIT_BYTES.value,
97+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from absl.testing import absltest
18+
from absl.testing import flagsaver
19+
from etils import epath
20+
import jax
21+
from orbax.checkpoint.experimental.v1._src.layout import orbax_layout
22+
from orbax.checkpoint.experimental.v1._src.partial import merging
23+
from orbax.checkpoint.experimental.v1._src.partial import run_merging
24+
25+
26+
class RunMergingTest(absltest.TestCase):
27+
28+
def setUp(self):
29+
super().setUp()
30+
self.out_path = self.create_tempdir().full_path
31+
self.in_paths = [self.create_tempdir().full_path]
32+
33+
@mock.patch.object(
34+
orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock
35+
)
36+
@mock.patch.object(merging, 'merge_checkpoints', autospec=True)
37+
@mock.patch.object(jax, 'process_index', return_value=0)
38+
def test_main_success(self, _, mock_merge, mock_validate):
39+
with flagsaver.flagsaver(
40+
in_paths=self.in_paths,
41+
out_path=self.out_path,
42+
per_host_memory_limit_bytes=1024,
43+
):
44+
run_merging.main([])
45+
46+
mock_validate.assert_called()
47+
mock_merge.assert_called_once()
48+
49+
@mock.patch.object(
50+
orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock
51+
)
52+
@mock.patch.object(jax, 'process_index', return_value=0)
53+
def test_main_invalid_output_not_empty(self, _):
54+
out_path = epath.Path(self.out_path)
55+
(out_path / 'some_file').write_text('content')
56+
57+
with flagsaver.flagsaver(
58+
in_paths=self.in_paths,
59+
out_path=self.out_path,
60+
per_host_memory_limit_bytes=1024,
61+
):
62+
with self.assertRaisesRegex(ValueError, 'not empty'):
63+
run_merging.main([])
64+
65+
@mock.patch.object(
66+
orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock
67+
)
68+
@mock.patch.object(jax, 'process_index', return_value=0)
69+
def test_main_invalid_input(self, _, mock_validate):
70+
mock_validate.side_effect = ValueError('Invalid checkpoint')
71+
72+
with flagsaver.flagsaver(
73+
in_paths=self.in_paths,
74+
out_path=self.out_path,
75+
per_host_memory_limit_bytes=1024,
76+
):
77+
with self.assertRaisesRegex(ValueError, 'is not a valid checkpoint'):
78+
run_merging.main([])
79+
80+
81+
if __name__ == '__main__':
82+
absltest.main()

0 commit comments

Comments
 (0)