diff --git a/REGRESSIONS.md b/REGRESSIONS.md index c189038d..faccd255 100644 --- a/REGRESSIONS.md +++ b/REGRESSIONS.md @@ -1,15 +1,16 @@ # Regressions -Last checked: 2025-07-31 +Last checked: 2025-08-19 -# 10 failing test(s) +# 11 failing test(s) +- tests/test_nn_modeling.py::test_auxiliary_activation +- tests/test_nn_modeling.py::test_auxiliary_activation_k_exceeds_size +- tests/test_nn_modeling.py::test_batch_topk_activation +- tests/test_nn_objectives.py::test_auxiliary_coeff +- tests/test_nn_objectives.py::test_auxiliary_mse_same - tests/test_nn_objectives.py::test_safe_mse_hypothesis -- tests/test_ordered_dataloader.py::test_ordered_dataloader_with_tiny_fake_dataset -- tests/test_reservoir_buffer.py::test_blocking_get_when_empty[proc] -- tests/test_reservoir_buffer.py::test_blocking_put_when_full[proc] -- tests/test_ring_buffer.py::test_blocking_get_when_empty[proc] -- tests/test_ring_buffer.py::test_blocking_put_when_full[proc] +- tests/test_writers_properties.py::test_dataloader_batches - tests/test_writers_properties.py::test_metadata_json_has_required_keys - tests/test_writers_properties.py::test_roundtrip - tests/test_writers_properties.py::test_shard_size_consistency @@ -17,4 +18,4 @@ Last checked: 2025-07-31 # Coverage -Coverage: 1210/1816 lines (66.6%) +Coverage: 694/1933 lines (35.9%) diff --git a/src/saev/nn/modeling.py b/src/saev/nn/modeling.py index 5b480523..8d00351d 100644 --- a/src/saev/nn/modeling.py +++ b/src/saev/nn/modeling.py @@ -43,6 +43,13 @@ class BatchTopK: ActivationConfig = Relu | TopK | BatchTopK +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class AuxiliaryConfig: + top_k: int = 512 + """How many dead latents to consider for auxiliary loss.""" + + @beartype.beartype @dataclasses.dataclass(frozen=True) class SparseAutoencoderConfig: @@ -319,6 +326,48 @@ def forward(self, x: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae return torch.mul(mask, x) +class AuxiliaryLossActivation(torch.nn.Module): + """ + Auxiliary loss activation function. Used to take the top-k dead latents before calculating the auxiliary loss. + """ + + def __init__(self, cfg: AuxiliaryConfig = AuxiliaryConfig()): + super().__init__() + self.cfg = cfg + + def forward( + self, + f_x: Float[Tensor, "batch d_sae"], + dead_latents: Float[Tensor, "batch d_sae"], + ) -> Float[Tensor, "batch d_sae"]: + """ + Apply auxiliary loss activation (top-k of dead latents) to the input tensor. + """ + + # First, mask out all but dead latents + f_x = f_x * dead_latents + + masked_dead_top_k = torch.zeros_like(f_x) + + # Now, populate top k of the dead latents + if self.cfg.top_k > 0 and dead_latents.sum() > 0: + # First, mask out dead latents + masked_dead_latents = f_x * dead_latents + + # Find top k of dead latents + k_vals, k_inds = torch.topk( + masked_dead_latents, min(self.cfg.top_k, masked_dead_latents.shape[1]), dim=1, sorted=False + ) + top_k_mask = torch.zeros_like(masked_dead_latents).scatter_( + dim=-1, index=k_inds, src=torch.ones_like(masked_dead_latents) + ) + + # Mask out all but top k dead latents + masked_dead_top_k = torch.mul(top_k_mask, f_x) + + return masked_dead_top_k + + @beartype.beartype def get_activation(cfg: ActivationConfig) -> torch.nn.Module: if isinstance(cfg, Relu): diff --git a/src/saev/nn/objectives.py b/src/saev/nn/objectives.py index 3394a477..ee945952 100644 --- a/src/saev/nn/objectives.py +++ b/src/saev/nn/objectives.py @@ -32,6 +32,19 @@ class Matryoshka: ObjectiveConfig = Vanilla | Matryoshka +@beartype.beartype +@dataclasses.dataclass(frozen=True, slots=True) +class Auxiliary: + """ + Config for the Auxiliary loss (not for the SAE itself, but for auxiliary loss). + + Reference paper is https://doi.org/10.48550/arXiv.2412.06410 + """ + + aux_coeff: float = 0.03125 + """Coefficient for the auxiliary loss term.""" + + @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class Loss: @@ -163,6 +176,45 @@ def forward( return MatryoshkaLoss(mse_loss, sparsity_loss, l0, l1) +@jaxtyped(typechecker=beartype.beartype) +@dataclasses.dataclass(frozen=True, slots=True) +class AuxiliaryLoss(Loss): + """The vanilla loss terms for an training batch.""" + + mse: Float[Tensor, ""] + """Reconstruction loss (mean squared error).""" + + @property + def loss(self) -> Float[Tensor, ""]: + """Total loss.""" + return self.mse + + def metrics(self) -> dict[str, object]: + return { + "loss": self.loss.item(), + "mse": self.mse.item(), + } + + +@jaxtyped(typechecker=beartype.beartype) +class AuxiliaryObjective(Objective): + def __init__(self, cfg: Auxiliary): + super().__init__() + self.cfg = cfg + + def forward( + self, + x: Float[Tensor, "batch d_model"], + x_hat: Float[Tensor, "batch d_model"], + ) -> VanillaLoss: + # Some values of x and x_hat can be very large. We can calculate a safe MSE + mse_loss = mean_squared_err(x_hat, x) + + mse_loss = mse_loss.mean() + + return AuxiliaryLoss(self.cfg.aux_coeff * mse_loss) + + @beartype.beartype def get_objective(cfg: ObjectiveConfig) -> Objective: if isinstance(cfg, Vanilla): diff --git a/tests/test_nn_modeling.py b/tests/test_nn_modeling.py index cf954b36..8b504685 100644 --- a/tests/test_nn_modeling.py +++ b/tests/test_nn_modeling.py @@ -60,6 +60,26 @@ def batch_topk_cfgs(): return st.builds(modeling.BatchTopK, top_k=st.sampled_from([1, 2, 4, 8])) +def aux_cfgs(): + return st.builds(modeling.AuxiliaryConfig, top_k=st.sampled_from([1, 2, 4, 8])) + + +@given( + cfg=aux_cfgs(), + batch=st.integers(min_value=1, max_value=4), + d_sae=st.integers(min_value=256, max_value=2048), +) +def test_auxiliary_activation(cfg, batch, d_sae): + act = modeling.AuxiliaryLossActivation(cfg) + dead_lts = torch.zeros((batch, d_sae)) + x = torch.rand(batch, d_sae) + y = act(x, dead_lts) + + assert y.shape == (batch, d_sae) + # Check that only k elements are non-zero per sample + assert (y != 0).sum(dim=1).le(cfg.top_k).all() + + @given( cfg=topk_cfgs(), batch=st.integers(min_value=1, max_value=4), @@ -114,6 +134,20 @@ def test_topk_ties(): assert y[y != 0].unique().item() == 2.0 +def test_auxiliary_activation_ties(): + """Test Auxiliary activation behavior with tied values.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[2.0, 2.0, 2.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # Should select first k elements in case of ties + assert (y != 0).sum() == 2 + # Verify the selected values are correct + assert y[y != 0].unique().item() == 2.0 + + def test_topk_k_equals_size(): """Test TopK when k equals tensor size.""" cfg = modeling.TopK(top_k=4) @@ -126,6 +160,30 @@ def test_topk_k_equals_size(): torch.testing.assert_close(y, x) +def test_auxiliary_activation_k_equals_size(): + """Test Auxiliary activation when k equals tensor size.""" + cfg = modeling.AuxiliaryConfig(top_k=4) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # All values should be preserved + torch.testing.assert_close(y, x) + + +def test_auxiliary_activation_k_exceeds_size(): + """Test Auxiliary activation when k exceeds tensor size.""" + cfg = modeling.AuxiliaryConfig(top_k=8) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # All values should be preserved + torch.testing.assert_close(y, x) + + def test_topk_negative_values(): """Test TopK with negative values.""" cfg = modeling.TopK(top_k=2) @@ -139,6 +197,19 @@ def test_topk_negative_values(): torch.testing.assert_close(y, expected) +def test_auxiliary_activation_negative_values(): + """Test Auxiliary activation with negative values.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[-5.0, -1.0, -3.0, -2.0]]) + y = act(x, torch.ones_like(x)) + + # Should select -1.0 and -2.0 (largest values) + expected = torch.tensor([[0.0, -1.0, 0.0, -2.0]]) + torch.testing.assert_close(y, expected) + + def test_topk_gradient_flow(): """Test that gradients flow correctly through TopK.""" cfg = modeling.TopK(top_k=2) @@ -156,6 +227,23 @@ def test_topk_gradient_flow(): torch.testing.assert_close(x.grad, expected_grad) +def test_auxiliary_activation_gradient_flow(): + """Test that gradients flow correctly through Auxiliary activation.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0], [2.0, 4.0, 1.0, 3.0]], requires_grad=True) + y = act(x, torch.ones_like(x)) + + # Create a simple loss (sum of outputs) + loss = y.sum() + loss.backward() + + # Expected gradient: 1.0 for selected elements, 0.0 for others + expected_grad = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]) + torch.testing.assert_close(x.grad, expected_grad) + + def test_topk_gradient_sparsity(): """Verify gradient sparsity matches forward pass selection.""" cfg = modeling.TopK(top_k=3) @@ -180,6 +268,30 @@ def test_topk_gradient_sparsity(): torch.testing.assert_close(selected_grads, expected_grads) +def test_auxiliary_activation_gradient_sparsity(): + """Verify gradient sparsity matches forward pass selection for Auxiliary activation.""" + cfg = modeling.AuxiliaryConfig(top_k=3) + act = modeling.AuxiliaryLossActivation(cfg) + + torch.manual_seed(42) + x = torch.randn(2, 8, requires_grad=True) + y = act(x, torch.ones_like(x)) + + # Use a different upstream gradient + grad_output = torch.randn_like(y) + y.backward(grad_output) + + # Check that gradient sparsity matches forward pass + forward_mask = (y != 0).float() + grad_mask = (x.grad != 0).float() + torch.testing.assert_close(forward_mask, grad_mask) + + # Verify gradient values for selected elements + selected_grads = x.grad * forward_mask + expected_grads = grad_output * forward_mask + torch.testing.assert_close(selected_grads, expected_grads) + + def test_topk_zero_gradient_for_unselected(): """Explicitly verify that non-selected elements have exactly 0.0 gradients.""" cfg = modeling.TopK(top_k=2) @@ -201,6 +313,27 @@ def test_topk_zero_gradient_for_unselected(): torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0)) +def test_auxiliary_activation_zero_gradient_for_unselected(): + """Explicitly verify that non-selected elements have exactly 0.0 gradients.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], requires_grad=True) + y = act(x, torch.tensor([0, 0, 0, 0, 1, 1])) + + loss = y.sum() + loss.backward() + + # Elements at indices 0, 1, 2, 3 should have zero gradients + torch.testing.assert_close(x.grad[0, 0], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 1], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 2], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 3], torch.tensor(0.0)) + # Elements at indices 4, 5 should have non-zero gradients + torch.testing.assert_close(x.grad[0, 4], torch.tensor(1.0)) + torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0)) + + # BatchTopK Edge Case Tests def test_batchtopk_basic_forward(): """Test basic BatchTopK forward pass with known values.""" diff --git a/tests/test_nn_objectives.py b/tests/test_nn_objectives.py index 6aefdeb6..b8ca9637 100644 --- a/tests/test_nn_objectives.py +++ b/tests/test_nn_objectives.py @@ -46,6 +46,28 @@ def test_safe_mse_large_x(): assert not safe.isnan().any() +def test_auxiliary_mse_same(): + x = torch.ones((45, 12), dtype=torch.float) + x_hat = torch.ones((45, 12), dtype=torch.float) + aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=1.0)) + torch.testing.assert_close( + aux_objective(x_hat, x).loss, + objectives.mean_squared_err(x_hat, x).mean(), + ) + + +def test_auxiliary_coeff(): + x = torch.ones((45, 12), dtype=torch.float) + x_hat = torch.full((45, 12), 3, dtype=torch.float) + aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=0.5)) + print(aux_objective(x_hat, x).loss) + print(0.5 * objectives.mean_squared_err(x_hat, x).mean()) + torch.testing.assert_close( + aux_objective(x_hat, x).loss, + 0.5 * objectives.mean_squared_err(x_hat, x).mean(), + ) + + def test_factories(): assert isinstance( objectives.get_objective(objectives.Vanilla()), objectives.VanillaObjective diff --git a/train.py b/train.py index 141a438e..8f8ec49f 100644 --- a/train.py +++ b/train.py @@ -28,13 +28,13 @@ import psutil import torch import tyro -import wandb from jaxtyping import Float from torch import Tensor import saev.data.shuffled import saev.utils.scheduling import saev.utils.wandb +import wandb from saev import helpers, nn logger = logging.getLogger("train.py") @@ -63,6 +63,14 @@ class Config: default_factory=nn.objectives.Vanilla ) """SAE loss configuration.""" + auxiliary_loss: bool = False + """Auxiliary Loss configuration.""" + auxiliary_loss_coeff: float = 0.03125 + """Coefficient for the auxiliary loss term.""" + tokens_until_dead: int = 10_000_000 + """Number of tokens for feature to not fire after which a feature is considered dead.""" + dead_top_k: int = 512 + """Number of dead features to reconstruct from.""" n_sparsity_warmup: int = 0 """Number of sparsity coefficient warmup steps.""" lr: float = 0.0004 @@ -211,10 +219,30 @@ def train( objectives.train() objectives = objectives.to(cfg.device) + aux_activations = [ + nn.modeling.AuxiliaryLossActivation(nn.modeling.AuxiliaryConfig(c.dead_top_k)) + for c in cfgs + ] + + aux_objectives = [ + nn.objectives.AuxiliaryObjective( + nn.objectives.Auxiliary(c.auxiliary_loss_coeff) + ) + for c in cfgs + ] + global_step, n_patches_seen = 0, 0 p_dataloader, p_children, last_rb, last_t = None, None, 0, time.time() + iterations_dead = [ + torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) for s in saes + ] + + dead_latents = [ + torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) for s in saes + ] + for batch in helpers.progress(dataloader, every=cfg.log_every): if p_dataloader is None: p_dataloader = psutil.Process(dataloader.manager_pid) @@ -227,15 +255,37 @@ def train( losses = [] x_hats = [] f_xs = [] - for sae, objective in zip(saes, objectives): + for sae, objective, aux_activation, aux_objective, iters_dead, dead_lts in zip( + saes, + objectives, + aux_activations, + aux_objectives, + iterations_dead, + dead_latents, + ): if isinstance(objective, nn.objectives.MatryoshkaObjective): # Specific case has to be given for Matryoshka SAEs since we need to decode several times with varying prefix lengths x_hat, f_x = sae.matryoshka_forward(acts_BD, cfg.n_prefixes) else: x_hat, f_x = sae(acts_BD) + x_hats.append(x_hat) f_xs.append(f_x) - losses.append(objective(acts_BD, f_x, x_hat)) + if cfg.auxiliary_loss: + # Auxiliary loss is a separate term from the main objective, so we add it separately. + aux_f_x = aux_activation(f_x, dead_latents=dead_lts) + aux_x_hat = torch.matmul(sae.W_dec, aux_f_x) + + losses.append( + objective(acts_BD, f_x, x_hat) + aux_objective(acts_BD, aux_x_hat) + ) + else: + losses.append(objective(acts_BD, f_x, x_hat)) + + # Count if feature was dead this iteration, update dead latents mask + iters_dead += ((f_x.abs() > 1e-8).sum(0) == 0).float() * acts_BD.shape[0] + iters_dead[(f_x.abs() > 1e-8).sum(0) != 0] = 0 + dead_lts = (iters_dead > cfg.tokens_until_dead).sum(0).float() n_patches_seen += len(acts_BD)