diff --git a/tests/unit/utilities/test_multi_gpu_unit.py b/tests/unit/utilities/test_multi_gpu_unit.py index 082f4ffae..5924c7476 100644 --- a/tests/unit/utilities/test_multi_gpu_unit.py +++ b/tests/unit/utilities/test_multi_gpu_unit.py @@ -1,7 +1,9 @@ """Tests for multi-GPU utilities.""" +from types import SimpleNamespace from unittest.mock import Mock +import pytest import torch from transformer_lens.utilities import ( @@ -9,6 +11,7 @@ determine_available_memory_for_available_devices, sort_devices_based_on_available_memory, ) +from transformer_lens.utilities.multi_gpu import get_device_for_block_index def mock_available_devices(memory_stats: list[tuple[int, int]]): @@ -66,3 +69,63 @@ def test_sort_devices_based_on_available_memory(): (2, 40), (0, 20), ] + + +def _cuda_cfg(n_layers: int, n_devices: int) -> SimpleNamespace: + return SimpleNamespace(n_layers=n_layers, n_devices=n_devices, device="cuda") + + +class TestGetDeviceForBlockIndex: + """Regression tests for the layer-to-device index math. + + Issue #1356: the previous formula ``index // (n_layers // n_devices)`` + overshot ``n_devices - 1`` whenever ``n_layers`` was not a multiple of + ``n_devices``, and divided by zero when ``n_layers < n_devices``. + """ + + @pytest.mark.parametrize( + "n_layers,n_devices", + [(62, 8), (12, 8), (32, 4), (24, 8), (8, 8), (1, 8), (7, 8)], + ) + def test_device_index_stays_in_bounds(self, n_layers: int, n_devices: int): + cfg = _cuda_cfg(n_layers, n_devices) + for index in range(n_layers): + result = get_device_for_block_index(index, cfg) + assert 0 <= result.index < n_devices, ( + f"index={index} mapped to device {result.index} which is outside " + f"[0, {n_devices - 1}] for n_layers={n_layers}, n_devices={n_devices}" + ) + + @pytest.mark.parametrize("n_layers,n_devices", [(62, 8), (32, 4), (24, 8), (8, 8)]) + def test_layer_distribution_is_balanced(self, n_layers: int, n_devices: int): + """Every device sees ``floor(n_layers / n_devices)`` or that plus 1 layers + — never more than 1 layer off, and the counts sum to ``n_layers``.""" + cfg = _cuda_cfg(n_layers, n_devices) + counts = [0] * n_devices + for index in range(n_layers): + counts[get_device_for_block_index(index, cfg).index] += 1 + assert sum(counts) == n_layers + assert max(counts) - min(counts) <= 1, f"unbalanced distribution: {counts}" + + def test_first_index_lands_on_first_device(self): + cfg = _cuda_cfg(n_layers=62, n_devices=8) + result = get_device_for_block_index(0, cfg) + assert result.index == 0 + + def test_last_index_lands_on_last_device(self): + cfg = _cuda_cfg(n_layers=62, n_devices=8) + result = get_device_for_block_index(61, cfg) + assert result.index == 7 + + def test_starting_device_offset_is_honored(self): + """When ``device`` carries an explicit index, layer offsets are added on top.""" + cfg = _cuda_cfg(n_layers=32, n_devices=4) + result = get_device_for_block_index(0, cfg, device=torch.device("cuda", 2)) + assert result.index == 2 + result = get_device_for_block_index(31, cfg, device=torch.device("cuda", 2)) + assert result.index == 5 # 2 (starting offset) + 3 (last layer on 4 devices) + + def test_cpu_device_is_returned_unchanged(self): + cfg = _cuda_cfg(n_layers=62, n_devices=8) + result = get_device_for_block_index(30, cfg, device="cpu") + assert result.type == "cpu" diff --git a/transformer_lens/utilities/multi_gpu.py b/transformer_lens/utilities/multi_gpu.py index f957877de..256cc2b04 100644 --- a/transformer_lens/utilities/multi_gpu.py +++ b/transformer_lens/utilities/multi_gpu.py @@ -133,13 +133,16 @@ def get_device_for_block_index( This will be removed in 3.0 """ assert cfg.device is not None - layers_per_device = cfg.n_layers // cfg.n_devices if device is None: device = cfg.device device = torch.device(device) if device.type == "cpu": return device - device_index = (device.index or 0) + (index // layers_per_device) + # Multiplying first guarantees the result is in [0, n_devices - 1] and avoids + # the divide-by-zero when n_layers < n_devices. The naive form + # `index // (n_layers // n_devices)` floors the divisor and overshoots when + # n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8). + device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers return torch.device(device.type, device_index)