All notable chagnes to this project will be documented in this file.
The format is based on Keep a Changelog,
and this project adheres to Semantic Versioning.
- The logger now records model architecture tags by capturing each model config's
class_pathand settingmodel.<idx>.class_pathat train start. - The loss-group auto logging routine
_log_loss_groups_config_and_tagslogs loss item names and weights as MLflow tags and persists the full loss group configuration as a JSON config artifact.
Introduces a lightweight wrapper abstraction for dataset composition and a MONAI-compatible adapter for dictionary-based augmentation pipelines. This enables augmentation workflows to be layered on top of existing dataset implementations without modifying core dataset logic.
BaseWrapperDataset(base_wrapper_dataset.py): Abstract wrapper base class that forwards dataset access to an underlying dataset instance and provides recursive access to the original base dataset via theoriginalproperty. Establishes a reusable pattern for composing dataset behaviors (e.g., augmentation, caching, preprocessing) while preserving compatibility with existing dataset APIs.MonaiAdapter(monai_aug_adapter_dataset.py): Wrapper dataset that adapts(input, target)tuple samples into MONAI dictionary format ({"input": ..., "target": ...}), applies optional MONAIComposetransforms, and returns transformed samples back as(input, target)tuples for trainer compatibility.
- Added/updated
4.data_augmentation_example.ipynbdemonstrating:- construction of a base dataset and crop dataset,
- application of MONAI dictionary transforms through
MonaiAdapter, - visualization of repeated stochastic augmentations,
- integration of the augmented dataset into a standard training dataloader/trainer workflow.
Allows the dataset to return user specified crops dynamically obtained from the full images. Supports serialization and reserialization to facilitate reproducibility.
CropImageDataset(crop_dataset.py): Dataset class for serving image crops based on aCropManifest. ExtendsBaseImageDatasetwith crop-specific state management and lazy loading viaCropFileState.CropManifest(ds_engine/crop_manifest.py): Immutable collection of crop definitions wrapping aDatasetManifestfor file access. Supports serialization/deserialization and factory construction from coordinate specifications.Crop(ds_engine/crop_manifest.py): Dataclass defining a single crop region with manifest index, position (x, y), and dimensions (width, height).CropIndexState(ds_engine/crop_manifest.py): Mutable state tracker for the currently active crop region.CropFileState(ds_engine/crop_manifest.py): Lazy image loading backend that wrapsFileStateto load full images and dynamically extract crop regions on demand.
- Abstract away the forward pass and multiple loss accumulation from trainers
- New logging trainer for single generator model training using
engine
- Add progress bar
- Added
LossGroupclass to abstract out the complexity of computating multiple losses surrounding a single forward pass iteration from the trainer. Itended as prepwork for incorporating more complex wGAN training.
LossItem(loss_group.py): Wrapper around atorch.nn.Moduleloss to specify weights and arguments needed for computing.LossGroup(loss_group.py): Container class organizing allLossItems to be computed during the same forward pass on the same set of context objects.
- Introduced a comprehensive refactoring of the dataset infrastructure for improved modularity, lazy loading, and memory efficiency.
DatasetManifest(manifest.py): Immutable manifest class that defines the structure of a dataset, holding a file index DataFrame where each row corresponds to a sample/FOV and columns represent channels. Validates file paths and PIL image modes during initialization.IndexState(manifest.py): Lightweight tracker for maintaining the last accessed dataset index.FileState(manifest.py): Lazy loading backend that manages image loading with configurable LRU caching.BaseImageDataset(base_dataset.py): PyTorch-compatible dataset class built on the manifest infrastructure.
- Restructured dataset loading logic to use the new modular manifest-based architecture.
- Improved error handling and validation throughout the dataset pipeline.
- Enhanced type annotations and documentation for better developer experience.
- Introduced a new modular and extensible
modelssubpackage for building image-to-image translation models. The subpackage is designed around a declarative style for creating U-Net-like architectures, with a hierarchy of abstractions:- Blocks (
blocks.py,up_down_blocks.py): Smallest modular units, categorized into computational blocks (e.g.,Conv2DNormActBlock,Conv2DConvNeXtBlock) and spatial dimension altering blocks (e.g.,Conv2DDownBlock,PixelShuffle2DUpBlock). - Stages (
stages.py): Sequences of blocks for downsampling or upsampling, such asDownStageandUpStage. - Encoder (
encoder.py): Implements the downsampling path of U-Net-like architectures usingDownStageobjects. - Decoder (
decoder.py): Implements the upsampling path with skip connections usingUpStageobjects. - BaseModel and BaseGeneratorModel (
model.py): Added abstract base classes for models, including functionality for saving weights, configuration handling, and defining the forward pass. - UNet (
unet.py): Predefined model class supporting fully convolutional and maxpooling-based U-Net variants. - UNeXt (
unext.py): Predefined U-Net variant with a ConvNeXtV2_tiny encoder and customizable decoder.
- Blocks (
- Added utility functions for normalization layers, activation functions, and type checking of block handles and configurations.
- Refer to the
modelsREADME for detailed explanations of components and usage examples.
- Restructured the repository from a flat layout to the conventional
/src/package_name/structure. This change improves module discoverability, aligns with modern Python packaging standards, and reduces potential import conflicts. All package-related code now resides under thesrc/virtual_stain_flow/directory. - Updated import paths throughout the codebase to reflect the new structure.
- Adjusted setup scripts and documentation to accommodate the restructuring.
A minimal rework of the logging framework as the first step to a complete overhual of the virtual_stain_flow software.
This version defines a new logging subpackage that better integrates MLflow into the virtual staining model training process for a more comprehensive logging framework.
Notes: This class is simiar to the old virtual_stain_flow.callback.MlflowLogger class, but promoted to be an independent logger class, with ability to accept logger callbacks. Key design/functionality:
- Files/metrics/parameters produced by Logger callbacks gets automatically logged to MLflow appropriately instead of being independent products untracked.
- Included some more pre-defined fine-grained run logging tags such as
experiment_type,model_architecture,target_channel_name,descriptionas logger class parameter. - Has a
bind_trainerandunbind_trainermethods to bind and unbind with the trainer instance during train step. - User controlled mlflow run cycle, no longer autoamtically ends with the train loop, so user can perform additional logging operation before explicity ending the run.
- Has exposed
log_artifact,log_metric, andlog_parammethods for manual logging of artifacts, metrics, and parameters. - Has some access point of trainer attributes for use by logger callbacks, but subject to optimization/change.
Notes: This class subclasses the virtual_stain_flow.trainers.AbstractTrainer class and preserves most of its behavior and functionalities.
Design/functionality change include:
- Binding of logger class moved from initialization to
trainmethod to reflect the design that logger instances should live with the training sessions. - Early termination mode is now a parameter of the class to allow for selection of min/max optimzation mode.
- The
trainmethod loop invokes the logger life cycle methods:logger.on_train_start()logger.on_epoch_start()logger.on_epoch_end()logger.on_train_end()methods which in turn leads to logger's invocation of the logger callback methods.
- Requires child classes to implement the @abstract
save_modelmethod for unified handle for saving model weights that can be called by the logger.
Notes: This class is nearly identical to the old Trainer class, except that:
- It is the realization of the new
AbstractLoggingTrainerclass instead of theAbstractTrainerclass. - It overrides the parent class
save_modelmethod that defines saving of the model weight.
Notes: This is a new subpackage that is distinct from the existing virtual_stain_flow.callbacks
subpackage in that classes under this subpackage are passed to logging.MlflowLogger instances as
opposed to a trainers.* instances.
AbstractLoggerCallbackclass: A newly introduced abstract class that defines behavior for logger callbacks interacting with thelogging.MlflowLoggerclass so product of callback gets logged appropriately as artifacts/metrics/parameters.PlotPredictionCallbackclass: A newly introduced class that is a realization of theAbstractLoggerCallbackclass. Serves as an example implementation of a logger callback. Similar to thevirtual_stain_flow.callbacks.intermediatePlotCallback, plots predictions of the model on a subset of the dataset, but the additional interface with theMlflowLoggerclass ensures the plots produced are logged as mlflow artifacts.
- Internal function renames for clarity.
- Consistent attribute/property usage.
- Updated
__init__.pydirectly exposing classes under subpackage.
- Introduced a minimal yet self-contained virtual staining framework structured around modular components for model training, dataset handling, transformations, metrics, and logging.
- Added
FNet: Fully convolutional encoder-decoder for image-to-image translation. - Added
UNet: U-Net variant using bilinear interpolation for upsampling. - Added GaN discriminators:
PatchBasedDiscriminator: Outputs a probability map.GlobalDiscriminator: Outputs a global scalar probability.
MinMaxNormalize: Albumentations transform for range-based normalization.ZScoreNormalize: Albumentations transform for z-score normalization.PixelDepthTransform: Converts between image bit depths (e.g., 16-bit to 8-bit).
ImageDataset: Dynamically loads multi-channel microscopy images from a PE2LoadData-formatted CSV; supports input/target channel selection and Albumentations transforms.PatchDataset: ExtendsImageDatasetwith configurable fixed-size cropping; supports object-centric patching and state retrieval (e.g., patch coordinates).GenericImageDataset: A simplified dataset for user-formatted directories using regex-based site/channel parsing.CachedDataset: Caches any of the above datasets in RAM to reduce I/O and speed up training.
AbstractLoss: Base class defining standardized loss interface and trainer binding.GeneratorLoss: Combines image reconstruction and adversarial loss for training GaN generators.WassersteinLoss: Computes Wasserstein distance for GaN discriminator training.GradientPenaltyLoss: Adds gradient penalty to improve discriminator stability.
AbstractMetrics: Base class for accumulating, aggregating, and resetting batch-wise metrics.MetricsWrapper: Wrapstorch.nn.Modulemetrics with accumulation and aggregation logic.PSNR: Computes Peak Signal-to-Noise Ratio (PSNR) for image quality evaluation.
AbstractCallback: Base class for trainer-stage hooks (on_train_start,on_epoch_end, etc.).IntermediatePlot: Visualizes model inference during training.MlflowLogger: Logs trainer metrics and losses to an MLflow server.
AbstractTrainer: Defines a modular training loop with support for custom models, datasets, losses, metrics, and callbacks. Exposes extensible hooks for batch and epoch-level logic.