Skip to content

Commit 4e428d6

Browse files
committed
Add test for CLM score parity to biofoundation
1 parent 2537367 commit 4e428d6

7 files changed

Lines changed: 168 additions & 58 deletions

File tree

experiments/plantcad/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Original tutorial: https://gist.github.com/eric-czech/31e5b79689d322f7becb94a109
1212
git clone https://github.com/marin-community/marin.git
1313
cd marin
1414
uv venv --python 3.11
15-
uv sync
1615
```
1716

1817
### Remote (SkyPilot)
@@ -29,7 +28,7 @@ envs:
2928
workdir: .
3029
setup: |
3130
uv venv --python 3.11
32-
uv sync --extra=cuda12
31+
uv sync --extra=cuda12 --extra=dna
3332
for var in HUGGING_FACE_HUB_TOKEN WANDB_API_KEY; do
3433
declare -n ref=$var
3534
grep -q "^export $var=" ~/.bashrc || echo "export $var=$ref" >> ~/.bashrc

experiments/plantcad/evaluation.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@
3939
from levanter.callbacks import StepInfo
4040
from marin.utilities.json_encoder import CustomJsonEncoder
4141

42-
from experiments.plantcad.utils import get_available_gpus, get_nucleotide_token_ids, get_plantcad_tokenizer
42+
from experiments.plantcad.utils import get_available_gpus, get_nucleotide_token_ids
43+
from levanter.utils.hf_utils import HfTokenizer
4344
from marin.execution.executor import InputName, this_output_path, versioned
45+
from jax.sharding import Mesh
46+
from haliax.partitioning import ResourceMapping
4447

4548
logger = logging.getLogger("ray")
4649

@@ -49,9 +52,6 @@
4952
class DnaEvalBaseConfig:
5053
"""Base configuration for DNA evaluation with fields needed for training callbacks"""
5154

52-
model_config: str
53-
"""Model configuration size (e.g., '300m', '100m', etc.)"""
54-
5555
dataset_path: str = "plantcad/evolutionary-constraint-example"
5656
"""Dataset repository path"""
5757

@@ -292,12 +292,12 @@ def create_alternate_sequences(
292292
assert 0 <= ref_cts.max().item() <= 1
293293
if (invalid := ref_cts == 0).any().item():
294294
pos = nucleotide_positions[Batch, invalid]
295-
tok = tokens_expanded[Batch, invalid, Position, pos]
295+
tok = tokens_expanded[Batch, invalid][Position, pos]
296296
raise ValueError(
297-
"Found invalid sequences in batch with OOV nucleotides at target positions; "
298-
f"Target positions: {pos} "
299-
f"Valid nucleotide token IDs: {nucleotide_token_ids} "
300-
f"Invalid tokens: {tok} "
297+
"Found invalid sequences in batch with OOV nucleotides at target positions;\n"
298+
f"Target positions: {pos.array} \n"
299+
f"Valid nucleotide token IDs: {nucleotide_token_ids} \n"
300+
f"Invalid tokens: {tok.array} "
301301
)
302302
ref = hax.argmax(ref_mask, axis=Variant)
303303
assert ref.axes == (Batch,)
@@ -439,7 +439,7 @@ def score_eval_dataset(
439439
assert isinstance(pos, list)
440440

441441
# Tokenize and convert to JAX arrays
442-
tokenized = tokenizer(sequences, padding=True, truncation=True, max_length=512, return_tensors="np")
442+
tokenized = tokenizer(sequences, padding=False, add_special_tokens=False, truncation=False, return_tensors="np")
443443
tokens = hax.named(tokenized["input_ids"], ("batch", "position"))
444444
nucleotide_positions = hax.named(pos, ("batch",))
445445

@@ -473,12 +473,25 @@ def score_eval_dataset(
473473
# ------------------------------------------------------------------------------------------------
474474

475475

476-
def create_dna_eval_callback(config: DnaEvalBaseConfig) -> Callable[[StepInfo], None]:
477-
"""Create a training callback for DNA evaluation."""
476+
def create_dna_eval_callback(
477+
config: DnaEvalBaseConfig,
478+
tokenizer: HfTokenizer,
479+
device_mesh: Mesh,
480+
compute_axis_mapping: ResourceMapping,
481+
parameter_axis_mapping: ResourceMapping,
482+
) -> Callable[[StepInfo], None]:
483+
"""Create a training callback for DNA evaluation.
484+
485+
Args:
486+
config: DNA evaluation configuration
487+
tokenizer: Tokenizer provided by Levanter's training loop
488+
device_mesh: JAX device mesh for distributed computation
489+
compute_axis_mapping: Axis mapping for computation (used during model inference)
490+
parameter_axis_mapping: Axis mapping for parameter storage (used for model sharding)
478491
479-
# Load tokenizer
480-
# TODO: fix this; how can the tokenizer be referenced during training without reloading?
481-
tokenizer = get_plantcad_tokenizer()
492+
Returns:
493+
Callback function that evaluates DNA conservation at training steps
494+
"""
482495

483496
# Load and validate dataset once when creating the callback
484497
dataset = load_eval_dataset(config)
@@ -488,7 +501,7 @@ def dna_conservation_callback(step_info: StepInfo) -> None:
488501
logger.debug(f"Running DNA conservation evaluation ({step=})")
489502
eval_model = step_info.eval_model
490503

491-
# Create logit function for Levanter model
504+
# Create logit function for Levanter model with proper axis mapping
492505
def logit_function(
493506
tokens: ht.Int[ht.NamedArray, "batch position"],
494507
) -> ht.Float[ht.NamedArray, "batch position vocab"]:
@@ -501,7 +514,6 @@ def logit_function(
501514
logit_function=logit_function,
502515
eval_dataset=dataset,
503516
batch_size=config.batch_size,
504-
# TODO: make configurable or disable?
505517
log_progress=True,
506518
)
507519

@@ -710,18 +722,12 @@ def run_conservation_eval(config: DnaEvalConfig) -> dict[str, float]:
710722
# Usage examples:
711723

712724
# 1. Training callback (uses Levanter model from training state):
713-
# config = DnaEvalConfig(
714-
# checkpoint_path="/path/to/checkpoint", # Not used for callbacks
715-
# model_config="300m",
716-
# dataset_path="plantcad/evolutionary-constraint-example",
717-
# dataset_config="10k"
718-
# )
719-
# trainer.add_hook(create_dna_eval_callback(config), every=1000)
725+
# The plugin system handles this automatically via PlantCADEvaluationPlugin
726+
# See plugin.py for implementation details
720727

721728
# 2. Standalone evaluation with HuggingFace checkpoint:
722729
# config = DnaEvalConfig(
723730
# checkpoint_path="/path/to/hf/checkpoint",
724-
# model_config="300m",
725731
# device="cuda", # or "cpu" for CPU inference
726732
# num_workers=None, # defaults to number of GPUs
727733
# dataset_path="plantcad/evolutionary-constraint-example",

experiments/plantcad/plugin.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
"""PlantCAD evaluation plugin for Levanter training."""
1717

1818
import logging
19-
from typing import Any
2019
from collections.abc import Callable
2120

2221
from levanter.eval import EvalPlugin
2322
from levanter.callbacks import StepInfo
23+
from levanter.utils.hf_utils import HfTokenizer
24+
from jax.sharding import Mesh
25+
from haliax.partitioning import ResourceMapping
2426
from experiments.plantcad.evaluation import DnaEvalBaseConfig, create_dna_eval_callback
2527

2628
logger = logging.getLogger("ray")
@@ -29,11 +31,39 @@
2931
class PlantCADEvaluationPlugin(EvalPlugin):
3032
"""PlantCAD DNA conservation evaluation plugin for Levanter."""
3133

32-
def __init__(self, config: dict[str, Any]):
33-
# Store the dict config directly
34-
self.config = DnaEvalBaseConfig(**config)
35-
logger.info(f"Initialized PlantCAD evaluation plugin with config: {self.config}")
34+
def __init__(self):
35+
logger.info("Initialized PlantCAD evaluation plugin")
3636

37-
def create_callback(self, **kwargs) -> Callable[[StepInfo], None]:
38-
"""Create DNA conservation evaluation callback."""
39-
return create_dna_eval_callback(self.config)
37+
def create_callback(
38+
self,
39+
*,
40+
tokenizer: HfTokenizer,
41+
device_mesh: Mesh,
42+
compute_axis_mapping: ResourceMapping,
43+
parameter_axis_mapping: ResourceMapping,
44+
batch_size: int,
45+
) -> Callable[[StepInfo], None]:
46+
"""Create DNA conservation evaluation callback.
47+
48+
Args:
49+
tokenizer: Tokenizer for the model
50+
device_mesh: JAX device mesh for distributed computation
51+
compute_axis_mapping: Axis mapping for computation
52+
parameter_axis_mapping: Axis mapping for parameter storage
53+
batch_size: Evaluation batch size
54+
55+
Returns:
56+
Callback function for DNA conservation evaluation
57+
"""
58+
# Cut batch size in half because current eval runs with
59+
# model in full precision rather than bf16
60+
# TODO: implement mixed-precision in eval
61+
config = DnaEvalBaseConfig(batch_size=batch_size // 2)
62+
logger.info(f"Creating conservation evaluation callback with config: {config}")
63+
return create_dna_eval_callback(
64+
config=config,
65+
tokenizer=tokenizer,
66+
device_mesh=device_mesh,
67+
compute_axis_mapping=compute_axis_mapping,
68+
parameter_axis_mapping=parameter_axis_mapping,
69+
)

experiments/plantcad/scripts/exp_pc1_train.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
from experiments.simple_train_config import SimpleTrainConfig
3232
from marin.execution.executor import executor_main
3333
from marin.resources import GpuConfig
34+
from experiments.plantcad.plugin import PlantCADEvaluationPlugin
35+
from experiments.plantcad.evaluation import resolve_checkpoint_path
3436

3537
logger = logging.getLogger("ray")
3638

3739
# Run iteration
38-
run_number = 12
40+
run_number = 16
3941

4042
# Resources
4143
num_gpus = get_available_gpus(local_only=True)
@@ -45,8 +47,10 @@
4547
learning_rate = 3e-4
4648

4749
# Batch size
48-
# TODO: How do you tune micro/macro batch size instead of gloabl batch?
49-
# This needs to not be a function of device count.
50+
# Ideally global batch would be fixed and device count wouldn't
51+
# matter (w/ grad accum), however OOMs are unavoidable on GPUs
52+
# unless the global batch varies as a function of device count.
53+
# TODO: find out how to fix global batch on GPUs
5054
micro_batch_size = 256
5155
global_batch_size = micro_batch_size * num_gpus
5256

@@ -56,14 +60,22 @@
5660
# Training configuration
5761
num_train_steps = target_examples // global_batch_size
5862
steps_per_export = num_train_steps // 10
59-
steps_per_eval = num_train_steps // 10
63+
steps_per_cycle = num_train_steps // 10
64+
steps_per_eval = num_train_steps // 100
6065

6166
# Model configuration - use 600M by default
6267
model_size = "600m"
6368
plant_model_config = get_plantcad_config(model_size)
6469

6570
# PlantCAD1 training dataset
6671
plant_data_tokenized = get_plantcad_training_dataset(use_pretokenized=True)
72+
plugin_class = PlantCADEvaluationPlugin
73+
74+
hf_checkpoint_path = (
75+
"hf://plantcad/_dev_marin_plantcad1_v2_train/local_store/checkpoints/plantcad-train-600m-r12-7ea0fc/hf/step-26782"
76+
)
77+
if hf_checkpoint_path is not None:
78+
hf_checkpoint_path = resolve_checkpoint_path(hf_checkpoint_path)
6779

6880
# Training configuration
6981
train_config = SimpleTrainConfig(
@@ -73,27 +85,21 @@
7385
lr_schedule="inv",
7486
warmup=0.05,
7587
decay=0.1,
76-
cycle_length=steps_per_eval,
88+
cycle_length=steps_per_cycle,
7789
train_batch_size=global_batch_size,
7890
per_device_eval_parallelism=micro_batch_size,
7991
steps_per_eval=steps_per_eval,
8092
num_train_steps=num_train_steps,
8193
learning_rate=learning_rate,
8294
steps_per_export=steps_per_export,
83-
eval_plugins=[
84-
{
85-
"plugin": "experiments.plantcad.plugin.PlantCADEvaluationPlugin",
86-
"config": {
87-
"model_config": model_size,
88-
"dataset_config": "10k",
89-
# TODO: I regularly get OOMs with the same per-device batch size on this
90-
# eval even though it only runs on one device? Cut it down by half for now..
91-
"batch_size": micro_batch_size // 2,
92-
"max_samples": 10000,
93-
},
94-
"steps": steps_per_eval,
95-
}
96-
],
95+
initialize_from_hf=hf_checkpoint_path,
96+
# TODO: figure out why this is broken
97+
# eval_plugins=[
98+
# EvalPluginConfig(
99+
# plugin_class=f"{plugin_class.__module__}.{plugin_class.__qualname__}",
100+
# steps=steps_per_cycle,
101+
# )
102+
# ]
97103
)
98104

99105
# Create training step

experiments/plantcad/tests/test_evaluation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import numpy as np
1517
import pytest
1618
import jax
1719
import jax.numpy as jnp
1820
import haliax as hax
21+
from datasets import load_dataset
22+
from huggingface_hub import snapshot_download
23+
from levanter.models.llama import LlamaConfig
24+
from levanter.utils.jax_utils import use_cpu_device
25+
from transformers import PretrainedConfig as HfConfig, AutoTokenizer
1926
from experiments.plantcad.evaluation import (
2027
create_alternate_sequences,
2128
compute_sequence_logprob,
@@ -277,6 +284,68 @@ def test_compute_causal_conservation():
277284
)
278285

279286

287+
def test_compute_causal_conservation_accuracy():
288+
"""End-to-end parity test against reference scores.
289+
290+
Reference scores come from https://github.com/Open-Athena/biofoundation/commit/23f6745defdd54cac09b43c066f249789bf74d56
291+
"""
292+
# Download model and dataset
293+
data_path = snapshot_download(
294+
repo_id="plantcad/ci",
295+
repo_type="dataset",
296+
allow_patterns="unit_tests/evolutionary_constraint/ref_logprob_clm_sim/*",
297+
)
298+
ds = load_dataset("plantcad/ci", name="ut_ec_ref_logprob_clm_sim", split="train")
299+
model_dir = os.path.join(data_path, "unit_tests/evolutionary_constraint/ref_logprob_clm_sim/model")
300+
301+
# Load tokenizer and config
302+
hf_config = HfConfig.from_pretrained(model_dir)
303+
config = LlamaConfig.from_hf_config(hf_config)
304+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
305+
306+
# Load sequences and positions
307+
sequences = ds["seq"] if "seq" in ds.column_names else ds["sequence"]
308+
positions = np.asarray(ds["pos"], dtype=np.int32)
309+
tokens_np = np.asarray([tokenizer(s, add_special_tokens=False)["input_ids"] for s in sequences], dtype=np.int32)
310+
tokens = hax.named(jnp.array(tokens_np), ("batch", "position"))
311+
nucleotide_positions = hax.named(jnp.array(positions), ("batch",))
312+
nucleotide_token_ids = [int(tokenizer.convert_tokens_to_ids(nt)) for nt in "ACGT"]
313+
314+
# Load model
315+
converter = config.hf_checkpoint_converter().replaced(reference_checkpoint=model_dir, tokenizer=tokenizer)
316+
with use_cpu_device():
317+
model = converter.load_pretrained(
318+
config.model_type,
319+
ref=model_dir,
320+
resize_vocab_to_match_tokenizer=False,
321+
dtype=jnp.float32,
322+
)
323+
324+
def logit_fn(x):
325+
return model(x)
326+
327+
# Compute conservation scores
328+
actual = compute_causal_conservation(
329+
tokens=tokens,
330+
logit_function=logit_fn,
331+
nucleotide_positions=nucleotide_positions,
332+
nucleotide_token_ids=nucleotide_token_ids,
333+
)
334+
335+
# Compare with expected scores
336+
expected = np.asarray(ds["score"], dtype=np.float32)
337+
our_scores_np = np.asarray(actual.array, dtype=np.float32)
338+
339+
assert len(our_scores_np) == len(expected) == 8
340+
assert jnp.all(jnp.isfinite(actual.array))
341+
assert np.all(np.isfinite(expected))
342+
343+
# Order parity
344+
assert np.array_equal(np.argsort(-expected), np.argsort(-our_scores_np))
345+
# Value parity within tolerance
346+
np.testing.assert_allclose(our_scores_np, expected, rtol=1e-3, atol=1e-3)
347+
348+
280349
def _assert_batch_variants(alt_array, batch_idx, expected_variants, seq_length, batch_name):
281350
"""Helper to assert variant sequences match expected values for a batch."""
282351
for variant_idx in range(4):

experiments/simple_train_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
import dataclasses
1616
from dataclasses import dataclass
17-
from typing import Any
1817

1918
from levanter.callbacks.watch import WatchConfig
2019
from levanter.optim import OptimizerConfig
2120
from levanter.schedule import IntSchedule
2221

22+
from levanter.eval import EvalPluginConfig
2323
from marin.resources import ResourceConfig, TpuPodConfig
2424

2525

@@ -88,8 +88,8 @@ class SimpleTrainConfig:
8888
watch: WatchConfig = dataclasses.field(default_factory=WatchConfig)
8989
"""Config for watching gradients, parameters, etc. Default is to log norms of gradients and parameters."""
9090

91-
eval_plugins: list[dict[str, Any]] | None = None
92-
"""List of evaluation plugin configs. Each should have 'plugin' (module.class) and 'config' keys."""
91+
eval_plugins: list[EvalPluginConfig] | None = None
92+
"""List of evaluation plugin configs."""
9393

9494
@property
9595
def tpu_type(self) -> str | None:

0 commit comments

Comments
 (0)