Description:
Following the v3.0 release notes regarding the removal of n_devices and the recent comment that multi-device support is now possible via device_map, I want to share our use case to highlight the critical need for multi-GPU support.
1. Our Use Case: 35B Model SAE Training via sae_lens
We run Sparse Autoencoder (SAE) training on large models (e.g., 35B). Even when offloading activations to the CPU (act_store_device="cpu"), the sheer size of the model weights and the forward pass instantly cause OOM on a single GPU.
Previously, we bypassed this by sharding layers across multiple GPUs, passing n_devices indirectly through sae_lens:
model_from_pretrained_kwargs={"trust_remote_code": True, "n_devices": 8}
I understand device_map is now the standard under TransformerBridge, and I plan to test it. Before I do, I have a quick question: I previously tried "device_map": "auto" in older versions, but unlike n_devices, it failed to prevent OOM errors. Does the new implementation include specific optimizations for multi-GPU training to address this?
Thank you for your incredible work on this library!
2. Side Note: Quick Fix for Legacy "invalid device ordinal" Bug
If n_devices remains as a fallback, there is a mathematical flaw in transformers_lens/utilities/devices causing crashes when n_layers is not perfectly divisible by n_devices.
The Bug:
offset = index // (cfg.n_layers // cfg.n_devices)
(Example: 62 layers, 8 devices. For layer 60, offset evaluates to 60 // 7 = 8, which is out of bounds.)
The Fix:
Using multiplication first guarantees the resulting offset is strictly < n_devices:
device_offset = (index * cfg.n_devices) // cfg.n_layers
(Example: (60 * 8) // 62 = 7, safely within bounds.)
Description:
Following the v3.0 release notes regarding the removal of
n_devicesand the recent comment that multi-device support is now possible viadevice_map, I want to share our use case to highlight the critical need for multi-GPU support.1. Our Use Case: 35B Model SAE Training via
sae_lensWe run Sparse Autoencoder (SAE) training on large models (e.g., 35B). Even when offloading activations to the CPU (
act_store_device="cpu"), the sheer size of the model weights and the forward pass instantly cause OOM on a single GPU.Previously, we bypassed this by sharding layers across multiple GPUs, passing
n_devicesindirectly throughsae_lens:I understand device_map is now the standard under TransformerBridge, and I plan to test it. Before I do, I have a quick question: I previously tried "device_map": "auto" in older versions, but unlike n_devices, it failed to prevent OOM errors. Does the new implementation include specific optimizations for multi-GPU training to address this?
Thank you for your incredible work on this library!
2. Side Note: Quick Fix for Legacy "invalid device ordinal" Bug
If
n_devicesremains as a fallback, there is a mathematical flaw intransformers_lens/utilities/devicescausing crashes whenn_layersis not perfectly divisible byn_devices.The Bug:
(Example: 62 layers, 8 devices. For layer 60, offset evaluates to
60 // 7 = 8, which is out of bounds.)The Fix:
Using multiplication first guarantees the resulting offset is strictly
< n_devices:(Example:
(60 * 8) // 62 = 7, safely within bounds.)