Skip to content

Commit 3ff862e

Browse files
committed
Add plugin implementation for conservation eval during training
1 parent 1fa3a22 commit 3ff862e

14 files changed

Lines changed: 641 additions & 344 deletions

File tree

experiments/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def default_train(
386386
data_seed=train_config.data_seed,
387387
eval_harness_steps=train_config.steps_per_task_eval or 10000,
388388
eval_harness=harness_config,
389+
eval_plugins=train_config.eval_plugins,
389390
)
390391

391392
# Create the pod config

experiments/plantcad/README.md

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ EOF
4040
#### Lambda
4141

4242
```bash
43-
uv pip install "skypilot[lambda]==0.10.3"
43+
uv pip install "skypilot[lambda]==0.10.3.post1"
4444
sky check lambda
4545
sky launch \
46-
--cluster marin --infra lambda --num-nodes 1 --gpus "A10:1" --disk-size 100 \
46+
--cluster marin --infra lambda --num-nodes 1 --gpus "A100:8" --disk-size 100 \
4747
--env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
4848
output/cluster.sky.yaml --retry-until-up --yes
4949
REMOTE_USER=ubuntu
@@ -52,7 +52,7 @@ REMOTE_USER=ubuntu
5252
#### GCP
5353

5454
```bash
55-
uv pip install "skypilot[gcp]==0.10.3"
55+
uv pip install "skypilot[gcp]==0.10.3.post1"
5656
sky check gcp
5757
sky launch \
5858
--cluster marin --infra gcp --num-nodes 1 --gpus "A100:1" --disk-size 100 \
@@ -65,8 +65,16 @@ REMOTE_USER=gcpuser
6565
#### CoreWeave
6666

6767
```bash
68-
uv pip install "skypilot[kubernetes]==0.10.3"
68+
# The default timeout for pod launch is too conservative in SkyPilot and needs to be increased:
69+
mkdir -p ~/.sky; cat > ~/.sky/config.yaml << EOF
70+
kubernetes:
71+
provision_timeout: 180 # Wait 3 minutes for provisioning before timeout
72+
autoscaler: coreweave
73+
EOF
74+
75+
uv pip install "skypilot[kubernetes]==0.10.3.post1"
6976
sky check k8s
77+
sky show-gpus --infra k8s
7078
sky launch \
7179
--cluster marin --num-nodes 1 --infra k8s --gpus "H100_NVLINK_80GB:8" \
7280
--cpus 124 --memory 2008 \
@@ -79,6 +87,10 @@ sudo apt update
7987
sudo apt install build-essential g++ cmake ninja-build
8088
# uv sync --extra cuda12
8189
# hint: This error likely indicates that you need to install a library that provides "cuda_runtime_api.h" for `transformer-engine-jax@2.6.0.post1`
90+
91+
# For a manual debugging pod:
92+
kubectl get nodes -o wide # get name "gd92c2c"
93+
kubectl debug node/gd92c2c -i -t --image=ubuntu
8294
```
8395

8496
## Execution
@@ -97,16 +109,18 @@ rsync -rPz ./ marin:/home/$REMOTE_USER/sky_workdir \
97109
python -m experiments.plantcad.scripts.exp_pc1_tutorial --prefix local_store --force_run_failed true
98110
python -m experiments.plantcad.scripts.exp_pc1_batch_tune --prefix local_store --force_run_failed true
99111
python -m experiments.plantcad.scripts.exp_pc1_lr_tune --prefix local_store --force_run_failed true
100-
find local_store | grep -E 'step-668$' | xargs -I {} echo "hf upload plantcad/_dev_marin_plantcad1_v1_lr_tune {} {} --repo-type model"
101112

102113
# Training
114+
sudo apt-get install screen -y; screen -S train
103115
mkdir -p logs
104-
screen -S train
105116
python -m experiments.plantcad.scripts.exp_pc1_train \
106117
--prefix local_store --force_run_failed true 2>&1 | tee logs/exp_pc1_train.log
107118

108119
# Evaluation
109120
rm -rf local_store/evaluation/dna-conservation*; python -m experiments.plantcad.scripts.exp_pc1_eval --prefix local_store --force_run_failed true
121+
122+
# Checkpoint upload
123+
find local_store | grep -E 'hf/step-[0-9]+$' | xargs -I {} echo "hf upload plantcad/_dev_marin_plantcad1_v2_train {} {} --repo-type model" | bash /dev/stdin
110124
```
111125

112126
```bash
@@ -127,6 +141,18 @@ roc_auc step
127141
0.593178 21749 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-21749
128142
```
129143

144+
Second iteration:
145+
146+
```
147+
python experiments/plantcad/misc/agg_eval_results.py
148+
roc_auc step checkpoint_path
149+
0.549341 2678 hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-2678
150+
0.566597 5356 hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-5356
151+
0.604521 8034 hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-8034
152+
0.626729 10712 hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-10712
153+
0.631095 13390 hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-13390
154+
```
155+
130156
## EDA
131157

132158
Stats on kuleshov-group/Angiosperm_16_genomes:

experiments/plantcad/evaluation.py

Lines changed: 59 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import json
2323
import dataclasses
2424
from dataclasses import dataclass
25-
from typing import Any
2625
from collections.abc import Callable
2726
from datasets import Dataset
2827

@@ -38,7 +37,6 @@
3837
from huggingface_hub import HfApi
3938
from transformers import AutoModelForCausalLM
4039
from levanter.callbacks import StepInfo
41-
from levanter.utils.tree_utils import inference_mode
4240
from marin.utilities.json_encoder import CustomJsonEncoder
4341

4442
from experiments.plantcad.utils import get_available_gpus, get_nucleotide_token_ids, get_plantcad_tokenizer
@@ -48,21 +46,12 @@
4846

4947

5048
@dataclass
51-
class DnaEvalConfig:
52-
"""Configuration for DNA model evolutionary conservation evaluation"""
53-
54-
checkpoint_path: str | InputName
55-
"""Path to the model checkpoint directory"""
49+
class DnaEvalBaseConfig:
50+
"""Base configuration for DNA evaluation with fields needed for training callbacks"""
5651

5752
model_config: str
5853
"""Model configuration size (e.g., '300m', '100m', etc.)"""
5954

60-
device: str = "cuda"
61-
"""Device to use for model inference (e.g., 'cuda', 'cpu')"""
62-
63-
dtype: str | None = None
64-
"""Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
65-
6655
dataset_path: str = "plantcad/evolutionary-constraint-example"
6756
"""Dataset repository path"""
6857

@@ -75,15 +64,29 @@ class DnaEvalConfig:
7564
batch_size: int = 32
7665
"""Batch size to use for inference"""
7766

78-
num_workers: int | None = None
79-
"""Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
80-
8167
max_samples: int | None = None
8268
"""Maximum number of samples to evaluate (for quick testing)"""
8369

84-
random_seed: int = versioned(42)
70+
random_seed: int = 42
8571
"""Random seed for data shuffling prior to downsampling"""
8672

73+
74+
@dataclass
75+
class DnaEvalConfig(DnaEvalBaseConfig):
76+
"""Configuration for standalone DNA model evolutionary conservation evaluation"""
77+
78+
checkpoint_path: str | InputName | None = None
79+
"""Path to the model checkpoint directory (None for training callbacks)"""
80+
81+
device: str = "cuda"
82+
"""Device to use for model inference (e.g., 'cuda', 'cpu')"""
83+
84+
dtype: str | None = None
85+
"""Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
86+
87+
num_workers: int | None = None
88+
"""Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
89+
8790
revision: str = versioned("0.1")
8891
"""Revision number to force re-runs when needed"""
8992

@@ -380,6 +383,8 @@ def compute_causal_conservation(
380383

381384
# Run inference for all reference/alternate sequences
382385
logits = logit_function(batch_alt_sequences)
386+
# Always promote to full precision for zero-shot evaluation
387+
logits = logits.astype(jnp.float32)
383388
Vocab = logits.resolve_axis("vocab")
384389
assert logits.axes == (VariantBatch, Position, Vocab)
385390

@@ -406,6 +411,7 @@ def score_eval_dataset(
406411
eval_dataset: Dataset,
407412
logit_function: Callable[[TokenArray], LogitArray],
408413
batch_size: int = 32,
414+
log_progress: bool = True,
409415
) -> ConservationResult:
410416
"""Score evaluation dataset based on zero-shot conservation prediction."""
411417

@@ -420,7 +426,8 @@ def score_eval_dataset(
420426
batches = eval_dataset.with_format(None).batch(batch_size=batch_size)
421427
total_batches = len(batches)
422428
progress_interval = max(1, total_batches // 20) # Every 5%
423-
logger.info(f"Processing {len(eval_dataset)} samples in {total_batches} batches (batch_size={batch_size})")
429+
if log_progress:
430+
logger.info(f"Processing {len(eval_dataset)} samples in {total_batches} batches (batch_size={batch_size})")
424431

425432
for batch_index, batch_data in enumerate(batches):
426433
# Tokenize sequences
@@ -451,7 +458,7 @@ def score_eval_dataset(
451458
total_processed += len(sequences)
452459

453460
# Log progress every 5% of batches
454-
if batch_index % progress_interval == 0 or batch_index == total_batches - 1:
461+
if log_progress and (batch_index % progress_interval == 0 or batch_index == total_batches - 1):
455462
progress_pct = ((batch_index + 1) / total_batches) * 100
456463
logger.info(
457464
f"Progress: {batch_index + 1}/{total_batches} batches ({progress_pct:.1f}%) - "
@@ -466,44 +473,7 @@ def score_eval_dataset(
466473
# ------------------------------------------------------------------------------------------------
467474

468475

469-
def evaluate_dna_conservation(
470-
tokenizer: AutoTokenizer,
471-
logit_function: Callable[[Any], Any],
472-
eval_dataset: Dataset,
473-
batch_size: int = 32,
474-
step: int | None = None,
475-
) -> dict[str, float]:
476-
"""
477-
Core evaluation logic - works for both training callbacks and standalone evaluation.
478-
479-
Args:
480-
logit_function: Function that takes tokens and returns logits
481-
eval_dataset: HuggingFace dataset with 'seq' field and binary 'label' field
482-
batch_size: Batch size for evaluation
483-
step: Training step (for logging), None for standalone
484-
485-
Returns:
486-
Dictionary with evaluation metrics including ROC AUC
487-
"""
488-
# Collect scores and labels using shared function
489-
result = score_eval_dataset(
490-
tokenizer=tokenizer, logit_function=logit_function, eval_dataset=eval_dataset, batch_size=batch_size
491-
)
492-
493-
# Calculate metrics using shared function
494-
results = evaluate_conservation_scores(result)
495-
496-
# Log during training, log for standalone
497-
if step is not None:
498-
levanter.tracker.log({"eval/dna_conservation/roc": results["roc_auc"]}, step=step)
499-
logger.info(f"Step {step}: ROC AUC = {results['roc_auc']:.3f}")
500-
else:
501-
logger.info(f"ROC AUC = {results['roc_auc']:.4f} ({results['n_total']} valid nucleotides)")
502-
503-
return results
504-
505-
506-
def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None]:
476+
def create_dna_eval_callback(config: DnaEvalBaseConfig) -> Callable[[StepInfo], None]:
507477
"""Create a training callback for DNA evaluation."""
508478

509479
# Load tokenizer
@@ -514,25 +484,43 @@ def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None
514484
dataset = load_eval_dataset(config)
515485

516486
def dna_conservation_callback(step_info: StepInfo) -> None:
517-
# Put model in inference mode
518-
eval_model = inference_mode(step_info.state.model, True)
487+
step = step_info.step
488+
logger.debug(f"Running DNA conservation evaluation ({step=})")
489+
eval_model = step_info.eval_model
519490

520491
# Create logit function for Levanter model
521492
def logit_function(
522493
tokens: ht.Int[ht.NamedArray, "batch position"],
523494
) -> ht.Float[ht.NamedArray, "batch position vocab"]:
524-
# TODO: validate input / output types
525-
return eval_model(tokens)
495+
logits = eval_model(tokens)
496+
return logits
526497

527-
# Run evaluation
528-
evaluate_dna_conservation(
498+
# Compute scores with binary labels
499+
scores = score_eval_dataset(
529500
tokenizer=tokenizer,
530501
logit_function=logit_function,
531-
eval_dataset=dataset, # Use the loaded dataset
502+
eval_dataset=dataset,
532503
batch_size=config.batch_size,
504+
# TODO: make configurable or disable?
505+
log_progress=True,
506+
)
507+
508+
# Evaluate scores and labels
509+
metrics = evaluate_conservation_scores(scores)
510+
511+
# Log results
512+
levanter.tracker.log(
513+
{
514+
"eval/dna_conservation_roc": metrics["roc_auc"],
515+
},
533516
step=step_info.step,
534517
)
535518

519+
logger.info(
520+
f"DNA conservation evaluation complete ({step=}): "
521+
f"ROC AUC = {metrics['roc_auc']:.4f}, n_samples = {metrics['n_total']}"
522+
)
523+
536524
return dna_conservation_callback
537525

538526

@@ -614,7 +602,11 @@ def logit_function(
614602

615603
# Generate raw conservation scores and labels
616604
result = score_eval_dataset(
617-
tokenizer=tokenizer, logit_function=logit_function, eval_dataset=dataset, batch_size=config.batch_size
605+
tokenizer=tokenizer,
606+
logit_function=logit_function,
607+
eval_dataset=dataset,
608+
batch_size=config.batch_size,
609+
log_progress=True,
618610
)
619611

620612
logger.info(f"Generated {len(result.scores)} conservation scores")
@@ -640,10 +632,7 @@ def evaluate_conservation_scores(scores: ConservationResult) -> dict[str, float]
640632
if len(scores.scores) == 0:
641633
raise ValueError("No valid conservation scores found")
642634

643-
# Log total before filtering and filter out NaN scores
644635
n_unmasked_total = len(scores.scores)
645-
logger.info(f"n_unmasked_total: {n_unmasked_total}")
646-
647636
valid_mask = ~np.isnan(scores.scores)
648637
filtered_scores = np.array(scores.scores)[valid_mask]
649638
filtered_labels = np.array(scores.labels)[valid_mask]
@@ -691,7 +680,8 @@ def save_conservation_results(config: DnaEvalConfig, results: dict[str, float])
691680
logger.info(f"Saved evaluation results to: {results_file}")
692681

693682

694-
@ray.remote(max_calls=1)
683+
# TODO: fix this which forces only one checkpoint to run at a time
684+
@ray.remote(max_calls=1, resources={"head_node": 1})
695685
def run_conservation_eval(config: DnaEvalConfig) -> dict[str, float]:
696686
# Determine number of workers
697687
if config.num_workers is None:

0 commit comments

Comments
 (0)