diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 1e1c813a6..a27bea568 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -66,7 +66,9 @@ def __matmul__( "FactoredMatrix", ], ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: - if isinstance(other, torch.Tensor): + if isinstance(other, FactoredMatrix): + return (self @ other.A) @ other.B + else: if other.ndim < 2: # It's a vector, so we collapse the factorisation and just return a vector # Squeezing/Unsqueezing is to preserve broadcasting working nicely @@ -79,8 +81,6 @@ def __matmul__( return FactoredMatrix(self.A, self.B @ other) else: return FactoredMatrix(self.AB, other) - elif isinstance(other, FactoredMatrix): - return (self @ other.A) @ other.B @overload def __rmatmul__( # type: ignore @@ -107,7 +107,9 @@ def __rmatmul__( # type: ignore "FactoredMatrix", ], ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: - if isinstance(other, torch.Tensor): + if isinstance(other, FactoredMatrix): + return other.A @ (other.B @ self) + else: assert ( other.size(-1) == self.ldim ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" @@ -118,8 +120,6 @@ def __rmatmul__( # type: ignore return FactoredMatrix(other @ self.A, self.B) else: return FactoredMatrix(other, self.AB) - elif isinstance(other, FactoredMatrix): - return other.A @ (other.B @ self) def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: """