Skip to content

Commit 6f674fd

Browse files
refactor: improve type hints and code formatting in train.py for better readability and maintainability
1 parent c17197c commit 6f674fd

1 file changed

Lines changed: 35 additions & 89 deletions

File tree

src/protpardelle/train.py

Lines changed: 35 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import Callable
1313
from contextlib import nullcontext
1414
from dataclasses import dataclass
15-
from typing import cast
15+
from typing import Self, cast
1616

1717
import numpy as np
1818
import torch
@@ -82,7 +82,7 @@ def ddp_enabled(self) -> bool:
8282
return self.world_size > 1
8383

8484
@classmethod
85-
def empty_context(cls) -> DistributedContext:
85+
def empty_context(cls) -> Self:
8686
"""Return an empty distributed context."""
8787

8888
return cls(rank=0, local_rank=0, world_size=1)
@@ -260,7 +260,6 @@ def __init__(
260260
min_lr: float = 1e-6,
261261
**kwargs,
262262
) -> None:
263-
264263
self.max_lr = max_lr
265264
self.min_lr = min_lr
266265
self.warmup_steps = warmup_steps
@@ -279,13 +278,9 @@ def get_lr(self) -> list[float]:
279278
# Cosine decay phase
280279
elif (self.decay_steps > 0) and (self.last_epoch < self.total_steps):
281280
# Fraction of decay completed (0 at start of decay, 1 at end)
282-
decay_progress = (self.last_epoch - self.warmup_steps) / max(
283-
1, self.decay_steps
284-
)
281+
decay_progress = (self.last_epoch - self.warmup_steps) / max(1, self.decay_steps)
285282
time = decay_progress * np.pi
286-
curr_lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (
287-
1.0 + float(np.cos(time))
288-
)
283+
curr_lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1.0 + float(np.cos(time)))
289284
else:
290285
curr_lr = self.min_lr
291286

@@ -326,9 +321,7 @@ def __init__(
326321

327322
# Determine batch size and num_workers
328323
self.batch_size = (
329-
batch_size_override
330-
if batch_size_override is not None
331-
else self.config.train.batch_size
324+
batch_size_override if batch_size_override is not None else self.config.train.batch_size
332325
)
333326
self.num_workers = (
334327
num_workers_override
@@ -382,11 +375,7 @@ def module(self) -> Protpardelle:
382375
nn.DataParallel,
383376
DDP,
384377
)
385-
return (
386-
self.model.module
387-
if isinstance(self.model, parallel_wrappers)
388-
else self.model
389-
)
378+
return self.model.module if isinstance(self.model, parallel_wrappers) else self.model
390379

391380
@property
392381
def device(self) -> torch.device:
@@ -450,9 +439,7 @@ def save_checkpoint(
450439
}
451440
checkpoint["rng"] = {
452441
"torch": torch.get_rng_state(),
453-
"cuda": (
454-
torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
455-
),
442+
"cuda": (torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None),
456443
"numpy": np.random.get_state(),
457444
"python": random.getstate(),
458445
"sampler_seed": (
@@ -485,9 +472,7 @@ def load_checkpoint(
485472
checkpoint_path = norm_path(checkpoint_path)
486473
if not checkpoint_path.is_file():
487474
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
488-
checkpoint = torch.load(
489-
checkpoint_path, map_location=self.device, weights_only=False
490-
)
475+
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
491476

492477
self.module.load_state_dict(checkpoint["model_state_dict"])
493478
self.optimizer.load_state_dict(checkpoint["optimizer"])
@@ -497,9 +482,7 @@ def load_checkpoint(
497482
torch.set_rng_state(checkpoint["rng"]["torch"].cpu())
498483
if torch.cuda.is_available():
499484
if checkpoint["rng"]["cuda"] is None:
500-
raise ValueError(
501-
"Checkpoint was trained with CUDA but current device is CPU"
502-
)
485+
raise ValueError("Checkpoint was trained with CUDA but current device is CPU")
503486
cuda_states = [state.cpu() for state in checkpoint["rng"]["cuda"]]
504487
torch.cuda.set_rng_state_all(cuda_states)
505488
np.random.set_state(checkpoint["rng"]["numpy"])
@@ -522,9 +505,7 @@ def initialize_training_parameters(self) -> tuple[int, int]:
522505
if seed is not None:
523506
if self.ddp_enabled:
524507
seed += self.distributed.rank
525-
seed_everything(
526-
seed, freeze_cuda=True
527-
) # use deterministic pytorch for training
508+
seed_everything(seed, freeze_cuda=True) # use deterministic pytorch for training
528509

529510
return start_epoch, total_steps
530511

@@ -550,17 +531,13 @@ def start_or_resume(self) -> tuple[int, int]:
550531
)
551532
return start_epoch, total_steps
552533
except FileNotFoundError:
553-
logger.warning(
554-
"Checkpoint file not found: %s; starting from scratch", checkpoint_path
555-
)
534+
logger.warning("Checkpoint file not found: %s; starting from scratch", checkpoint_path)
556535

557536
return self.initialize_training_parameters()
558537

559538
def log_training_info(self) -> None:
560539
"""Log training information."""
561-
logger.info(
562-
"Total params: %d", sum(p.numel() for p in self.module.parameters())
563-
)
540+
logger.info("Total params: %d", sum(p.numel() for p in self.module.parameters()))
564541
logger.info(
565542
"Trainable params: %d",
566543
sum(p.numel() for p in self.module.parameters() if p.requires_grad),
@@ -575,21 +552,18 @@ def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
575552
batch_dict = cast(dict[str, torch.Tensor], default_collate(batch))
576553

577554
if self.config.train.crop_conditional:
578-
579555
atom_coords = batch_dict["coords_in"]
580556
atom_mask = batch_dict["atom_mask"]
581557
aatype = batch_dict["aatype"]
582558
chain_index = batch_dict["chain_index"]
583559

584560
# Pre-compute crop conditioning mask and recentered coords for efficiency
585-
atom_coords, crop_cond_mask, hotspot_mask = (
586-
make_crop_cond_mask_and_recenter_coords(
587-
atom_coords=atom_coords,
588-
atom_mask=atom_mask,
589-
aatype=aatype,
590-
chain_index=chain_index,
591-
**vars(self.config.train.crop_cond),
592-
)
561+
atom_coords, crop_cond_mask, hotspot_mask = make_crop_cond_mask_and_recenter_coords(
562+
atom_coords=atom_coords,
563+
atom_mask=atom_mask,
564+
aatype=aatype,
565+
chain_index=chain_index,
566+
**vars(self.config.train.crop_cond),
593567
)
594568
struct_crop_cond = atom_coords * crop_cond_mask.unsqueeze(-1)
595569

@@ -670,18 +644,14 @@ def compute_loss(
670644
# Crop conditioning
671645
if self.config.train.crop_conditional:
672646
if self.config.model.compute_loss_on_all_atoms:
673-
raise NotImplementedError(
674-
"Crop conditioning with all atom loss not implemented"
675-
)
647+
raise NotImplementedError("Crop conditioning with all atom loss not implemented")
676648

677649
crop_cond_mask = input_dict.get("crop_cond_mask")
678650
struct_crop_cond = input_dict.get("struct_crop_cond")
679651
hotspot_mask = input_dict.get("hotspot_mask")
680652

681653
# If using correct data loader and collate_fn, these should never be None
682-
assert all(
683-
x is not None for x in [crop_cond_mask, struct_crop_cond, hotspot_mask]
684-
)
654+
assert all(x is not None for x in [crop_cond_mask, struct_crop_cond, hotspot_mask])
685655

686656
if "hotspots" not in self.config.model.conditioning_style:
687657
hotspot_mask = None # type: ignore
@@ -698,9 +668,7 @@ def compute_loss(
698668
adj_cond = None
699669

700670
# Noise data
701-
timestep = torch.rand(batch_size, device=self.device).clamp(
702-
min=tol, max=1 - tol
703-
)
671+
timestep = torch.rand(batch_size, device=self.device).clamp(min=tol, max=1 - tol)
704672
noise_level = self.module.training_noise_schedule(timestep)
705673
noised_coords = dummy_fill_noise_coords(
706674
atom_coords,
@@ -713,9 +681,7 @@ def compute_loss(
713681
bb_atom_mask = atom37_mask_from_aatype(bb_seq, seq_mask)
714682

715683
# Some backbone atoms may be missing; mask them to zeros
716-
bb_atom_mask = (
717-
bb_atom_mask * atom_mask
718-
) # float masks; multiply instead of boolean ops
684+
bb_atom_mask = bb_atom_mask * atom_mask # float masks; multiply instead of boolean ops
719685
if self.config.model.task == "backbone":
720686
noised_coords = noised_coords * bb_atom_mask.unsqueeze(-1)
721687
elif self.config.model.task == "ai-allatom":
@@ -759,18 +725,14 @@ def compute_loss(
759725
"codesign",
760726
}:
761727
if self.config.model.task == "backbone":
762-
struct_loss_mask = torch.ones_like(
763-
atom_coords
764-
) * bb_atom_mask.unsqueeze(-1)
728+
struct_loss_mask = torch.ones_like(atom_coords) * bb_atom_mask.unsqueeze(-1)
765729
elif self.config.model.compute_loss_on_all_atoms:
766730
# Compute loss on all 37 atoms
767-
struct_loss_mask = torch.ones_like(
768-
atom_coords
769-
) * unsqueeze_trailing_dims(seq_mask, atom_coords)
770-
else:
771-
struct_loss_mask = torch.ones_like(atom_coords) * atom_mask.unsqueeze(
772-
-1
731+
struct_loss_mask = torch.ones_like(atom_coords) * unsqueeze_trailing_dims(
732+
seq_mask, atom_coords
773733
)
734+
else:
735+
struct_loss_mask = torch.ones_like(atom_coords) * atom_mask.unsqueeze(-1)
774736

775737
sigma_fp32 = torch.tensor(
776738
self.config.data.sigma_data,
@@ -790,9 +752,7 @@ def compute_loss(
790752
alpha = self.config.model.mpnn_model.label_smoothing
791753
aatype_oh = F.one_hot(aatype, self.config.data.n_aatype_tokens).float()
792754
target_oh = (1 - alpha) * aatype_oh + alpha / self.module.num_tokens
793-
mpnn_loss = masked_cross_entropy_loss(
794-
aatype_logprobs, target_oh, seq_mask
795-
).mean()
755+
mpnn_loss = masked_cross_entropy_loss(aatype_logprobs, target_oh, seq_mask).mean()
796756
loss = loss + mpnn_loss
797757
log_dict["mpnn_loss"] = mpnn_loss.detach().cpu().item()
798758

@@ -843,7 +803,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor]) -> dict[str, float]:
843803

844804

845805
@record
846-
def train(
806+
def train( # noqa: C901
847807
config_path: StrPath,
848808
output_dir: StrPath,
849809
device: Device = None,
@@ -873,10 +833,7 @@ def train(
873833
logger.info("Enabled TF32 mode for faster training on Ampere+ GPUs")
874834

875835
# Set and resolve device with DDP if applicable
876-
if device is None:
877-
requested_device = get_default_device()
878-
else:
879-
requested_device = torch.device(device)
836+
requested_device = get_default_device() if device is None else torch.device(device)
880837
resolved_device, distributed = resolve_device_with_distributed(requested_device)
881838

882839
# Load config
@@ -974,22 +931,15 @@ def train(
974931
wandb_run = wandb.init(**wandb_kwargs)
975932
if wandb_run is None:
976933
raise RuntimeError("Failed to initialize wandb run")
977-
if (
978-
(wandb_run.name is None)
979-
or (wandb_run.dir is None)
980-
or (wandb_run.id is None)
981-
):
934+
if (wandb_run.name is None) or (wandb_run.dir is None) or (wandb_run.id is None):
982935
raise RuntimeError("wandb returned an incomplete run object")
983936
run_name = wandb_run.name
984937
run_dir = wandb_run.dir
985938
run_id = wandb_run.id
986939
else:
987940
# Non-main ranks reuse exp_name for logging clarity
988941
if exp_name is None:
989-
if distributed.ddp_enabled:
990-
run_name = f"run-rank{distributed.rank}"
991-
else:
992-
run_name = "run"
942+
run_name = f"run-rank{distributed.rank}" if distributed.ddp_enabled else "run"
993943
else:
994944
run_name = exp_name
995945

@@ -1051,19 +1001,15 @@ def train(
10511001
disable=not trainer.is_main,
10521002
)
10531003
for input_dict in progress:
1054-
input_dict: dict[str, torch.Tensor] = {
1055-
k: v.to(
1056-
trainer.device, non_blocking=True
1057-
) # non_blocking for pin_memory
1004+
input_dict: dict[str, torch.Tensor] = { # noqa: PLW2901
1005+
k: v.to(trainer.device, non_blocking=True) # non_blocking for pin_memory
10581006
for k, v in input_dict.items()
10591007
}
10601008
assert "cyclic_mask" in input_dict # TODO: test and remove
10611009
log_dict = trainer.train_step(input_dict)
10621010
log_dict["learning_rate"] = trainer.scheduler.get_last_lr()[0]
10631011
log_dict["epoch"] = epoch
1064-
log_dict = log_distributed_mean(
1065-
log_dict, trainer.device, distributed
1066-
)
1012+
log_dict = log_distributed_mean(log_dict, trainer.device, distributed)
10671013

10681014
# Log to wandb on main rank only
10691015
if trainer.is_main:

0 commit comments

Comments
 (0)