This repository contains the implementation and experimental results for the paper "Beyond Multiple Choice: Evaluating Steering Vectors for Summarization".
The repository is organized as follows:
.
├── config/ # Configuration files for experiments
├── data/ # Results from the experiments
├── datasets/ # Datasets used in the experiments
│ ├── sentiment/ # Sentiment vectors synthetic training data
│ ├── readability/ # Readability vectors synthetic training data
│ ├── toxicity/ # Toxicity vectors synthetic training data
│ └── topic/ # Topic vectors training representations
├── notebooks/ # Jupyter notebooks for exploratory analysis and visualization
├── scripts/ # Supporting scripts for data preparation and preprocessing
├── src/ # Core source code
│ ├── __init__.py
│ ├── apply_vectors/ # Scripts for applying steering vectors
│ ├── data_loading/ # Dataset loaders for ArXiv, SamSum, NEWTS
│ ├── evaluation/ # Scoring and evaluation scripts
│ ├── experiments/ # Main steering experiments
│ ├── plot/ # Plotting and visualization scripts
│ ├── prompt_engineering/ # Prompt engineering for all datasets
│ ├── train_vectors/ # Steering vector training
│ ├── utils/ # Utility functions and helpers
├── tests/ # Tests for some of the steering functionalities
├── pyproject.toml # Project configuration and dependencies
├── LICENSE # License for the repository
├── README.md # Project documentation (this file)
This project uses Python 3.11.11 and UV for dependency management. Follow these steps to set up the environment:
-
Clone the repository:
git clone [repository-url] cd adaptive_steering_evaluation -
Create virtual environment and install dependencies using UV:
uv venv --python 3.11.11 source .venv/bin/activate # On macOS/Linux uv pip install -e ".[dev]"
After installation, try these commands:
# Generate baseline summaries
python src/generate_vanilla_summaries.py --dataset newts --model llama3_1b --num-articles 10
# Apply prompt engineering
python src/prompt_engineering/apply_prompt_engineering_to_dataset.py --dataset arxiv --model llama3_1b --num-samples 10
# Train steering vectors
python src/train_vectors/get_steering_vector.py --model llama3_1b --behavior-type sentimentThe project includes development dependencies for testing (pytest), type checking (mypy),
code formatting (black), and linting (ruff), as specified in pyproject.toml.
The project supports three main summarization datasets:
- NEWTS: News articles (3,000 samples)
- SamSum: Dialogue conversations (16,369 samples, filtered version available)
- ArXiv: Scientific papers (18,255 samples, filtered version available)
The synthetic datasets used for training the steering vectors:
- Sentiment data:
datasets/sentiment/ - Readability data:
datasets/readability/ - Toxicity data:
datasets/toxicity/ - Topic representations:
datasets/topic/
Generate vanilla summaries without steering or prompt engineering:
python src/generate_vanilla_summaries.py --dataset newts --model llama3_1b --num-articles 100 --test-setApply behavior-encouraging prompts to any dataset:
# For NEWTS dataset
python src/prompt_engineering/apply_prompt_engineering_to_dataset.py --dataset newts --model llama3_1b --num-samples 100 --test-set
# For ArXiv dataset
python src/prompt_engineering/apply_prompt_engineering_to_dataset.py --dataset arxiv --model llama3_1b --num-samples 100 --test-set
# For SamSum dataset
python src/prompt_engineering/apply_prompt_engineering_to_dataset.py --dataset samsum --model llama3_1b --num-samples 100 --test-setTrain steering vectors for different behaviors:
python src/train_vectors/get_steering_vector.py --model llama3_1b --behavior-type sentimentfrom src.utils.get_prompt import get_summary_prompt
# Generate a neutral summary prompt
prompt = get_summary_prompt(
text="Your article/paper/conversation text here",
dataset="newts", # or "arxiv" or "samsum"
)
# Generate a behavior-specific prompt
prompt = get_summary_prompt(
text="Your text here",
dataset="arxiv",
behavior_type="sentiment",
use_behavior_encouraging_prompt=True,
encourage_positive_sentiment=True
)from src.data_loading.arxiv_loader import ArXivDataLoader
from src.data_loading.samsum_loader import SamSumDataLoader
# Load ArXiv dataset
arxiv_loader = ArXivDataLoader()
arxiv_data = arxiv_loader.load_dataset(split="test", max_samples=100)
# Load SamSum dataset
samsum_loader = SamSumDataLoader()
samsum_data = samsum_loader.load_dataset(split="test", max_samples=100)The code in this repository is licensed under the MIT License. The datasets provided are licensed under CC BY-NC 4.0. Note: This dataset contains derivatives of the SAMSum, NEWTS, and arXiv datasets. Please ensure you comply with the original licenses of these datasets. See the LICENSE file for details.