This repo implements a top-down deep generative model trained with Alternating Back-Propagation (ABP).
Given a low-dimensional latent variable z (here 2D), a generator network g(z; θ) synthesizes an image.
Training alternates between:
- Inferential back-propagation: infer
zfor each training image using Langevin dynamics (approximate sampling fromp(z | x, θ)). - Learning back-propagation: update generator parameters
θby back-propagation given inferredz.
The experiments are conducted on a lion–tiger image dataset resized to 128×128.
.
├── abp.py # Main training script (ABP + generator)
├── run_all_experiments.sh # Runs warm-start + cold-start experiments
├── report.tex # LaTeX report (placeholders reference output images)
├── images/ # Dataset folder (not tracked by git)
└── .gitignore
Note: your main script may be named
adp.pyin some setups. The providedrun_all_experiments.shwill auto-detectabp.pyoradp.py.
- Python 3.10+ recommended
- PyTorch
- torchvision
- numpy
- matplotlib
- pillow
Example installation (conda):
conda create -n abp python=3.10 -y
conda activate abp
pip install torch torchvision numpy matplotlib pillowPlace all lion/tiger images under:
./images/
The script loads all files inside images/ as training data.
Images are resized to 128×128 and normalized to [-1, 1].
python abp.py \
--start warm \
--lr 4e-4 \
--langevin_step_size 0.05 \
--langevin_num_steps 120 \
--mse_sigma 1 \
--prior_sigma 1 \
--n_epochs 2000 \
--n_log 100 \
--n_stats 100 \
--n_plot 200 \
--seed 1python abp.py \
--start cold \
--lr 4e-4 \
--langevin_step_size 0.05 \
--langevin_num_steps 120 \
--mse_sigma 1 \
--prior_sigma 1 \
--n_epochs 2000 \
--n_log 100 \
--n_stats 100 \
--n_plot 200 \
--seed 1The project requires:
-
Reconstructed images (using inferred
z):XXXX_recon.png
-
Randomly generated images (sampling
z ~ N(0, I)):XXXX_sampled.png
-
Latent interpolation grid (2D interpolation):
XXXX_interp.png
-
Loss plot over iterations:
stat.pngandstat.pdf
Each experiment folder also includes:
output.log— training logs (loss, z mean/std)
Example output files:
warm_0.0004_0.05_120_1/
0200_recon.png
0200_sampled.png
0200_interp.png
...
2000_recon.png
2000_sampled.png
2000_interp.png
stat.png
stat.pdf
output.log
-
Warm-start vs Cold-start:
- Warm-start reuses previous
zvalues and often converges faster / more stably. - Cold-start reinitializes from the prior each epoch and can be noisier.
- Warm-start reuses previous
-
If samples look noisy or unstable:
- Reduce
--langevin_step_size(e.g.,0.03) - Increase
--langevin_num_steps(e.g.,200) - Try adjusting
--mse_sigma(e.g.,0.5)
- Reduce
All runs are controlled by --seed.
To reproduce results exactly, use the same seed and hyperparameters.
MIT License