Skip to content

Use Case: Multi-GPU support for 35B SAE training & minor fix for legacy n_devices bug #1356

@summer0517

Description

@summer0517

Description:

Following the v3.0 release notes regarding the removal of n_devices and the recent comment that multi-device support is now possible via device_map, I want to share our use case to highlight the critical need for multi-GPU support.

1. Our Use Case: 35B Model SAE Training via sae_lens

We run Sparse Autoencoder (SAE) training on large models (e.g., 35B). Even when offloading activations to the CPU (act_store_device="cpu"), the sheer size of the model weights and the forward pass instantly cause OOM on a single GPU.

Previously, we bypassed this by sharding layers across multiple GPUs, passing n_devices indirectly through sae_lens:

model_from_pretrained_kwargs={"trust_remote_code": True, "n_devices": 8}

I understand device_map is now the standard under TransformerBridge, and I plan to test it. Before I do, I have a quick question: I previously tried "device_map": "auto" in older versions, but unlike n_devices, it failed to prevent OOM errors. Does the new implementation include specific optimizations for multi-GPU training to address this?

Thank you for your incredible work on this library!


2. Side Note: Quick Fix for Legacy "invalid device ordinal" Bug

If n_devices remains as a fallback, there is a mathematical flaw in transformers_lens/utilities/devices causing crashes when n_layers is not perfectly divisible by n_devices.

The Bug:

offset = index // (cfg.n_layers // cfg.n_devices)

(Example: 62 layers, 8 devices. For layer 60, offset evaluates to 60 // 7 = 8, which is out of bounds.)

The Fix:
Using multiplication first guarantees the resulting offset is strictly < n_devices:

device_offset = (index * cfg.n_devices) // cfg.n_layers

(Example: (60 * 8) // 62 = 7, safely within bounds.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    TransformerBridgeBug specific to the new TransformerBridge systembugSomething isn't workingcomplexity-moderateModerately complicated issues for people who have intermediate experience with the codequestionFurther information is requestedseen_by_maintainersConfirms that a maintainer is aware of this card.

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions