2222import json
2323import dataclasses
2424from dataclasses import dataclass
25- from typing import Any
2625from collections .abc import Callable
2726from datasets import Dataset
2827
3837from huggingface_hub import HfApi
3938from transformers import AutoModelForCausalLM
4039from levanter .callbacks import StepInfo
41- from levanter .utils .tree_utils import inference_mode
4240from marin .utilities .json_encoder import CustomJsonEncoder
4341
4442from experiments .plantcad .utils import get_available_gpus , get_nucleotide_token_ids , get_plantcad_tokenizer
4846
4947
5048@dataclass
51- class DnaEvalConfig :
52- """Configuration for DNA model evolutionary conservation evaluation"""
53-
54- checkpoint_path : str | InputName
55- """Path to the model checkpoint directory"""
49+ class DnaEvalBaseConfig :
50+ """Base configuration for DNA evaluation with fields needed for training callbacks"""
5651
5752 model_config : str
5853 """Model configuration size (e.g., '300m', '100m', etc.)"""
5954
60- device : str = "cuda"
61- """Device to use for model inference (e.g., 'cuda', 'cpu')"""
62-
63- dtype : str | None = None
64- """Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
65-
6655 dataset_path : str = "plantcad/evolutionary-constraint-example"
6756 """Dataset repository path"""
6857
@@ -75,15 +64,29 @@ class DnaEvalConfig:
7564 batch_size : int = 32
7665 """Batch size to use for inference"""
7766
78- num_workers : int | None = None
79- """Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
80-
8167 max_samples : int | None = None
8268 """Maximum number of samples to evaluate (for quick testing)"""
8369
84- random_seed : int = versioned ( 42 )
70+ random_seed : int = 42
8571 """Random seed for data shuffling prior to downsampling"""
8672
73+
74+ @dataclass
75+ class DnaEvalConfig (DnaEvalBaseConfig ):
76+ """Configuration for standalone DNA model evolutionary conservation evaluation"""
77+
78+ checkpoint_path : str | InputName | None = None
79+ """Path to the model checkpoint directory (None for training callbacks)"""
80+
81+ device : str = "cuda"
82+ """Device to use for model inference (e.g., 'cuda', 'cpu')"""
83+
84+ dtype : str | None = None
85+ """Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
86+
87+ num_workers : int | None = None
88+ """Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
89+
8790 revision : str = versioned ("0.1" )
8891 """Revision number to force re-runs when needed"""
8992
@@ -380,6 +383,8 @@ def compute_causal_conservation(
380383
381384 # Run inference for all reference/alternate sequences
382385 logits = logit_function (batch_alt_sequences )
386+ # Always promote to full precision for zero-shot evaluation
387+ logits = logits .astype (jnp .float32 )
383388 Vocab = logits .resolve_axis ("vocab" )
384389 assert logits .axes == (VariantBatch , Position , Vocab )
385390
@@ -406,6 +411,7 @@ def score_eval_dataset(
406411 eval_dataset : Dataset ,
407412 logit_function : Callable [[TokenArray ], LogitArray ],
408413 batch_size : int = 32 ,
414+ log_progress : bool = True ,
409415) -> ConservationResult :
410416 """Score evaluation dataset based on zero-shot conservation prediction."""
411417
@@ -420,7 +426,8 @@ def score_eval_dataset(
420426 batches = eval_dataset .with_format (None ).batch (batch_size = batch_size )
421427 total_batches = len (batches )
422428 progress_interval = max (1 , total_batches // 20 ) # Every 5%
423- logger .info (f"Processing { len (eval_dataset )} samples in { total_batches } batches (batch_size={ batch_size } )" )
429+ if log_progress :
430+ logger .info (f"Processing { len (eval_dataset )} samples in { total_batches } batches (batch_size={ batch_size } )" )
424431
425432 for batch_index , batch_data in enumerate (batches ):
426433 # Tokenize sequences
@@ -451,7 +458,7 @@ def score_eval_dataset(
451458 total_processed += len (sequences )
452459
453460 # Log progress every 5% of batches
454- if batch_index % progress_interval == 0 or batch_index == total_batches - 1 :
461+ if log_progress and ( batch_index % progress_interval == 0 or batch_index == total_batches - 1 ) :
455462 progress_pct = ((batch_index + 1 ) / total_batches ) * 100
456463 logger .info (
457464 f"Progress: { batch_index + 1 } /{ total_batches } batches ({ progress_pct :.1f} %) - "
@@ -466,44 +473,7 @@ def score_eval_dataset(
466473# ------------------------------------------------------------------------------------------------
467474
468475
469- def evaluate_dna_conservation (
470- tokenizer : AutoTokenizer ,
471- logit_function : Callable [[Any ], Any ],
472- eval_dataset : Dataset ,
473- batch_size : int = 32 ,
474- step : int | None = None ,
475- ) -> dict [str , float ]:
476- """
477- Core evaluation logic - works for both training callbacks and standalone evaluation.
478-
479- Args:
480- logit_function: Function that takes tokens and returns logits
481- eval_dataset: HuggingFace dataset with 'seq' field and binary 'label' field
482- batch_size: Batch size for evaluation
483- step: Training step (for logging), None for standalone
484-
485- Returns:
486- Dictionary with evaluation metrics including ROC AUC
487- """
488- # Collect scores and labels using shared function
489- result = score_eval_dataset (
490- tokenizer = tokenizer , logit_function = logit_function , eval_dataset = eval_dataset , batch_size = batch_size
491- )
492-
493- # Calculate metrics using shared function
494- results = evaluate_conservation_scores (result )
495-
496- # Log during training, log for standalone
497- if step is not None :
498- levanter .tracker .log ({"eval/dna_conservation/roc" : results ["roc_auc" ]}, step = step )
499- logger .info (f"Step { step } : ROC AUC = { results ['roc_auc' ]:.3f} " )
500- else :
501- logger .info (f"ROC AUC = { results ['roc_auc' ]:.4f} ({ results ['n_total' ]} valid nucleotides)" )
502-
503- return results
504-
505-
506- def create_dna_eval_callback (config : DnaEvalConfig ) -> Callable [[StepInfo ], None ]:
476+ def create_dna_eval_callback (config : DnaEvalBaseConfig ) -> Callable [[StepInfo ], None ]:
507477 """Create a training callback for DNA evaluation."""
508478
509479 # Load tokenizer
@@ -514,25 +484,43 @@ def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None
514484 dataset = load_eval_dataset (config )
515485
516486 def dna_conservation_callback (step_info : StepInfo ) -> None :
517- # Put model in inference mode
518- eval_model = inference_mode (step_info .state .model , True )
487+ step = step_info .step
488+ logger .debug (f"Running DNA conservation evaluation ({ step = } )" )
489+ eval_model = step_info .eval_model
519490
520491 # Create logit function for Levanter model
521492 def logit_function (
522493 tokens : ht .Int [ht .NamedArray , "batch position" ],
523494 ) -> ht .Float [ht .NamedArray , "batch position vocab" ]:
524- # TODO: validate input / output types
525- return eval_model ( tokens )
495+ logits = eval_model ( tokens )
496+ return logits
526497
527- # Run evaluation
528- evaluate_dna_conservation (
498+ # Compute scores with binary labels
499+ scores = score_eval_dataset (
529500 tokenizer = tokenizer ,
530501 logit_function = logit_function ,
531- eval_dataset = dataset , # Use the loaded dataset
502+ eval_dataset = dataset ,
532503 batch_size = config .batch_size ,
504+ # TODO: make configurable or disable?
505+ log_progress = True ,
506+ )
507+
508+ # Evaluate scores and labels
509+ metrics = evaluate_conservation_scores (scores )
510+
511+ # Log results
512+ levanter .tracker .log (
513+ {
514+ "eval/dna_conservation_roc" : metrics ["roc_auc" ],
515+ },
533516 step = step_info .step ,
534517 )
535518
519+ logger .info (
520+ f"DNA conservation evaluation complete ({ step = } ): "
521+ f"ROC AUC = { metrics ['roc_auc' ]:.4f} , n_samples = { metrics ['n_total' ]} "
522+ )
523+
536524 return dna_conservation_callback
537525
538526
@@ -614,7 +602,11 @@ def logit_function(
614602
615603 # Generate raw conservation scores and labels
616604 result = score_eval_dataset (
617- tokenizer = tokenizer , logit_function = logit_function , eval_dataset = dataset , batch_size = config .batch_size
605+ tokenizer = tokenizer ,
606+ logit_function = logit_function ,
607+ eval_dataset = dataset ,
608+ batch_size = config .batch_size ,
609+ log_progress = True ,
618610 )
619611
620612 logger .info (f"Generated { len (result .scores )} conservation scores" )
@@ -640,10 +632,7 @@ def evaluate_conservation_scores(scores: ConservationResult) -> dict[str, float]
640632 if len (scores .scores ) == 0 :
641633 raise ValueError ("No valid conservation scores found" )
642634
643- # Log total before filtering and filter out NaN scores
644635 n_unmasked_total = len (scores .scores )
645- logger .info (f"n_unmasked_total: { n_unmasked_total } " )
646-
647636 valid_mask = ~ np .isnan (scores .scores )
648637 filtered_scores = np .array (scores .scores )[valid_mask ]
649638 filtered_labels = np .array (scores .labels )[valid_mask ]
@@ -691,7 +680,8 @@ def save_conservation_results(config: DnaEvalConfig, results: dict[str, float])
691680 logger .info (f"Saved evaluation results to: { results_file } " )
692681
693682
694- @ray .remote (max_calls = 1 )
683+ # TODO: fix this which forces only one checkpoint to run at a time
684+ @ray .remote (max_calls = 1 , resources = {"head_node" : 1 })
695685def run_conservation_eval (config : DnaEvalConfig ) -> dict [str , float ]:
696686 # Determine number of workers
697687 if config .num_workers is None :
0 commit comments