3939from levanter .callbacks import StepInfo
4040from marin .utilities .json_encoder import CustomJsonEncoder
4141
42- from experiments .plantcad .utils import get_available_gpus , get_nucleotide_token_ids , get_plantcad_tokenizer
42+ from experiments .plantcad .utils import get_available_gpus , get_nucleotide_token_ids
43+ from levanter .utils .hf_utils import HfTokenizer
4344from marin .execution .executor import InputName , this_output_path , versioned
45+ from jax .sharding import Mesh
46+ from haliax .partitioning import ResourceMapping
4447
4548logger = logging .getLogger ("ray" )
4649
4952class DnaEvalBaseConfig :
5053 """Base configuration for DNA evaluation with fields needed for training callbacks"""
5154
52- model_config : str
53- """Model configuration size (e.g., '300m', '100m', etc.)"""
54-
5555 dataset_path : str = "plantcad/evolutionary-constraint-example"
5656 """Dataset repository path"""
5757
@@ -292,12 +292,12 @@ def create_alternate_sequences(
292292 assert 0 <= ref_cts .max ().item () <= 1
293293 if (invalid := ref_cts == 0 ).any ().item ():
294294 pos = nucleotide_positions [Batch , invalid ]
295- tok = tokens_expanded [Batch , invalid , Position , pos ]
295+ tok = tokens_expanded [Batch , invalid ][ Position , pos ]
296296 raise ValueError (
297- "Found invalid sequences in batch with OOV nucleotides at target positions; "
298- f"Target positions: { pos } "
299- f"Valid nucleotide token IDs: { nucleotide_token_ids } "
300- f"Invalid tokens: { tok } "
297+ "Found invalid sequences in batch with OOV nucleotides at target positions;\n "
298+ f"Target positions: { pos . array } \n "
299+ f"Valid nucleotide token IDs: { nucleotide_token_ids } \n "
300+ f"Invalid tokens: { tok . array } "
301301 )
302302 ref = hax .argmax (ref_mask , axis = Variant )
303303 assert ref .axes == (Batch ,)
@@ -439,7 +439,7 @@ def score_eval_dataset(
439439 assert isinstance (pos , list )
440440
441441 # Tokenize and convert to JAX arrays
442- tokenized = tokenizer (sequences , padding = True , truncation = True , max_length = 512 , return_tensors = "np" )
442+ tokenized = tokenizer (sequences , padding = False , add_special_tokens = False , truncation = False , return_tensors = "np" )
443443 tokens = hax .named (tokenized ["input_ids" ], ("batch" , "position" ))
444444 nucleotide_positions = hax .named (pos , ("batch" ,))
445445
@@ -473,12 +473,25 @@ def score_eval_dataset(
473473# ------------------------------------------------------------------------------------------------
474474
475475
476- def create_dna_eval_callback (config : DnaEvalBaseConfig ) -> Callable [[StepInfo ], None ]:
477- """Create a training callback for DNA evaluation."""
476+ def create_dna_eval_callback (
477+ config : DnaEvalBaseConfig ,
478+ tokenizer : HfTokenizer ,
479+ device_mesh : Mesh ,
480+ compute_axis_mapping : ResourceMapping ,
481+ parameter_axis_mapping : ResourceMapping ,
482+ ) -> Callable [[StepInfo ], None ]:
483+ """Create a training callback for DNA evaluation.
484+
485+ Args:
486+ config: DNA evaluation configuration
487+ tokenizer: Tokenizer provided by Levanter's training loop
488+ device_mesh: JAX device mesh for distributed computation
489+ compute_axis_mapping: Axis mapping for computation (used during model inference)
490+ parameter_axis_mapping: Axis mapping for parameter storage (used for model sharding)
478491
479- # Load tokenizer
480- # TODO: fix this; how can the tokenizer be referenced during training without reloading?
481- tokenizer = get_plantcad_tokenizer ()
492+ Returns:
493+ Callback function that evaluates DNA conservation at training steps
494+ """
482495
483496 # Load and validate dataset once when creating the callback
484497 dataset = load_eval_dataset (config )
@@ -488,7 +501,7 @@ def dna_conservation_callback(step_info: StepInfo) -> None:
488501 logger .debug (f"Running DNA conservation evaluation ({ step = } )" )
489502 eval_model = step_info .eval_model
490503
491- # Create logit function for Levanter model
504+ # Create logit function for Levanter model with proper axis mapping
492505 def logit_function (
493506 tokens : ht .Int [ht .NamedArray , "batch position" ],
494507 ) -> ht .Float [ht .NamedArray , "batch position vocab" ]:
@@ -501,7 +514,6 @@ def logit_function(
501514 logit_function = logit_function ,
502515 eval_dataset = dataset ,
503516 batch_size = config .batch_size ,
504- # TODO: make configurable or disable?
505517 log_progress = True ,
506518 )
507519
@@ -710,18 +722,12 @@ def run_conservation_eval(config: DnaEvalConfig) -> dict[str, float]:
710722# Usage examples:
711723
712724# 1. Training callback (uses Levanter model from training state):
713- # config = DnaEvalConfig(
714- # checkpoint_path="/path/to/checkpoint", # Not used for callbacks
715- # model_config="300m",
716- # dataset_path="plantcad/evolutionary-constraint-example",
717- # dataset_config="10k"
718- # )
719- # trainer.add_hook(create_dna_eval_callback(config), every=1000)
725+ # The plugin system handles this automatically via PlantCADEvaluationPlugin
726+ # See plugin.py for implementation details
720727
721728# 2. Standalone evaluation with HuggingFace checkpoint:
722729# config = DnaEvalConfig(
723730# checkpoint_path="/path/to/hf/checkpoint",
724- # model_config="300m",
725731# device="cuda", # or "cpu" for CPU inference
726732# num_workers=None, # defaults to number of GPUs
727733# dataset_path="plantcad/evolutionary-constraint-example",
0 commit comments