22
33import logging
44import os
5- import time
65from collections import defaultdict
76from contextlib import AbstractContextManager
87from typing import Any , Dict , List , Optional , Tuple , Union
1918from arch_eval .distributed import cleanup_distributed , get_wrapped_model , init_distributed
2019from arch_eval .logging .logger_config import LoggerAdapter
2120from arch_eval .metrics .calculator import MetricCalculator
22- from arch_eval .plugins .manager import PluginManager , hook
21+ from arch_eval .plugins .manager import PluginManager
2322from arch_eval .profiler import profiler_context
24- from arch_eval .utils .device import memory_summary
23+ from arch_eval .utils .device import memory_summary , auto_device
2524from arch_eval .viz .viz import PlotSaver , RealtimeWindow , VideoRecorder
2625
2726logger = logging .getLogger (__name__ )
@@ -61,13 +60,15 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
6160 else :
6261 self .model = model
6362
64- self ._validate_model ()
6563 self .device = torch .device (config .device )
6664 self .model = self .model .to (self .device ).to (config .dtype )
6765
6866 self .dataset_handler = DatasetHandler (config )
6967 self .train_loader , self .val_loader , self .test_loader = self .dataset_handler .prepare_loaders ()
7068
69+ # Validate model with a real batch (if train_loader exists)
70+ self ._validate_model_with_data ()
71+
7172 self .metric_calculator = MetricCalculator (
7273 config .task , config .device , output_transform = config .model_output_transform
7374 )
@@ -81,15 +82,15 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
8182 self .amp_dtype = self ._get_amp_dtype ()
8283 self .scaler = torch .cuda .amp .GradScaler () if self .use_amp and config .grad_scaler else None
8384
84- # Gradient checkpointing (experimental) TODO
85+ # Gradient checkpointing (experimental)
8586 if config .gradient_checkpointing :
8687 self ._apply_gradient_checkpointing ()
8788
8889 # Visualization
8990 self .window = None
9091 if config .realtime :
9192 try :
92- self .window = RealtimeWindow (config )
93+ self .window = RealtimeWindow (config , metric_names = config . viz_metrics )
9394 if getattr (self .window , "disabled" , False ):
9495 self .window = None
9596 except Exception as e :
@@ -120,6 +121,9 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
120121 self .accumulation_steps = config .gradient_accumulation_steps
121122 self .current_accum_step = 0
122123
124+ # Initialize checkpoint best metric
125+ self .checkpoint_best_metric = None
126+
123127 self .logger .info (f"Trainer initialized on { self .device } \n { memory_summary ()} " )
124128
125129 def _get_amp_dtype (self ):
@@ -135,10 +139,7 @@ def _get_amp_dtype(self):
135139 return torch .float16
136140
137141 def _apply_gradient_checkpointing (self ):
138- """Experimental: attempts to enable gradient checkpointing on specified modules.
139- This is not a standard PyTorch feature; models must implement it internally.
140- The current implementation sets a '_gradient_checkpointing' attribute on modules,
141- which may be used by custom layers. For most models, this will have no effect."""
142+ """Experimental: attempts to enable gradient checkpointing on specified modules."""
142143 if self .config .gradient_checkpointing_modules :
143144 for name in self .config .gradient_checkpointing_modules :
144145 module = dict (self .model .named_modules ()).get (name )
@@ -150,14 +151,24 @@ def _apply_gradient_checkpointing(self):
150151 module ._gradient_checkpointing = True
151152 self .logger .warning ("Gradient checkpointing is experimental and may not work as expected." )
152153
153- def _validate_model (self ):
154- shape = self .config .input_shape or (1 , 10 )
155- dummy = torch .randn (1 , * shape ).to (torch .device (self .config .device ))
154+ def _validate_model_with_data (self ):
155+ """Run a forward pass on a single batch to ensure the model accepts the data."""
156+ if self .train_loader is None :
157+ self .logger .warning ("No training loader – skipping model validation." )
158+ return
156159 try :
160+ # Get one batch
161+ data , targets = next (iter (self .train_loader ))
162+ data = data .to (self .device )
163+ targets = targets .to (self .device )
164+ self .model .eval ()
157165 with torch .no_grad ():
158- self .model (dummy )
166+ _ = self .model (data )
167+ self .model .train ()
168+ self .logger .info ("Model validation passed." )
159169 except Exception as e :
160- raise ModelError (f"Model validation failed: { e } " )
170+ raise ModelError (f"Model validation failed on a real batch: { e } . "
171+ "Check that your model's input size matches the dataset features." )
161172
162173 def _setup_optimizers (self ):
163174 self .optimizers = []
@@ -231,7 +242,7 @@ def _setup_loss_function(self):
231242
232243 def _compute_loss (self , output , targets ):
233244 if isinstance (output , tuple ) and len (output ) == 2 :
234- return output [1 ] # Assume second element is loss
245+ return output [1 ] # Assume second element is loss
235246 else :
236247 return self .criterion (output , targets )
237248
@@ -298,8 +309,6 @@ def train(self) -> Dict[str, List[float]]:
298309 self .window .close ()
299310 if self .config .log_to_wandb :
300311 wandb .finish ()
301- if hasattr (self , "tb_writer" ) and self .tb_writer :
302- self .tb_writer .close ()
303312 if self .config .distributed_backend != DistributedBackend .NONE :
304313 cleanup_distributed ()
305314
@@ -334,7 +343,6 @@ def _train_epoch(self) -> Dict[str, float]:
334343 else :
335344 loss .backward ()
336345
337- # self.plugin_manager.execute_hook("on_backward", self, loss.item() * self.accumulation_steps) # FIXME: Possible Overhead
338346 self .current_accum_step += 1
339347
340348 if self .current_accum_step % self .accumulation_steps == 0 :
@@ -397,7 +405,6 @@ def _evaluate(self, loader: DataLoader, split: str) -> Dict[str, float]:
397405 total_loss = 0.0
398406 metric_accum = defaultdict (float )
399407 count = 0
400- autocast = torch .cuda .amp .autocast if self .use_amp else NullContext
401408
402409 # Reset confusion matrix accumulator if needed
403410 if self .config .log_confusion_matrix and split == "val" :
@@ -437,8 +444,8 @@ def _evaluate(self, loader: DataLoader, split: str) -> Dict[str, float]:
437444 wandb .log ({
438445 f"confusion_matrix/{ split } " : wandb .plot .confusion_matrix (
439446 probs = None ,
440- y_true = self .metric_calculator ._all_targets ,
441- preds = self .metric_calculator ._all_preds ,
447+ y_true = np . array ( self .metric_calculator ._all_targets ) ,
448+ preds = np . array ( self .metric_calculator ._all_preds ) ,
442449 class_names = class_names
443450 )
444451 }, step = self .current_epoch )
@@ -475,9 +482,6 @@ def _log_metrics(self, metrics: Dict[str, float], step: int):
475482
476483 if self .config .log_to_wandb :
477484 wandb .log (metrics , step = step )
478- if hasattr (self , "tb_writer" ) and self .tb_writer :
479- for k , v in metrics .items ():
480- self .tb_writer .add_scalar (k , v , step )
481485
482486 self .plugin_manager .execute_hook ("on_log" , self , metrics , step )
483487
@@ -502,13 +506,18 @@ def _save_checkpoint(self, epoch: int, metrics: Dict[str, float]):
502506 current = metrics .get (self .config .checkpoint_metric )
503507 if current is not None :
504508 mode = "min" if "loss" in self .config .checkpoint_metric else "max"
505- improved = (mode == "min" and current < self .checkpoint_best_metric ) or (
506- mode == "max" and current > self .checkpoint_best_metric
507- )
508- if improved :
509+ if self .checkpoint_best_metric is None :
509510 self .checkpoint_best_metric = current
510511 is_best = True
511512 save_this = True
513+ else :
514+ improved = (mode == "min" and current < self .checkpoint_best_metric ) or (
515+ mode == "max" and current > self .checkpoint_best_metric
516+ )
517+ if improved :
518+ self .checkpoint_best_metric = current
519+ is_best = True
520+ save_this = True
512521 else :
513522 if epoch % self .config .save_frequency == 0 :
514523 save_this = True
0 commit comments