1212from collections .abc import Callable
1313from contextlib import nullcontext
1414from dataclasses import dataclass
15- from typing import cast
15+ from typing import Self , cast
1616
1717import numpy as np
1818import torch
@@ -82,7 +82,7 @@ def ddp_enabled(self) -> bool:
8282 return self .world_size > 1
8383
8484 @classmethod
85- def empty_context (cls ) -> DistributedContext :
85+ def empty_context (cls ) -> Self :
8686 """Return an empty distributed context."""
8787
8888 return cls (rank = 0 , local_rank = 0 , world_size = 1 )
@@ -260,7 +260,6 @@ def __init__(
260260 min_lr : float = 1e-6 ,
261261 ** kwargs ,
262262 ) -> None :
263-
264263 self .max_lr = max_lr
265264 self .min_lr = min_lr
266265 self .warmup_steps = warmup_steps
@@ -279,13 +278,9 @@ def get_lr(self) -> list[float]:
279278 # Cosine decay phase
280279 elif (self .decay_steps > 0 ) and (self .last_epoch < self .total_steps ):
281280 # Fraction of decay completed (0 at start of decay, 1 at end)
282- decay_progress = (self .last_epoch - self .warmup_steps ) / max (
283- 1 , self .decay_steps
284- )
281+ decay_progress = (self .last_epoch - self .warmup_steps ) / max (1 , self .decay_steps )
285282 time = decay_progress * np .pi
286- curr_lr = self .min_lr + (self .max_lr - self .min_lr ) * 0.5 * (
287- 1.0 + float (np .cos (time ))
288- )
283+ curr_lr = self .min_lr + (self .max_lr - self .min_lr ) * 0.5 * (1.0 + float (np .cos (time )))
289284 else :
290285 curr_lr = self .min_lr
291286
@@ -326,9 +321,7 @@ def __init__(
326321
327322 # Determine batch size and num_workers
328323 self .batch_size = (
329- batch_size_override
330- if batch_size_override is not None
331- else self .config .train .batch_size
324+ batch_size_override if batch_size_override is not None else self .config .train .batch_size
332325 )
333326 self .num_workers = (
334327 num_workers_override
@@ -382,11 +375,7 @@ def module(self) -> Protpardelle:
382375 nn .DataParallel ,
383376 DDP ,
384377 )
385- return (
386- self .model .module
387- if isinstance (self .model , parallel_wrappers )
388- else self .model
389- )
378+ return self .model .module if isinstance (self .model , parallel_wrappers ) else self .model
390379
391380 @property
392381 def device (self ) -> torch .device :
@@ -450,9 +439,7 @@ def save_checkpoint(
450439 }
451440 checkpoint ["rng" ] = {
452441 "torch" : torch .get_rng_state (),
453- "cuda" : (
454- torch .cuda .get_rng_state_all () if torch .cuda .is_available () else None
455- ),
442+ "cuda" : (torch .cuda .get_rng_state_all () if torch .cuda .is_available () else None ),
456443 "numpy" : np .random .get_state (),
457444 "python" : random .getstate (),
458445 "sampler_seed" : (
@@ -485,9 +472,7 @@ def load_checkpoint(
485472 checkpoint_path = norm_path (checkpoint_path )
486473 if not checkpoint_path .is_file ():
487474 raise FileNotFoundError (f"Checkpoint file not found: { checkpoint_path } " )
488- checkpoint = torch .load (
489- checkpoint_path , map_location = self .device , weights_only = False
490- )
475+ checkpoint = torch .load (checkpoint_path , map_location = self .device , weights_only = False )
491476
492477 self .module .load_state_dict (checkpoint ["model_state_dict" ])
493478 self .optimizer .load_state_dict (checkpoint ["optimizer" ])
@@ -497,9 +482,7 @@ def load_checkpoint(
497482 torch .set_rng_state (checkpoint ["rng" ]["torch" ].cpu ())
498483 if torch .cuda .is_available ():
499484 if checkpoint ["rng" ]["cuda" ] is None :
500- raise ValueError (
501- "Checkpoint was trained with CUDA but current device is CPU"
502- )
485+ raise ValueError ("Checkpoint was trained with CUDA but current device is CPU" )
503486 cuda_states = [state .cpu () for state in checkpoint ["rng" ]["cuda" ]]
504487 torch .cuda .set_rng_state_all (cuda_states )
505488 np .random .set_state (checkpoint ["rng" ]["numpy" ])
@@ -522,9 +505,7 @@ def initialize_training_parameters(self) -> tuple[int, int]:
522505 if seed is not None :
523506 if self .ddp_enabled :
524507 seed += self .distributed .rank
525- seed_everything (
526- seed , freeze_cuda = True
527- ) # use deterministic pytorch for training
508+ seed_everything (seed , freeze_cuda = True ) # use deterministic pytorch for training
528509
529510 return start_epoch , total_steps
530511
@@ -550,17 +531,13 @@ def start_or_resume(self) -> tuple[int, int]:
550531 )
551532 return start_epoch , total_steps
552533 except FileNotFoundError :
553- logger .warning (
554- "Checkpoint file not found: %s; starting from scratch" , checkpoint_path
555- )
534+ logger .warning ("Checkpoint file not found: %s; starting from scratch" , checkpoint_path )
556535
557536 return self .initialize_training_parameters ()
558537
559538 def log_training_info (self ) -> None :
560539 """Log training information."""
561- logger .info (
562- "Total params: %d" , sum (p .numel () for p in self .module .parameters ())
563- )
540+ logger .info ("Total params: %d" , sum (p .numel () for p in self .module .parameters ()))
564541 logger .info (
565542 "Trainable params: %d" ,
566543 sum (p .numel () for p in self .module .parameters () if p .requires_grad ),
@@ -575,21 +552,18 @@ def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
575552 batch_dict = cast (dict [str , torch .Tensor ], default_collate (batch ))
576553
577554 if self .config .train .crop_conditional :
578-
579555 atom_coords = batch_dict ["coords_in" ]
580556 atom_mask = batch_dict ["atom_mask" ]
581557 aatype = batch_dict ["aatype" ]
582558 chain_index = batch_dict ["chain_index" ]
583559
584560 # Pre-compute crop conditioning mask and recentered coords for efficiency
585- atom_coords , crop_cond_mask , hotspot_mask = (
586- make_crop_cond_mask_and_recenter_coords (
587- atom_coords = atom_coords ,
588- atom_mask = atom_mask ,
589- aatype = aatype ,
590- chain_index = chain_index ,
591- ** vars (self .config .train .crop_cond ),
592- )
561+ atom_coords , crop_cond_mask , hotspot_mask = make_crop_cond_mask_and_recenter_coords (
562+ atom_coords = atom_coords ,
563+ atom_mask = atom_mask ,
564+ aatype = aatype ,
565+ chain_index = chain_index ,
566+ ** vars (self .config .train .crop_cond ),
593567 )
594568 struct_crop_cond = atom_coords * crop_cond_mask .unsqueeze (- 1 )
595569
@@ -670,18 +644,14 @@ def compute_loss(
670644 # Crop conditioning
671645 if self .config .train .crop_conditional :
672646 if self .config .model .compute_loss_on_all_atoms :
673- raise NotImplementedError (
674- "Crop conditioning with all atom loss not implemented"
675- )
647+ raise NotImplementedError ("Crop conditioning with all atom loss not implemented" )
676648
677649 crop_cond_mask = input_dict .get ("crop_cond_mask" )
678650 struct_crop_cond = input_dict .get ("struct_crop_cond" )
679651 hotspot_mask = input_dict .get ("hotspot_mask" )
680652
681653 # If using correct data loader and collate_fn, these should never be None
682- assert all (
683- x is not None for x in [crop_cond_mask , struct_crop_cond , hotspot_mask ]
684- )
654+ assert all (x is not None for x in [crop_cond_mask , struct_crop_cond , hotspot_mask ])
685655
686656 if "hotspots" not in self .config .model .conditioning_style :
687657 hotspot_mask = None # type: ignore
@@ -698,9 +668,7 @@ def compute_loss(
698668 adj_cond = None
699669
700670 # Noise data
701- timestep = torch .rand (batch_size , device = self .device ).clamp (
702- min = tol , max = 1 - tol
703- )
671+ timestep = torch .rand (batch_size , device = self .device ).clamp (min = tol , max = 1 - tol )
704672 noise_level = self .module .training_noise_schedule (timestep )
705673 noised_coords = dummy_fill_noise_coords (
706674 atom_coords ,
@@ -713,9 +681,7 @@ def compute_loss(
713681 bb_atom_mask = atom37_mask_from_aatype (bb_seq , seq_mask )
714682
715683 # Some backbone atoms may be missing; mask them to zeros
716- bb_atom_mask = (
717- bb_atom_mask * atom_mask
718- ) # float masks; multiply instead of boolean ops
684+ bb_atom_mask = bb_atom_mask * atom_mask # float masks; multiply instead of boolean ops
719685 if self .config .model .task == "backbone" :
720686 noised_coords = noised_coords * bb_atom_mask .unsqueeze (- 1 )
721687 elif self .config .model .task == "ai-allatom" :
@@ -759,18 +725,14 @@ def compute_loss(
759725 "codesign" ,
760726 }:
761727 if self .config .model .task == "backbone" :
762- struct_loss_mask = torch .ones_like (
763- atom_coords
764- ) * bb_atom_mask .unsqueeze (- 1 )
728+ struct_loss_mask = torch .ones_like (atom_coords ) * bb_atom_mask .unsqueeze (- 1 )
765729 elif self .config .model .compute_loss_on_all_atoms :
766730 # Compute loss on all 37 atoms
767- struct_loss_mask = torch .ones_like (
768- atom_coords
769- ) * unsqueeze_trailing_dims (seq_mask , atom_coords )
770- else :
771- struct_loss_mask = torch .ones_like (atom_coords ) * atom_mask .unsqueeze (
772- - 1
731+ struct_loss_mask = torch .ones_like (atom_coords ) * unsqueeze_trailing_dims (
732+ seq_mask , atom_coords
773733 )
734+ else :
735+ struct_loss_mask = torch .ones_like (atom_coords ) * atom_mask .unsqueeze (- 1 )
774736
775737 sigma_fp32 = torch .tensor (
776738 self .config .data .sigma_data ,
@@ -790,9 +752,7 @@ def compute_loss(
790752 alpha = self .config .model .mpnn_model .label_smoothing
791753 aatype_oh = F .one_hot (aatype , self .config .data .n_aatype_tokens ).float ()
792754 target_oh = (1 - alpha ) * aatype_oh + alpha / self .module .num_tokens
793- mpnn_loss = masked_cross_entropy_loss (
794- aatype_logprobs , target_oh , seq_mask
795- ).mean ()
755+ mpnn_loss = masked_cross_entropy_loss (aatype_logprobs , target_oh , seq_mask ).mean ()
796756 loss = loss + mpnn_loss
797757 log_dict ["mpnn_loss" ] = mpnn_loss .detach ().cpu ().item ()
798758
@@ -843,7 +803,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor]) -> dict[str, float]:
843803
844804
845805@record
846- def train (
806+ def train ( # noqa: C901
847807 config_path : StrPath ,
848808 output_dir : StrPath ,
849809 device : Device = None ,
@@ -873,10 +833,7 @@ def train(
873833 logger .info ("Enabled TF32 mode for faster training on Ampere+ GPUs" )
874834
875835 # Set and resolve device with DDP if applicable
876- if device is None :
877- requested_device = get_default_device ()
878- else :
879- requested_device = torch .device (device )
836+ requested_device = get_default_device () if device is None else torch .device (device )
880837 resolved_device , distributed = resolve_device_with_distributed (requested_device )
881838
882839 # Load config
@@ -974,22 +931,15 @@ def train(
974931 wandb_run = wandb .init (** wandb_kwargs )
975932 if wandb_run is None :
976933 raise RuntimeError ("Failed to initialize wandb run" )
977- if (
978- (wandb_run .name is None )
979- or (wandb_run .dir is None )
980- or (wandb_run .id is None )
981- ):
934+ if (wandb_run .name is None ) or (wandb_run .dir is None ) or (wandb_run .id is None ):
982935 raise RuntimeError ("wandb returned an incomplete run object" )
983936 run_name = wandb_run .name
984937 run_dir = wandb_run .dir
985938 run_id = wandb_run .id
986939 else :
987940 # Non-main ranks reuse exp_name for logging clarity
988941 if exp_name is None :
989- if distributed .ddp_enabled :
990- run_name = f"run-rank{ distributed .rank } "
991- else :
992- run_name = "run"
942+ run_name = f"run-rank{ distributed .rank } " if distributed .ddp_enabled else "run"
993943 else :
994944 run_name = exp_name
995945
@@ -1051,19 +1001,15 @@ def train(
10511001 disable = not trainer .is_main ,
10521002 )
10531003 for input_dict in progress :
1054- input_dict : dict [str , torch .Tensor ] = {
1055- k : v .to (
1056- trainer .device , non_blocking = True
1057- ) # non_blocking for pin_memory
1004+ input_dict : dict [str , torch .Tensor ] = { # noqa: PLW2901
1005+ k : v .to (trainer .device , non_blocking = True ) # non_blocking for pin_memory
10581006 for k , v in input_dict .items ()
10591007 }
10601008 assert "cyclic_mask" in input_dict # TODO: test and remove
10611009 log_dict = trainer .train_step (input_dict )
10621010 log_dict ["learning_rate" ] = trainer .scheduler .get_last_lr ()[0 ]
10631011 log_dict ["epoch" ] = epoch
1064- log_dict = log_distributed_mean (
1065- log_dict , trainer .device , distributed
1066- )
1012+ log_dict = log_distributed_mean (log_dict , trainer .device , distributed )
10671013
10681014 # Log to wandb on main rank only
10691015 if trainer .is_main :
0 commit comments