A research framework for the full LLM lifecycle - pre-training, post-training, and mechanistic interpretability - built from scratch in PyTorch.
| Component | What it does |
|---|---|
pre_training/ |
Train language models from scratch with FSDP, tensor parallelism, and multiple architecture variants |
post_training/ |
Inference utilities, KV-cache generation, and nucleus sampling rollouts |
interpretability/ |
Train Sparse Autoencoders on model activations and steer features at generation time |
git clone https://github.com/aman-17/911
cd 911
pip install -e .Requires Python 3.11+ and PyTorch 2.7+ with CUDA 12.8:
pip install torch --index-url https://download.pytorch.org/whl/cu128python pre_training/train.pySupports FSDP sharding strategies (FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD), activation checkpointing, mixed precision, and W&B logging. Configure via YAML:
model:
emb_dim: 2048
n_heads: 16
n_layers: 24
attention: grouped_query
training:
batch_size: 512
lr: 3e-4
fsdp_strategy: FULL_SHARDfrom post_training.inference.inference_utils import generate_top_p
output = generate_top_p(model, tokenizer, prompt="Hello!", max_new_tokens=200, top_p=0.9, temperature=0.8)KV-cache rollouts for RLHF-style training:
from post_training.inference.rollout import sample_response
tokens, text, log_probs = sample_response(model, tokenizer, prompt_ids, max_new_tokens=512)Step 1: Collect residual stream activations
python -m interpretability.data.lymsys_chat1bRuns OLMo-2 1B inference and saves activations from layer 8 to disk in 200K-token chunks.
Step 2: Train the SAE
python -m interpretability.trainTrains a TopK Sparse Autoencoder (k=32, 32K dictionary) on the collected activations for 50M tokens.
from interpretability.inference import run_steered_generation
output = run_steered_generation(feature_idx=4821, scale=3.0, prompt="Tell me about your day")
print(output)Or use FeatureSteerer directly for full control:
from interpretability.inference import FeatureSteerer
with FeatureSteerer(model, sae, layer_idx=8).set_feature(4821, scale=3.0):
output_ids = model.generate(**inputs, max_new_tokens=200)