From 21052255e649fbb4bfcf1c7ac67ecda07639dcfb Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 17 May 2024 02:05:49 -0700 Subject: [PATCH] Make `FactoredMatrix` compatible with tensor-like arguments I'd like to be able to use `FactoredMatrix` with things that implement the interface of `torch.Tensor` without subclassing it. This slight change allows `FactoredMatrix` to work with such classes rather than returning `None` in various places. --- transformer_lens/FactoredMatrix.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 1e1c813a6..a27bea568 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -66,7 +66,9 @@ def __matmul__( "FactoredMatrix", ], ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: - if isinstance(other, torch.Tensor): + if isinstance(other, FactoredMatrix): + return (self @ other.A) @ other.B + else: if other.ndim < 2: # It's a vector, so we collapse the factorisation and just return a vector # Squeezing/Unsqueezing is to preserve broadcasting working nicely @@ -79,8 +81,6 @@ def __matmul__( return FactoredMatrix(self.A, self.B @ other) else: return FactoredMatrix(self.AB, other) - elif isinstance(other, FactoredMatrix): - return (self @ other.A) @ other.B @overload def __rmatmul__( # type: ignore @@ -107,7 +107,9 @@ def __rmatmul__( # type: ignore "FactoredMatrix", ], ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: - if isinstance(other, torch.Tensor): + if isinstance(other, FactoredMatrix): + return other.A @ (other.B @ self) + else: assert ( other.size(-1) == self.ldim ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" @@ -118,8 +120,6 @@ def __rmatmul__( # type: ignore return FactoredMatrix(other @ self.A, self.B) else: return FactoredMatrix(other, self.AB) - elif isinstance(other, FactoredMatrix): - return other.A @ (other.B @ self) def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: """