diff --git a/biapy/config/config.py b/biapy/config/config.py index 82b28b8a9..f5f73c361 100644 --- a/biapy/config/config.py +++ b/biapy/config/config.py @@ -1162,12 +1162,15 @@ def __init__(self, job_dir: str, job_identifier: str): # * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' - # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2' + # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet' # * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' # * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2' # * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D) # * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' _C.MODEL.ARCHITECTURE = "unet" + # Architecture of the network. Possible values are: + # * 'patchgan' + _C.MODEL.ARCHITECTURE_D = "patchgan" # Number of feature maps on each level of the network. _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' @@ -1306,6 +1309,26 @@ def __init__(self, job_dir: str, job_identifier: str): # Whether to use a pretrained version of STUNet on ImageNet _C.MODEL.STUNET.PRETRAINED = False + # NafNet + _C.MODEL.NAFNET = CN() + # Base number of channels (width) used in the first layer and base levels. + _C.MODEL.NAFNET.WIDTH = 16 + # Number of NAFBlocks stacked at the bottleneck (deepest level). + _C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12 + # Number of NAFBlocks assigned to each downsampling level of the encoder. + _C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8] + # Number of NAFBlocks assigned to each upsampling level of the decoder. + _C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2] + # Channel expansion factor for the depthwise convolution within the gating unit. + _C.MODEL.NAFNET.DW_EXPAND = 2 + # Expansion factor for the hidden layer within the feed-forward network. + _C.MODEL.NAFNET.FFN_EXPAND = 2 + + # Discriminator PATCHGAN + _C.MODEL.PATCHGAN = CN() + # Number of initial convolutional filters in the first layer of the discriminator. + _C.MODEL.PATCHGAN.BASE_FILTERS = 64 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1371,7 +1394,24 @@ def __init__(self, job_dir: str, job_identifier: str): _C.LOSS.CONTRAST.MEMORY_SIZE = 5000 _C.LOSS.CONTRAST.PROJ_DIM = 256 _C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10 - + + # Fine-grained GAN composition. Set any weight to 0.0 to disable that term. + # Used when LOSS.TYPE == "COMPOSED_GAN". + _C.LOSS.COMPOSED_GAN = CN() + # Weight for adversarial BCE term. + _C.LOSS.COMPOSED_GAN.LAMBDA_GAN = 1.0 + # Weight for L1 reconstruction term. + _C.LOSS.COMPOSED_GAN.LAMBDA_RECON = 10.0 + # Weight for MSE reconstruction term. + _C.LOSS.COMPOSED_GAN.DELTA_MSE = 0.0 + # Weight for VGG perceptual term. + _C.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL = 0.0 + # Weight for SSIM term. + _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 + + # Backward-compatible alias for previous naming. + _C.LOSS.GAN = CN() + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Training phase # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1381,12 +1421,18 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.VERBOSE = False # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" _C.TRAIN.OPTIMIZER = "SGD" + # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" for GAN discriminator + _C.TRAIN.OPTIMIZER_D = "SGD" # Learning rate _C.TRAIN.LR = 1.0e-4 + # Learning rate for GAN discriminator + _C.TRAIN.LR_D = 1.0e-4 # Weight decay _C.TRAIN.W_DECAY = 0.02 # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers _C.TRAIN.OPT_BETAS = (0.9, 0.999) + # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers for GANS discriminator + _C.TRAIN.OPT_BETAS_D = (0.5, 0.999) # Batch size _C.TRAIN.BATCH_SIZE = 2 # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is diff --git a/biapy/data/generators/__init__.py b/biapy/data/generators/__init__.py index ca072706b..dcc7720f0 100644 --- a/biapy/data/generators/__init__.py +++ b/biapy/data/generators/__init__.py @@ -246,7 +246,7 @@ def create_train_val_augmentors( dic["zflip"] = cfg.AUGMENTOR.ZFLIP if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -293,7 +293,7 @@ def create_train_val_augmentors( ) if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 8991943f7..1873d9ac7 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -21,6 +21,7 @@ def prepare_optimizer( cfg: CN, model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, steps_per_epoch: int, + is_gan: bool = False, ) -> Tuple[Optimizer, Scheduler | None]: """ Create and configure the optimizer and learning rate scheduler for the given model. @@ -33,57 +34,89 @@ def prepare_optimizer( ---------- cfg : YACS CN object Configuration object with optimizer and scheduler settings. - model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel + model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel or dict The model to optimize. steps_per_epoch : int Number of steps (batches) per training epoch. + is_gan : bool, optional + Whether to create optimizer/scheduler pairs for GAN generator and discriminator. Returns ------- - optimizer : Optimizer - Configured optimizer for the model. - lr_scheduler : Scheduler or None - Configured learning rate scheduler, or None if not specified. + optimizer : Optimizer or dict + Configured optimizer for the model or dict with generator/discriminator optimizers in GAN mode. + lr_scheduler : Scheduler or None or dict + Configured scheduler for the model or dict with generator/discriminator schedulers in GAN mode. """ - lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR - opt_args = {} - if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: - opt_args["betas"] = cfg.TRAIN.OPT_BETAS - optimizer = timm.optim.create_optimizer_v2( - model_without_ddp, - opt=cfg.TRAIN.OPTIMIZER, - lr=lr, - weight_decay=cfg.TRAIN.W_DECAY, - **opt_args, - ) - print(optimizer) - - # Learning rate schedulers - lr_scheduler = None - if cfg.TRAIN.LR_SCHEDULER.NAME != "": - if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": - lr_scheduler = ReduceLROnPlateau( - optimizer, - patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, - factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": - lr_scheduler = WarmUpCosineDecayScheduler( - lr=cfg.TRAIN.LR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, - epochs=cfg.TRAIN.EPOCHS, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler = OneCycleLR( - optimizer, - cfg.TRAIN.LR, - epochs=cfg.TRAIN.EPOCHS, - steps_per_epoch=steps_per_epoch, - ) - - return optimizer, lr_scheduler + def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: + lr_scheduler = None + if cfg.TRAIN.LR_SCHEDULER.NAME != "": + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + lr_scheduler = ReduceLROnPlateau( + optimizer, + patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, + factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": + lr_scheduler = WarmUpCosineDecayScheduler( + lr=lr_value, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, + epochs=cfg.TRAIN.EPOCHS, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler = OneCycleLR( + optimizer, + lr_value, + epochs=cfg.TRAIN.EPOCHS, + steps_per_epoch=steps_per_epoch, + ) + return lr_scheduler + + def _make_optimizer(model: nn.Module | nn.parallel.DistributedDataParallel, train_cfg: dict): + lr_value = train_cfg["lr"] + opt_name = train_cfg["optimizer"] + betas = train_cfg["betas"] + w_decay = train_cfg["weight_decay"] + + lr = lr_value if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR + opt_args = {} + if opt_name in ["ADAM", "ADAMW"]: + opt_args["betas"] = betas + + optimizer = timm.optim.create_optimizer_v2( + model, + opt=opt_name, + lr=lr, + weight_decay=w_decay, + **opt_args, + ) + print(optimizer) + lr_scheduler = _make_scheduler(optimizer, lr_value) + return optimizer, lr_scheduler + + g_train_cfg = { + "lr": cfg.TRAIN.LR, + "optimizer": cfg.TRAIN.OPTIMIZER, + "betas": cfg.TRAIN.OPT_BETAS, + "weight_decay": cfg.TRAIN.W_DECAY, + } + + if not is_gan: + return _make_optimizer(model_without_ddp, g_train_cfg) + + d_train_cfg = { + "lr": cfg.TRAIN.LR_D, + "optimizer": cfg.TRAIN.OPTIMIZER_D, + "betas": cfg.TRAIN.OPT_BETAS_D, + "weight_decay": cfg.TRAIN.W_DECAY, + } + + optimizer_g, scheduler_g = _make_optimizer(model_without_ddp["generator"], g_train_cfg) + optimizer_d, scheduler_d = _make_optimizer(model_without_ddp["discriminator"], d_train_cfg) + + return {"generator": optimizer_g, "discriminator": optimizer_d}, {"generator": scheduler_g, "discriminator": scheduler_d,} def build_callbacks(cfg: CN) -> EarlyStopping | None: diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 4f28f200d..5f2b48464 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -837,6 +837,11 @@ def prepare_logging_tool(self): self.plot_values = {} self.plot_values["loss"] = [] self.plot_values["val_loss"] = [] + if getattr(self, "is_gan_mode", False): + self.plot_values["loss_g"] = [] + self.plot_values["loss_d"] = [] + self.plot_values["val_loss_g"] = [] + self.plot_values["val_loss_d"] = [] for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]] = [] self.plot_values["val_" + self.train_metric_names[i]] = [] @@ -853,9 +858,28 @@ def train(self): assert ( self.start_epoch is not None and self.model is not None and self.model_without_ddp is not None and self.loss ) - self.optimizer, self.lr_scheduler = prepare_optimizer( - self.cfg, self.model_without_ddp, len(self.train_generator) - ) + is_gan_mode = getattr(self, "is_gan_mode", False) + if is_gan_mode: + discriminator_wo_ddp = getattr(self, "discriminator_without_ddp", None) + if discriminator_wo_ddp is None: + raise ValueError("GAN mode requires a discriminator model before training") + optimizers, schedulers = prepare_optimizer( + self.cfg, + { + "generator": self.model_without_ddp, + "discriminator": discriminator_wo_ddp, + }, + len(self.train_generator), + is_gan=True, + ) + self.optimizer = optimizers["generator"] + self.lr_scheduler = schedulers["generator"] + self.optimizer_d = optimizers["discriminator"] + self.lr_scheduler_d = schedulers["discriminator"] + else: + self.optimizer, self.lr_scheduler = prepare_optimizer( + self.cfg, self.model_without_ddp, len(self.train_generator) + ) contrast_init_iter = 0 if self.cfg.LOSS.CONTRAST.ENABLE: @@ -910,6 +934,9 @@ def train(self): memory_bank=self.memory_bank, total_iters=total_iters, contrast_warmup_iters=contrast_init_iter, + model_d=getattr(self, "discriminator", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, + lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, ) total_iters += iterations_done @@ -929,6 +956,8 @@ def train(self): epoch=epoch + 1, model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, + discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) # Validation @@ -944,24 +973,29 @@ def train(self): data_loader=self.val_generator, lr_scheduler=self.lr_scheduler, memory_bank=self.memory_bank, + model_d=getattr(self, "discriminator", None) if is_gan_mode else None, + lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, + device=self.device if is_gan_mode else None, ) + val_loss_key = "loss_g" if is_gan_mode else "loss" + # Save checkpoint is val loss improved - if test_stats["loss"] < self.val_best_loss: + if test_stats[val_loss_key] < self.val_best_loss: f = os.path.join( self.cfg.PATHS.CHECKPOINT, "{}-checkpoint-best.pth".format(self.job_identifier), ) print( "Val loss improved from {} to {}, saving model to {}".format( - self.val_best_loss, test_stats["loss"], f + self.val_best_loss, test_stats[val_loss_key], f ) ) m = " " for i in range(len(self.val_best_metric)): self.val_best_metric[i] = test_stats[self.train_metric_names[i]] m += f"{self.train_metric_names[i]}: {self.val_best_metric[i]:.4f} " - self.val_best_loss = test_stats["loss"] + self.val_best_loss = test_stats[val_loss_key] if is_main_process(): self.checkpoint_path = save_model( @@ -973,12 +1007,14 @@ def train(self): epoch="best", model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, + discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) print(f"[Val] best loss: {self.val_best_loss:.4f} best " + m) # Store validation stats if self.log_writer: - self.log_writer.update(test_loss=test_stats["loss"], head="perf", step=epoch) + self.log_writer.update(test_loss=test_stats[val_loss_key], head="perf", step=epoch) for i in range(len(self.train_metric_names)): self.log_writer.update( test_iou=test_stats[self.train_metric_names[i]], @@ -1006,9 +1042,20 @@ def train(self): f.write(json.dumps(log_stats) + "\n") # Create training plot - self.plot_values["loss"].append(train_stats["loss"]) + train_loss_key = "loss_g" if is_gan_mode else "loss" + val_loss_key = "loss_g" if is_gan_mode else "loss" + self.plot_values["loss"].append(train_stats[train_loss_key]) if self.val_generator: - self.plot_values["val_loss"].append(test_stats["loss"]) + self.plot_values["val_loss"].append(test_stats[val_loss_key]) + if is_gan_mode: + if "loss_g" in self.plot_values: + self.plot_values["loss_g"].append(train_stats.get("loss_g", 0)) + if "loss_d" in self.plot_values: + self.plot_values["loss_d"].append(train_stats.get("loss_d", 0)) + if self.val_generator and "val_loss_g" in self.plot_values: + self.plot_values["val_loss_g"].append(test_stats.get("loss_g", 0)) + if self.val_generator and "val_loss_d" in self.plot_values: + self.plot_values["val_loss_d"].append(test_stats.get("loss_d", 0)) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: @@ -1024,7 +1071,8 @@ def train(self): ) if self.val_generator and self.early_stopping: - self.early_stopping(test_stats["loss"]) + val_loss_key = "loss_g" if is_gan_mode else "loss" + self.early_stopping(test_stats[val_loss_key]) if self.early_stopping.early_stop: print("Early stopping") break @@ -1043,7 +1091,8 @@ def train(self): self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time: {}".format(self.total_training_time_str)) - self.train_metrics_message += ("Train loss: {}\n".format(train_stats["loss"])) + train_loss_key = "loss_g" if is_gan_mode else "loss" + self.train_metrics_message += ("Train loss: {}\n".format(train_stats[train_loss_key])) for i in range(len(self.train_metric_names)): self.train_metrics_message += ("Train {}: {}\n".format(self.train_metric_names[i], train_stats[self.train_metric_names[i]])) if self.val_generator: diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index f4b5d5b6a..0374e17c7 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -1242,7 +1242,7 @@ def sort_key(item): ], "LOSS.CLASS_REBALANCE not in ['none', 'auto'] for INSTANCE_SEG workflow" elif cfg.PROBLEM.TYPE == "DENOISING": loss = "MSE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE - assert loss == "MSE", "LOSS.TYPE must be 'MSE'" + assert loss in ["MSE", "COMPOSED_GAN"], "LOSS.TYPE must be in ['MSE', 'COMPOSED_GAN'] for DENOISING" elif cfg.PROBLEM.TYPE == "CLASSIFICATION": loss = "CE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE assert loss == "CE", "LOSS.TYPE must be 'CE'" @@ -1797,12 +1797,16 @@ def sort_key(item): #### Denoising #### elif cfg.PROBLEM.TYPE == "DENOISING": - if cfg.DATA.TEST.LOAD_GT: - raise ValueError( - "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" - ) - if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): - raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") + if cfg.LOSS.TYPE == "COMPOSED_GAN": + if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA: + raise ValueError("Denoising with COMPOSED_GAN is supervised. 'DATA.TRAIN.GT_PATH' is required.") + else: + if cfg.DATA.TEST.LOAD_GT: + raise ValueError( + "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" + ) + if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): + raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") if cfg.MODEL.SOURCE == "torchvision": raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in denoising workflow") @@ -2341,6 +2345,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']" if ( model_arch @@ -2514,6 +2519,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ]: raise ValueError( "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']".format( @@ -2976,6 +2982,7 @@ def compare_configurations_without_model(actual_cfg, old_cfg, header_message="", "DATA.PATCH_SIZE", "PROBLEM.INSTANCE_SEG.DATA_CHANNELS", "PROBLEM.SUPER_RESOLUTION.UPSCALING", + "MODEL.ARCHITECTURE_D", "DATA.N_CLASSES", ] diff --git a/biapy/engine/denoising.py b/biapy/engine/denoising.py index 9cac8d574..b5c0efac8 100644 --- a/biapy/engine/denoising.py +++ b/biapy/engine/denoising.py @@ -24,9 +24,10 @@ merge_3D_data_with_overlap, ) from biapy.engine.base_workflow import Base_Workflow +from biapy.models import build_discriminator from biapy.data.data_manipulation import save_tif -from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger -from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation +from biapy.utils.misc import to_pytorch_format, MetricLogger +from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, ComposedGANLoss class Denoising_Workflow(Base_Workflow): @@ -72,10 +73,20 @@ def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs): # From now on, no modification of the cfg will be allowed self.cfg.freeze() + self.is_gan_mode = str(cfg.LOSS.TYPE).upper() == "COMPOSED_GAN" + self.discriminator = None + self.discriminator_without_ddp = None + self.optimizer_d = None + self.lr_scheduler_d = None + # Workflow specific training variables - self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None + if self.is_gan_mode: + self.mask_path = cfg.DATA.TRAIN.GT_PATH + self.load_Y_val = True + else: + self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None + self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.is_y_mask = False - self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.norm_module.mask_norm = "as_image" self.test_norm_module.mask_norm = "as_image" @@ -166,6 +177,8 @@ def define_metrics(self): # print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") if self.cfg.LOSS.TYPE == "MSE": self.loss = loss_encapsulation(n2v_loss_mse) + elif self.cfg.LOSS.TYPE == "COMPOSED_GAN": + self.loss = ComposedGANLoss(cfg=self.cfg, device=self.device) super().define_metrics() @@ -232,7 +245,11 @@ def metric_calculation( with torch.no_grad(): for i, metric in enumerate(list_to_use): - val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) + if self.is_gan_mode: + target_for_metric = _targets.contiguous() + else: + target_for_metric = _targets[:, _output.shape[1]:].contiguous() + val = metric(_output.contiguous(), target_for_metric) val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val @@ -328,6 +345,36 @@ def process_test_sample(self): verbose=self.cfg.TEST.VERBOSE, ) + def prepare_model(self): + """Build generator model and discriminator when running denoising in GAN mode.""" + super().prepare_model() + + # It is not a GAN model, or we alredy have a discriminator loaded from checkpoint + if not self.is_gan_mode or self.discriminator is not None: + return + + print("#######################") + print("# Build Discriminator #") + print("#######################") + self.discriminator = build_discriminator(self.cfg, self.device) + self.discriminator_without_ddp = self.discriminator + + if self.args.distributed: + self.discriminator = torch.nn.parallel.DistributedDataParallel( + self.discriminator, + device_ids=[self.args.gpu], + find_unused_parameters=False, + ) + self.discriminator_without_ddp = self.discriminator.module + + if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT and self.checkpoint_path: + checkpoint = torch.load(self.checkpoint_path, map_location=self.device) + if "discriminator_state_dict" in checkpoint: + self.discriminator_without_ddp.load_state_dict(checkpoint["discriminator_state_dict"]) + print("Discriminator weights loaded successfully.") + else: + print("Warning: 'discriminator_state_dict' not found in checkpoint.") + def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None: """ Call a regular Pytorch model. @@ -933,4 +980,4 @@ def get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius): Callable Value manipulation function. """ - return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) + return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) \ No newline at end of file diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 32dab4283..80bb21db1 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -18,6 +18,8 @@ import torch.nn.functional as F import torch.nn as nn from typing import Optional, List, Tuple, Dict +from torchvision import transforms +from torchvision.models import vgg16, VGG16_Weights def jaccard_index_numpy(y_true, y_pred): """ @@ -2162,4 +2164,115 @@ def forward( loss = loss / B iou = iou / B - return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals \ No newline at end of file + return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals + +class VGGLoss(nn.Module): + def __init__(self, device): + super().__init__() + self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device) + for param in self.vgg.parameters(): + param.requires_grad = False + self.loss = nn.L1Loss() + self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, pred, target): + if isinstance(pred, dict): + pred = pred["pred"] + if isinstance(target, dict): + target = target["pred"] + + # If 3D, fold Depth (dim 2) into Batch (dim 0) -> (B*D, C, H, W) + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + + # 2D behavior remains identical + if pred.shape[1] == 1: + pred = pred.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + pred = self.preprocess(pred) + target = self.preprocess(target) + pred_vgg = self.vgg(pred) + target_vgg = self.vgg(target) + return self.loss(pred_vgg, target_vgg) + +class ComposedGANLoss(nn.Module): + """ + Dynamic composite loss for GANs. + Only instantiates heavy loss models (like VGG) if their config weight is > 0. + """ + def __init__(self, cfg, device): + super().__init__() + self.device = device + self.w_gan = cfg.LOSS.COMPOSED_GAN.LAMBDA_GAN + self.w_l1 = cfg.LOSS.COMPOSED_GAN.LAMBDA_RECON + self.w_vgg = cfg.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL + self.w_ssim = cfg.LOSS.COMPOSED_GAN.GAMMA_SSIM + self.w_mse = cfg.LOSS.COMPOSED_GAN.DELTA_MSE + + # Dont load the vgg if not + if self.w_vgg > 0: + self.vgg = VGGLoss(device) + if self.w_ssim > 0: + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) + + # Standard lightweight losses are always initialized + self.l1 = nn.L1Loss() + self.mse = nn.MSELoss() + self.bce = nn.BCEWithLogitsLoss() + + def forward_generator(self, pred, target, d_fake): + # Dict extraction + if isinstance(pred, dict): pred = pred["pred"] + if isinstance(target, dict): target = target["pred"] + + # NaN Band-aid + pred = torch.nan_to_num(pred, nan=0.0, posinf=1.0, neginf=-1.0) + target = torch.nan_to_num(target, nan=0.0, posinf=1.0, neginf=-1.0) + + total_loss = torch.tensor(0.0, device=self.device) + + # 2. Dynamically build the loss based on config weights + if self.w_l1 > 0: + total_loss += self.w_l1 * self.l1(pred, target) + + if self.w_mse > 0: + total_loss += self.w_mse * self.mse(pred, target) + + if self.w_vgg > 0: + total_loss += self.w_vgg * self.vgg(pred, target) + + if self.w_ssim > 0: + # SSIM requires 4D tensors. Safely route 3D to 2D slices. + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred_ssim = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target_ssim = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + total_loss += self.w_ssim * (1.0 - self.ssim(pred_ssim, target_ssim)) + else: + total_loss += self.w_ssim * (1.0 - self.ssim(pred, target)) + + if self.w_gan > 0: + total_loss += self.w_gan * self.bce(d_fake, torch.ones_like(d_fake)) + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in generator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss + + def forward_discriminator(self, d_real, d_fake): + # Calculate Adversarial Loss for Discriminator + real_loss = self.bce(d_real, torch.full_like(d_real, 0.9)) # Label smoothing (0.9 instead of 1.0) + fake_loss = self.bce(d_fake, torch.zeros_like(d_fake)) + total_loss = (real_loss + fake_loss) / 2.0 + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in discriminator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss \ No newline at end of file diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index e4b05e3b9..c1939cabc 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -38,6 +38,9 @@ def train_one_epoch( memory_bank: Optional[MemoryBank] = None, total_iters: int=0, contrast_warmup_iters: int=0, + model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, + optimizer_d: Optional[Optimizer] = None, + lr_scheduler_d: Optional[Scheduler] = None, ): """ Train the model for one epoch. @@ -87,18 +90,28 @@ def train_one_epoch( int Number of steps (batches) processed. """ + is_gan = model_d is not None and optimizer_d is not None + # Switch to training mode model.train(True) + if is_gan: + model_d.train(True) # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ", verbose=verbose) - metric_logger.add_meter("loss", SmoothedValue()) + if is_gan: + metric_logger.add_meter("loss_g", SmoothedValue()) + metric_logger.add_meter("loss_d", SmoothedValue()) + else: + metric_logger.add_meter("loss", SmoothedValue()) # Set up the header for logging header = "Epoch: [{}]".format(epoch + 1) print_freq = 10 optimizer.zero_grad() + if is_gan: + optimizer_d.zero_grad() for step, (batch, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): @@ -111,6 +124,9 @@ def train_one_epoch( and isinstance(lr_scheduler, WarmUpCosineDecayScheduler) ): lr_scheduler.adjust_learning_rate(optimizer, step / len(data_loader) + epoch) + if is_gan and lr_scheduler_d and isinstance(lr_scheduler_d, WarmUpCosineDecayScheduler): + lr_scheduler_d.adjust_learning_rate(optimizer_d, step / len(data_loader) + epoch) + # Gather inputs targets = prepare_targets(targets, batch) @@ -121,6 +137,86 @@ def train_one_epoch( f" Input: {batch.shape[1:-1]} vs PATCH_SIZE: {cfg.DATA.PATCH_SIZE[:-1]}" ) + if is_gan: + assert model_d is not None and optimizer_d is not None + + if ( + torch.isnan(batch).any() + or torch.isinf(batch).any() + or torch.isnan(targets).any() + or torch.isinf(targets).any() + ): + print("Warning: NaN or Inf detected in input. Skipping batch.") + continue + + # Phase 1: discriminator update + optimizer_d.zero_grad() + fake_img = model_call_func(batch, is_train=True) + if isinstance(fake_img, dict): + fake_img = fake_img["pred"] + fake_img = torch.clamp(fake_img, 0, 1) + + d_real = model_d(targets) + d_fake = model_d(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + + if torch.isnan(loss_d) or torch.isinf(loss_d): + print("Warning: NaN or Inf detected in discriminator loss. Skipping batch.") + continue + + loss_d.backward() + optimizer_d.step() + + if lr_scheduler_d and isinstance(lr_scheduler_d, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler_d.step() + + # Phase 2: generator update + optimizer.zero_grad() + outputs = model_call_func(batch, is_train=True) + if isinstance(outputs, dict): + outputs = outputs["pred"] + outputs = torch.clamp(outputs, 0, 1) + + d_fake_for_g = model_d(outputs) + loss = loss_function.forward_generator(outputs, targets, d_fake_for_g) + + if torch.isnan(loss) or torch.isinf(loss): + print("Warning: NaN or Inf detected in generator loss. Skipping batch.") + continue + + loss.backward() + optimizer.step() + + if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler.step() + + metric_function(outputs, targets, metric_logger=metric_logger) + + loss_g_value = loss.item() + loss_d_value = loss_d.item() + metric_logger.update(loss_g=loss_g_value, loss_d=loss_d_value) + + if log_writer: + log_writer.update(loss_g=all_reduce_mean(loss_g_value), head="loss") + log_writer.update(loss_d=all_reduce_mean(loss_d_value), head="loss") + + max_lr_g = 0.0 + max_lr_d = 0.0 + for group in optimizer.param_groups: + max_lr_g = max(max_lr_g, group["lr"]) + for group in optimizer_d.param_groups: + max_lr_d = max(max_lr_d, group["lr"]) + + if step == 0: + metric_logger.add_meter("lr_g", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("lr_d", SmoothedValue(window_size=1, fmt="{value:.6f}")) + + metric_logger.update(lr_g=max_lr_g, lr_d=max_lr_d) + if log_writer: + log_writer.update(lr_g=max_lr_g, head="opt") + log_writer.update(lr_d=max_lr_d, head="opt") + continue + # Pass the images through the model outputs = model_call_func(batch, is_train=True) @@ -193,6 +289,12 @@ def train_one_epoch( if log_writer: log_writer.update(lr=max_lr, head="opt") + if is_gan and cfg.TRAIN.LR_SCHEDULER.NAME not in ["reduceonplateau", "onecycle", "warmupcosine"]: + if lr_scheduler: + lr_scheduler.step() + if lr_scheduler_d: + lr_scheduler_d.step() + # Gather the stats from all processes metric_logger.synchronize_between_processes() print("[Train] averaged stats:", metric_logger) @@ -211,6 +313,9 @@ def evaluate( data_loader: DataLoader, lr_scheduler: Optional[Scheduler] = None, memory_bank: Optional[MemoryBank] = None, + model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, + lr_scheduler_d: Optional[Scheduler] = None, + device: Optional[torch.device] = None, ): """ Evaluate the model on the validation set. @@ -246,13 +351,21 @@ def evaluate( dict Dictionary of averaged metrics for the validation set. """ + is_gan = model_d is not None + # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - metric_logger.add_meter("loss", SmoothedValue()) + if is_gan: + metric_logger.add_meter("loss_g", SmoothedValue()) + metric_logger.add_meter("loss_d", SmoothedValue()) + else: + metric_logger.add_meter("loss", SmoothedValue()) header = "Epoch: [{}]".format(epoch + 1) # Switch to evaluation mode model.eval() + if is_gan: + model_d.eval() for batch in metric_logger.log_every(data_loader, 10, header): # Gather inputs @@ -260,9 +373,30 @@ def evaluate( targets = batch[1] targets = prepare_targets(targets, images) + if is_gan: + assert model_d is not None + outputs = model_call_func(images, is_train=False) + if isinstance(outputs, dict): + outputs = outputs["pred"] + outputs = torch.clamp(outputs, 0, 1) + + d_fake_val = model_d(outputs) + d_real_val = model_d(targets) + loss = loss_function.forward_generator(outputs, targets, d_fake_val) + loss_d = loss_function.forward_discriminator(d_real_val, d_fake_val) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print(f"Validation loss is {loss_value}, skipping batch.") + continue + + metric_function(outputs, targets, metric_logger=metric_logger) + metric_logger.update(loss_g=loss_value, loss_d=loss_d.item()) + continue + # Pass the images through the model outputs = model_call_func(images, is_train=True) - + # Loss function call if memory_bank is not None: with_embed = False @@ -285,7 +419,7 @@ def evaluate( precalculated_metric = loss[1] precalculated_metric_name = loss[2] loss = loss[0] - + loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) @@ -306,10 +440,12 @@ def evaluate( print("[Val] averaged stats:", metric_logger) # Apply reduceonplateau scheduler if the global validation has been reduced - if ( - lr_scheduler - and isinstance(lr_scheduler, ReduceLROnPlateau) - and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau" - ): - lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + if is_gan: + if lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): + lr_scheduler.step(metric_logger.meters["loss_g"].global_avg, epoch=epoch) + if lr_scheduler_d and isinstance(lr_scheduler_d, ReduceLROnPlateau): + lr_scheduler_d.step(metric_logger.meters["loss_d"].global_avg, epoch=epoch) + elif lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): + lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 5a0f98b23..238903f77 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -366,6 +366,20 @@ def build_model( ) model = MaskedAutoencoderViT(**args) # type: ignore callable_model = MaskedAutoencoderViT # type: ignore + elif modelname == "nafnet": + args = dict( + img_channel=cfg.DATA.PATCH_SIZE[-1], + width=cfg.MODEL.NAFNET.WIDTH, + middle_blk_num=cfg.MODEL.NAFNET.MIDDLE_BLK_NUM, + enc_blk_nums=cfg.MODEL.NAFNET.ENC_BLK_NUMS, + dec_blk_nums=cfg.MODEL.NAFNET.DEC_BLK_NUMS, + drop_out_rate=cfg.MODEL.DROPOUT_VALUES[0], + dw_expand=cfg.MODEL.NAFNET.DW_EXPAND, + ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND + ) + callable_model = NAFNet # type: ignore + model = callable_model(**args) # type: ignore + # Check the network created model.to(device) if cfg.PROBLEM.NDIM == "2D": @@ -405,6 +419,75 @@ def build_model( return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore +def build_discriminator(cfg: CN, device: torch.device): + """ + Build selected model. + + Parameters + ---------- + cfg : YACS CN object + Configuration. + + device : Torch device + Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", + "xpu", "xla" or "meta". + + Returns + ------- + """ + # 1. Standardize name and Import the module + modelname = str(cfg.MODEL.ARCHITECTURE_D).lower() + + print("###############") + print(f"# Build {modelname.upper()} Disc #") + print("###############") + + # Dynamic import like build_model + mdl = import_module("biapy.models." + modelname) + + names = [x for x in mdl.__dict__ if not x.startswith("_")] + globals().update({k: getattr(mdl, k) for k in names}) + + # 2. Model building block + if modelname == "patchgan": + args = dict( + in_channels=cfg.DATA.PATCH_SIZE[-1], + base_filters=cfg.MODEL.PATCHGAN.BASE_FILTERS + ) + callable_model = PatchGANDiscriminator # type: ignore + else: + raise ValueError(f"Discriminator {modelname} is not implemented or registered.") + + # Instantiate + model = callable_model(**args) + model.to(device) + + # 3. Summary Logic + if cfg.PROBLEM.NDIM == "2D": + sample_size = ( + 1, + cfg.DATA.PATCH_SIZE[2], + cfg.DATA.PATCH_SIZE[0], + cfg.DATA.PATCH_SIZE[1], + ) + else: + sample_size = ( + 1, + cfg.DATA.PATCH_SIZE[3], + cfg.DATA.PATCH_SIZE[0], + cfg.DATA.PATCH_SIZE[1], + cfg.DATA.PATCH_SIZE[2], + ) + + summary( + model, + input_size=sample_size, + col_names=("input_size", "output_size", "num_params"), + depth=10, + device=device.type, + ) + + return model def init_embedding_output(model: nn.Module, n_sigma: int = 2): """ diff --git a/biapy/models/nafnet.py b/biapy/models/nafnet.py new file mode 100644 index 000000000..868f1fdf1 --- /dev/null +++ b/biapy/models/nafnet.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + self.eps).sqrt() + y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) + return y + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +class NAFNet(nn.Module): + def __init__( + self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + drop_out_rate=0.0, + dw_expand=2, + ffn_expand=2 + ): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + # Pass the new parameters into the NAFBlock + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x \ No newline at end of file diff --git a/biapy/models/patchgan.py b/biapy/models/patchgan.py new file mode 100644 index 000000000..bf8fa8933 --- /dev/null +++ b/biapy/models/patchgan.py @@ -0,0 +1,23 @@ +import torch.nn as nn + +class PatchGANDiscriminator(nn.Module): + def __init__(self, in_channels=1, base_filters=64): + super(PatchGANDiscriminator, self).__init__() + + def discriminator_block(in_filters, out_filters, normalization=True): + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] + if normalization: + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels, base_filters, normalization=False), + *discriminator_block(base_filters, base_filters * 2), + *discriminator_block(base_filters * 2, base_filters * 4), + *discriminator_block(base_filters * 4, base_filters * 8), + nn.Conv2d(base_filters * 8, 1, 4, stride=1, padding=1) + ) + + def forward(self, img): + return self.model(img) \ No newline at end of file diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 0ce41e73b..9c4dffa35 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -305,7 +305,19 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: return total_norm -def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, model_build_kwargs=None, extension="pth"): +def save_model( + cfg, + biapy_version, + jobname, + epoch, + model_without_ddp, + optimizer, + model_build_kwargs=None, + extension="pth", + discriminator_without_ddp=None, + optimizer_d=None, + extra_checkpoint_items=None, +): """ Save the model checkpoint to the specified path. @@ -333,6 +345,12 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, extension : str, optional The file extension for the checkpoint file. Options are 'pth' (native PyTorch format) or 'safetensors' (https://github.com/huggingface/safetensors). Defaults to "pth". + discriminator_without_ddp : Optional[nn.Module], optional + Optional discriminator model to include in checkpoints for GAN training. + optimizer_d : Optional[torch.optim.Optimizer], optional + Optional discriminator optimizer state to include in checkpoints for GAN training. + extra_checkpoint_items : Optional[dict], optional + Additional custom fields to append to the checkpoint payload. Returns ------- @@ -352,6 +370,13 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, "biapy_version": biapy_version, } + if discriminator_without_ddp is not None: + to_save["discriminator_state_dict"] = discriminator_without_ddp.state_dict() + if optimizer_d is not None: + to_save["optimizer_d_state_dict"] = optimizer_d.state_dict() + if extra_checkpoint_items: + to_save.update(extra_checkpoint_items) + save_on_master(to_save, checkpoint_path) if len(checkpoint_paths) > 0: return checkpoint_paths[0]