Skip to content

Commit 50f4f0a

Browse files
angel-coreOrbax Authors
authored andcommitted
Add generated v0 and v1 static checkpoints to testing directory third_party/py/orbax/checkpoint/experimental/v1/_src/testing/compatibility
PiperOrigin-RevId: 872203447
1 parent 27777d9 commit 50f4f0a

704 files changed

Lines changed: 1613 additions & 390 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ devices anyway.
3636
`load_checkpointables()` each with their own dedicated loading logic
3737
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
3838
tests
39+
- Refactor logic for handler resolution and loading checkpointables for
40+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
41+
non-standard checkpoint formats.
3942

4043
## [0.11.32] - 2026-01-20
4144

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def _get_possible_handlers(
392392
return possible_handlers
393393

394394

395+
def get_registered_handler_by_name(
396+
registry: CheckpointableHandlerRegistry,
397+
name: str,
398+
) -> CheckpointableHandler | None:
399+
"""Returns the handler for the given name if registered."""
400+
if registry.has(name):
401+
return _construct_handler_instance(name, registry.get(name))
402+
return None
403+
404+
395405
def resolve_handler_for_save(
396406
registry: CheckpointableHandlerRegistry,
397407
checkpointable: Any,
@@ -444,7 +454,7 @@ def resolve_handler_for_load(
444454
abstract_checkpointable: Any | None,
445455
*,
446456
name: str,
447-
handler_typestr: str,
457+
handler_typestr: str | None = None,
448458
) -> CheckpointableHandler:
449459
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450460
@@ -471,7 +481,9 @@ def resolve_handler_for_load(
471481
abstract_checkpointable: An abstract checkpointable to resolve.
472482
name: The name of the checkpointable.
473483
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
484+
to guide resolution. We allow a None value for handler_typestr as its
485+
possible to find the last registered handler given a specified
486+
abstract_checkpointable.
475487
476488
Returns:
477489
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +504,16 @@ def is_handleable_fn(
492504
handler_types.typestr(type(handler)) for handler in possible_handlers
493505
]
494506

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
499-
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
502-
handler_typestr,
503-
)
507+
if handler_typestr:
508+
try:
509+
idx = possible_handler_typestrs.index(handler_typestr)
510+
return possible_handlers[idx]
511+
except ValueError:
512+
logging.warning(
513+
'No handler found for typestr %s. The checkpointable may be restored'
514+
' with different handler logic than was used for saving.',
515+
handler_typestr,
516+
)
504517

505-
# Prefer the first handler in the absence of any other information.
518+
# Prefer the last handler in the absence of any other information.
506519
return possible_handlers[-1]

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ def test_resolve_handler_for_load_checkpointable(self):
284284
handler_typestr='unused',
285285
)
286286

287+
def test_resolve_handler_for_load_no_handler_typestr(self):
288+
local_registry = registration.local_registry()
289+
local_registry.add(handler_utils.FooHandler)
290+
with self.assertRaises(registration.NoEntryError):
291+
registration.resolve_handler_for_load(
292+
local_registry,
293+
handler_utils.Foo(1, 'hi'),
294+
name='unregistered_name',
295+
handler_typestr=None,
296+
)
297+
287298

288299
if __name__ == '__main__':
289300
absltest.main()

0 commit comments

Comments
 (0)