Skip to content

Commit cefda62

Browse files
samos123Orbax Authors
authored andcommitted
Normalize device kind strings when looking up HBM memory.
This change removes spaces from device kind strings before looking up HBM memory values. This ensures that device kinds like "TPU 7x" and "TPU7x" are treated the same, preventing lookup failures due to inconsistent spacing. PiperOrigin-RevId: 870223149
1 parent c339d50 commit cefda62

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ handlers using `StepMetadata.item_handlers` and the global `HandlerTypeRegistry`
2222
if no args are provided.
2323
- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`.
2424

25+
### Fixed
26+
27+
- Fixed `get_device_memory` issue on TPU 7x devices where the device kind string
28+
was consistently reported without a space, causing a ValueError.
29+
2530
## [0.1.7] - 2022-03-29
2631

2732
### Added

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def get_device_memory() -> int:
159159
'NVIDIA B200': int(183e9),
160160
'NVIDIA B300 SXM6 AC': int(275e9),
161161
}
162-
memory = hbm_memory.get(device.device_kind, None)
162+
# Remove spaces from the device kind to make the lookup robust.
163+
# For example, "TPU 7x" and "TPU7x" should both map to the same value.
164+
normalized_hbm_memory = {k.replace(' ', ''): v for k, v in hbm_memory.items()}
165+
memory = normalized_hbm_memory.get(device.device_kind.replace(' ', ''), None)
163166
if memory is None:
164167
raise ValueError(
165168
f'get_device_memory is not supported for {device.device_kind}.'

0 commit comments

Comments
 (0)