Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/unit/utilities/test_multi_gpu_unit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Tests for multi-GPU utilities."""

from types import SimpleNamespace
from unittest.mock import Mock

import pytest
import torch

from transformer_lens.utilities import (
calculate_available_device_cuda_memory,
determine_available_memory_for_available_devices,
sort_devices_based_on_available_memory,
)
from transformer_lens.utilities.multi_gpu import get_device_for_block_index


def mock_available_devices(memory_stats: list[tuple[int, int]]):
Expand Down Expand Up @@ -66,3 +69,63 @@ def test_sort_devices_based_on_available_memory():
(2, 40),
(0, 20),
]


def _cuda_cfg(n_layers: int, n_devices: int) -> SimpleNamespace:
return SimpleNamespace(n_layers=n_layers, n_devices=n_devices, device="cuda")


class TestGetDeviceForBlockIndex:
"""Regression tests for the layer-to-device index math.

Issue #1356: the previous formula ``index // (n_layers // n_devices)``
overshot ``n_devices - 1`` whenever ``n_layers`` was not a multiple of
``n_devices``, and divided by zero when ``n_layers < n_devices``.
"""

@pytest.mark.parametrize(
"n_layers,n_devices",
[(62, 8), (12, 8), (32, 4), (24, 8), (8, 8), (1, 8), (7, 8)],
)
def test_device_index_stays_in_bounds(self, n_layers: int, n_devices: int):
cfg = _cuda_cfg(n_layers, n_devices)
for index in range(n_layers):
result = get_device_for_block_index(index, cfg)
assert 0 <= result.index < n_devices, (
f"index={index} mapped to device {result.index} which is outside "
f"[0, {n_devices - 1}] for n_layers={n_layers}, n_devices={n_devices}"
)

@pytest.mark.parametrize("n_layers,n_devices", [(62, 8), (32, 4), (24, 8), (8, 8)])
def test_layer_distribution_is_balanced(self, n_layers: int, n_devices: int):
"""Every device sees ``floor(n_layers / n_devices)`` or that plus 1 layers
— never more than 1 layer off, and the counts sum to ``n_layers``."""
cfg = _cuda_cfg(n_layers, n_devices)
counts = [0] * n_devices
for index in range(n_layers):
counts[get_device_for_block_index(index, cfg).index] += 1
assert sum(counts) == n_layers
assert max(counts) - min(counts) <= 1, f"unbalanced distribution: {counts}"

def test_first_index_lands_on_first_device(self):
cfg = _cuda_cfg(n_layers=62, n_devices=8)
result = get_device_for_block_index(0, cfg)
assert result.index == 0

def test_last_index_lands_on_last_device(self):
cfg = _cuda_cfg(n_layers=62, n_devices=8)
result = get_device_for_block_index(61, cfg)
assert result.index == 7

def test_starting_device_offset_is_honored(self):
"""When ``device`` carries an explicit index, layer offsets are added on top."""
cfg = _cuda_cfg(n_layers=32, n_devices=4)
result = get_device_for_block_index(0, cfg, device=torch.device("cuda", 2))
assert result.index == 2
result = get_device_for_block_index(31, cfg, device=torch.device("cuda", 2))
assert result.index == 5 # 2 (starting offset) + 3 (last layer on 4 devices)

def test_cpu_device_is_returned_unchanged(self):
cfg = _cuda_cfg(n_layers=62, n_devices=8)
result = get_device_for_block_index(30, cfg, device="cpu")
assert result.type == "cpu"
7 changes: 5 additions & 2 deletions transformer_lens/utilities/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,16 @@ def get_device_for_block_index(
This will be removed in 3.0
"""
assert cfg.device is not None
layers_per_device = cfg.n_layers // cfg.n_devices
if device is None:
device = cfg.device
device = torch.device(device)
if device.type == "cpu":
return device
device_index = (device.index or 0) + (index // layers_per_device)
# Multiplying first guarantees the result is in [0, n_devices - 1] and avoids
# the divide-by-zero when n_layers < n_devices. The naive form
# `index // (n_layers // n_devices)` floors the divisor and overshoots when
# n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8).
device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers
return torch.device(device.type, device_index)


Expand Down
Loading