Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
APP_NAME="Transformer Visualizer API"
MODEL_NAME="gpt2"
DEVICE="cpu"
100 changes: 98 additions & 2 deletions backend/backend.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,101 @@
# Command to run the app
# Transformer Visualizer Backend

## run server

```bash
uvicorn main:app --reload
```


## API endpoints

### health check

```text
GET /health
```

response:

```json
{
"status": "healthy",
"model_loaded": true,
"model_name": "gpt2"
}
```

### predict next token

```text
POST /v1/predict
```

request body:

```json
{
"text": "The quick brown fox",
"temperature": 1.0,
"top_k": 5
}
```

response:

```json
{
"input_text": "The quick brown fox",
"generated_text": "The quick brown fox jumps",
"next_token_probabilities": [
{
"token": " jumps",
"probability": 0.45,
"token_id": 14523
},
{
"token": " jumped",
"probability": 0.25,
"token_id": 11687
},
{
"token": " is",
"probability": 0.15,
"token_id": 318
}
]
}
```

### generate text

```text
POST /v1/generate
```

request body:

```json
{
"text": "Once upon a time",
"max_tokens": 50,
"temperature": 0.8,
"top_k": 5
}
```

response:

```json
{
"input_text": "Once upon a time",
"generated_text": "Once upon a time there was a little girl who lived in a small village...",
"tokens_generated": 50
}
```

## request parameters

- `text` (required): input text to generate from
- `max_tokens` (optional, default: 50): number of tokens to generate
- `temperature` (optional, default: 1.0): controls randomness (lower = more deterministic)
- `top_k` (optional, default: 5): number of top probable tokens to return
15 changes: 15 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
# application configuration
app_name: str = "Transformer Visualizer API"

# model configuration
model_name: str = "gpt2"
device: str = "cpu"

class Config:
env_file = ".env"

# global settings instance
settings = Settings()
34 changes: 32 additions & 2 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager

app = FastAPI()
from config import settings
from models.model_loader import model_manager
from routes.inference import router as inference_router

@asynccontextmanager
async def lifespan(app: FastAPI):
# load model on startup
model_manager.load_model(model_name=settings.model_name, device=settings.device)
yield
# cleanup model on shutdown
model_manager.model = None

app = FastAPI(title=settings.app_name, lifespan=lifespan)

# enable CORS for frontend integration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# register inference routes
app.include_router(inference_router)

@app.get("/health")
async def health_check():
return {"status": "healthy"}
return {
"status": "healthy",
"model_loaded": model_manager.is_loaded(),
"model_name": model_manager.model_name
}
37 changes: 37 additions & 0 deletions backend/models/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformer_lens import HookedTransformer
from typing import Optional
import torch

class ModelManager:
def __init__(self):
self.model: Optional[HookedTransformer] = None
self.model_name: Optional[str] = None
self.device: str = "cpu"

def load_model(self, model_name: str = "gpt2", device: str = "cpu"):
# skip loading if model is already loaded
if self.model is not None and self.model_name == model_name:
return self.model

# update model configuration
self.device = device
self.model_name = model_name

# load pretrained model from transformerlens
self.model = HookedTransformer.from_pretrained(
model_name,
device=device
)
return self.model

def get_model(self) -> HookedTransformer:
# ensure model is loaded before returning
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model first.")
return self.model

def is_loaded(self) -> bool:
return self.model is not None

# global model manager instance
model_manager = ModelManager()
6 changes: 6 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi
uvicorn[standard]
transformer-lens
torch
pydantic-settings
python-dotenv
97 changes: 97 additions & 0 deletions backend/routes/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from fastapi import APIRouter, HTTPException
import torch
from typing import List

from models.model_loader import model_manager
from schemas import InferenceRequest, InferenceResponse, TokenProbability

router = APIRouter(prefix="/v1", tags=["inference"])

@router.post("/predict", response_model=InferenceResponse)
async def predict_next_token(request: InferenceRequest):
# check if model is loaded before processing
if not model_manager.is_loaded():
raise HTTPException(status_code=503, detail="Model not loaded")

model = model_manager.get_model()

try:
# convert input text to tokens
tokens = model.to_tokens(request.text)

# run model forward pass to get logits
with torch.no_grad():
logits = model(tokens)

# extract logits for the last token position
final_logits = logits[0, -1, :]

# apply temperature scaling and convert to probabilities
probs = torch.softmax(final_logits / request.temperature, dim=-1)

# get top k most probable tokens
top_k_probs, top_k_indices = torch.topk(probs, request.top_k)

# build list of token probabilities for response
next_token_probs: List[TokenProbability] = []
for prob, idx in zip(top_k_probs.tolist(), top_k_indices.tolist()):
token_str = model.to_string(idx)
next_token_probs.append(TokenProbability(
token=token_str,
probability=prob,
token_id=idx
))

# generate output text with most likely next token
predicted_token_id = top_k_indices[0].item()
generated_text = request.text + model.to_string(predicted_token_id)

return InferenceResponse(
input_text=request.text,
generated_text=generated_text,
next_token_probabilities=next_token_probs
)

except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")

@router.post("/generate")
async def generate_text(request: InferenceRequest):
# check if model is loaded before processing
if not model_manager.is_loaded():
raise HTTPException(status_code=503, detail="Model not loaded")

model = model_manager.get_model()

try:
# convert input text to tokens
tokens = model.to_tokens(request.text)

# start with input tokens
generated_tokens = tokens.clone()

# generate tokens one at a time
for _ in range(request.max_tokens):
# get logits for current sequence
with torch.no_grad():
logits = model(generated_tokens)

# apply temperature and get probabilities for next token
final_logits = logits[0, -1, :]
probs = torch.softmax(final_logits / request.temperature, dim=-1)

# sample next token from probability distribution
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)

# convert tokens back to text
generated_text = model.to_string(generated_tokens[0])

return {
"input_text": request.text,
"generated_text": generated_text,
"tokens_generated": request.max_tokens
}

except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
25 changes: 25 additions & 0 deletions backend/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel
from typing import List, Dict

class InferenceRequest(BaseModel):
# input text to generate from
text: str

# generation parameters
max_tokens: int = 50
temperature: float = 1.0
top_k: int = 5

class TokenProbability(BaseModel):
# token information
token: str
probability: float
token_id: int

class InferenceResponse(BaseModel):
# input and output text
input_text: str
generated_text: str

# top k token predictions with probabilities
next_token_probabilities: List[TokenProbability]