diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index 572529cec..f75e924fd 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -128,3 +128,45 @@ def test_remove_einsum_from_complex_attn_linear(): # Check if the results are the same assert torch.allclose(result_new, result_old, atol=1e-4) + + +@pytest.mark.skipif( + not torch.backends.mps.is_available() or torch.__version__ != "2.8.0", + reason="Issue with F.linear issue exclusive to mps and PyTorch 2.8\n" + "https://github.com/pytorch/pytorch/issues/161640", +) +def test_cpu_mps_outputs_match(): + torch.manual_seed(0) + + cfg = { + "n_layers": 1, + "d_model": 48, + "n_ctx": 256, + "d_head": 16, + "n_heads": 3, + "load_in_4bit": False, + "dtype": torch.float32, + "act_fn": "relu", + } + + def init_weights(attn_layer: nn.Module): + nn.init.normal_(attn_layer.W_Q, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_K, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_V, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_O, mean=0.0, std=0.02) + return attn_layer + + attn_cpu = Attention(cfg) + attn_cpu = init_weights(attn_cpu) + + attn_mps = Attention(cfg).to("mps") + attn_mps.load_state_dict(attn_cpu.state_dict(), strict=True) + + batch = 1 + input_cpu = torch.randn(batch, cfg["n_ctx"], cfg["d_model"]) + input_mps = input_cpu.to("mps") + + cpu_output = attn_cpu(input_cpu, input_cpu, input_cpu) + mps_output = attn_mps(input_mps, input_mps, input_mps) + + assert torch.allclose(cpu_output, mps_output.cpu()) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index c89586e93..f9af85637 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -304,11 +304,15 @@ def forward( if self.b_O.device != z.device: z = z.to(self.b_O.device) - out = F.linear( - z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), - w, - self.b_O, - ) + z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads) + + # F.linear is a fused matmul+bias that matches HuggingFace exactly, + # but has a bug on MPS with PyTorch 2.8 (pytorch#161640). + # Fall back to manual matmul on MPS to work around it. + if z.device.type == "mps": + out = torch.matmul(z, w.T) + self.b_O + else: + out = F.linear(z, w, self.b_O) else: # Explicitly calculate the attention result so it can be accessed by a hook # This is off by default because it can easily eat through your GPU memory.