Skip to content

geraldmc/DeepWeeds

Repository files navigation

Deep Weeds via Deep Fake: GAN-Based Data Augmentation for Weed Classification

Open In Colab

Abstract

This repository contains code to reproduce experiments exploring whether synthetic data generated by Generative Adversarial Networks (GANs) can improve deep learning classification accuracy on the DeepWeeds dataset. We evaluate two GAN architectures—DCGAN and ACGAN—for synthetic image generation, and assess their utility as data augmentation strategies when training a ResNet-50 classifier via transfer learning.

Key Finding: Our experiments indicate that GAN-based augmentation does not improve classification accuracy on this dataset. A ResNet-50 model trained with stratified k-fold cross-validation achieves 92% test accuracy without augmentation, compared to 91% with augmentation. We discuss potential explanations related to dataset characteristics and GAN training limitations.

Research Questions

  1. Can synthetic data from DCGAN and ACGAN achieve higher classification accuracy than traditional augmentation or no augmentation?
  2. Which model configuration is most effective for classifying the DeepWeeds dataset?

Background and Motivation

Detection and classification of weed species in situ remains a significant challenge in precision agriculture. Unlike segmented laboratory images, field conditions introduce variability in lighting, occlusion, background complexity, and plant growth stages. The DeepWeeds dataset (Olsen et al., 2019) was specifically designed to reflect this complexity, containing 17,509 images of eight weed species plus a "negative" class captured across multiple locations in northeastern Australia.

Training deep convolutional neural networks requires extensive labeled data, which is costly to collect and annotate in agricultural settings. This motivates exploration of synthetic data augmentation using GANs as a potentially more efficient alternative to manual data collection.

The DeepWeeds Dataset

Class Species Count
0 Chinee apple 1,125
1 Lantana 1,064
2 Parkinsonia 1,031
3 Parthenium 1,022
4 Prickly acacia 1,062
5 Rubber vine 1,009
6 Siam weed 1,074
7 Snake weed 1,016
8 Negatives 9,106
Total 17,509

The dataset maintains a 50:50 split between positive (weed) and negative (non-weed) classes. This distribution is intentional—field locations containing target weeds also contain numerous unlabeled plant species that must be classified as "not target."

Methodology

1. Transfer Learning with ResNet-50

We adapt a PyTorch ResNet-50 model pretrained on ImageNet (1,000 classes) for 9-class weed classification. The pretrained convolutional layers are frozen, and only the final fully-connected layer is retrained. This approach:

  • Reduces computational cost
  • Alleviates training data requirements
  • Leverages learned feature representations from ImageNet

Training Configuration:

  • Optimizer: Adam
  • Loss function: Binary cross-entropy
  • Learning rate: 0.0001 (halved when validation loss plateaus for 16 epochs)
  • Cross-validation: Stratified 5-fold
  • Epochs per fold: 100

2. DCGAN (Deep Convolutional GAN)

Generates unlabeled 64×64 pixel synthetic images from the complete dataset.

Parameter Value
Generator layers 5 (de-convolutional)
Discriminator layers 5 (convolutional)
Input/Output dimensions 64×64×3
Batch size 128
Epochs 200
Latent dimension (z) 100

3. ACGAN (Auxiliary Classifier GAN)

Generates labeled 128×128 pixel synthetic images, conditioned on class labels.

Parameter Value
Generator 2 upsample layers, batch norm, 3 conv layers, Tanh
Discriminator 4 conv layers, 1 FC + sigmoid (real/fake), 1 FC + softmax (class)
Input/Output dimensions 128×128×3
Batch size 64
Epochs 200
Data augmentation imgaug library

4. Evaluation Metrics

Fréchet Inception Distance (FID): Measures similarity between generated and real image distributions using features extracted from a pretrained Inception v3 model. Lower scores indicate higher quality (0 = identical distributions).

Classification Metrics: Accuracy, precision, sensitivity, specificity, false positive rate, false negative rate, and false discovery rate computed per class.

Results Summary

GAN Image Quality (FID Scores)

Model FID Score Interpretation
DCGAN 152 High deviation from real images
ACGAN 158 Higher deviation; fewer images per class

The high FID scores indicate that generated images deviate substantially from real image distributions. Visual inspection confirms that ACGAN images appear heavily pixelated, likely due to insufficient training examples per class.

Classification Accuracy

Configuration Best Validation Acc Test Accuracy
ResNet-50 (no augmentation) 84.19% 92%
ResNet-50 (with augmentation) 83.24% 91%

Conclusion: Data augmentation—including GAN-generated synthetic images—did not improve classification accuracy on this dataset.

Repository Structure

.
├── conf/
│   └── params.py              # Global configuration parameters
├── data/
│   ├── labels/                # CSV files for k-fold splits
│   ├── images/                # Raw DeepWeeds images (download required)
│   ├── train/                 # Training images by class
│   ├── val/                   # Validation images by class
│   ├── test/                  # Test images by class
│   └── test_loader.py         # PyTorch dataset class
├── models/
│   └── resnet50.py            # ResNet architecture implementation
├── Model/                     # Saved model checkpoints
├── output/                    # Training outputs and logs
│
├── acgan_network.py           # ACGAN generator/discriminator
├── dcgan_network.py           # DCGAN generator/discriminator
│
├── main.ipynb                 # Main training notebook (Colab)
├── load_deepweeds_local.ipynb # Data preparation (local)
├── DataAugmentation.ipynb     # Traditional augmentation
├── DCGAN.ipynb                # DCGAN training
├── DCGAN Evaluation.ipynb     # DCGAN evaluation and FID
├── AC-GAN.ipynb               # ACGAN training
├── ACGAN Evaluation.ipynb     # ACGAN evaluation and FID
└── Generate_ACGAN_Images.ipynb # Generate synthetic images

Reproducing the Experiments

Prerequisites

  • Python 3.8+
  • PyTorch 1.9+
  • CUDA-capable GPU (recommended)
  • Google Colab account (for cloud execution)

Step 1: Obtain the Dataset

Download the DeepWeeds images:

Extract to data/images/ at the project root.

Data Organization: Images use filenames in the format YYYYMMDD-HHMMSS-ID.jpg (e.g., 20170320-093423-1.jpg). Labels are provided in data/labels/labels.csv:

Filename,Label,Species
20170207-154924-0.jpg,7,Snake weed
20170610-123859-1.jpg,1,Lantana
20180119-105722-1.jpg,8,Negative

Step 2: Environment Setup

Option A: Google Colab (Recommended)

Open main.ipynb in Colab. The notebook handles dependencies automatically. Set runtime to GPU with High RAM for optimal performance.

Option B: Local Installation

git clone https://github.com/geraldmc/torch-draft-final_project.git
cd torch-draft-final_project

pip install torch torchvision pandas numpy scikit-learn matplotlib seaborn imgaug

Step 3: Prepare Data for Training

Run load_deepweeds_local.ipynb to organize images into the required directory structure:

data/train/[0-8]/
data/val/[0-8]/
data/test/[0-8]/

This notebook reads the k-fold CSV files and copies images to appropriate directories.

Step 4: Train and Evaluate the Classifier

Execute main.ipynb to run stratified 5-fold cross-validation:

  1. Downloads code from GitHub (if on Colab)
  2. Loads DeepWeeds images
  3. Trains ResNet-50 for 100 epochs per fold
  4. Saves best model checkpoints
  5. Generates confusion matrices and evaluation metrics

Expected Runtime: ~2-3 hours on Colab GPU (T4)

Step 5: Train GAN Models (Optional)

To reproduce GAN experiments:

DCGAN:

  1. Run DCGAN.ipynb — trains generator/discriminator
  2. Run DCGAN Evaluation.ipynb — computes FID score

ACGAN:

  1. (Optional) Run DataAugmentation.ipynb for augmented training data
  2. Run AC-GAN.ipynb — trains conditional generator/discriminator
  3. Run ACGAN Evaluation.ipynb — computes FID score
  4. Run Generate_ACGAN_Images.ipynb — saves synthetic images to data/test_generated/

Step 6: Train with Augmented Data

Modify the loader parameter in main.ipynb:

# No augmentation (default)
run_train_kfold('_no_aug', batch=32)

# With traditional augmentation
run_train_kfold('_aug1', batch=32)

# With GAN-generated images
run_train_kfold('_aug2', batch=32)

Configuration Parameters

Key parameters in conf/params.py:

IMG_SIZE = (224, 224)      # Input size for ResNet-50
MAX_EPOCH = 200            # Maximum epochs
BATCH_SIZE = 32            # Training batch size
FOLDS = 5                  # K-fold cross-validation
STOPPING_PATIENCE = 32     # Early stopping patience
LR_PATIENCE = 16           # Learning rate scheduler patience
INITIAL_LR = 0.0001        # Initial learning rate
NUM_CLASSES = 9            # Number of output classes

Discussion and Future Directions

Our results suggest several areas for improvement:

  1. DCGAN Resolution: Upsampling from 64×64 to 299×299 (for FID computation via Inception) likely causes information loss. Native high-resolution generation may yield better results.

  2. ACGAN Data Scarcity: With only ~1,000 images per class, the ACGAN struggles to learn meaningful class-conditional distributions. This presents a chicken-and-egg problem: GANs require sufficient data to generate quality synthetic samples.

  3. Alternative GAN Architectures:

    • StyleGAN: For improved resolution post-generation
    • CycleGAN: For domain transfer (e.g., generating wet-weather images from dry-weather images)
    • EVAGAN: Designed specifically for low-data scenarios requiring oversampling
  4. Spectral Normalization: Regularizing the discriminator by bounding its spectral norm may improve training stability.

References

  1. Olsen, A., et al. "DeepWeeds: A Multiclass Weed Species Image Dataset for Deep Learning." Scientific Reports 9, 2058 (2019). https://doi.org/10.1038/s41598-018-38343-3

  2. Radford, A., Metz, L., and Chintala, S. "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks." arXiv:1511.06434 (2016).

  3. Odena, A., Olah, C., and Shlens, J. "Conditional Image Synthesis with Auxiliary Classifier GANs." arXiv:1703.10717 (2017).

  4. Giuffrida, M.V., et al. "ARIGAN: Synthetic Arabidopsis Plants Using Generative Adversarial Network." arXiv:1709.00938 (2017).

  5. Nunn, E.J., Khadivi, P., and Samavi, S. "Compound Frechet Inception Distance for Quality Assessment of GAN Created Images." arXiv:2106.08575 (2021).

Citation

If you use this code in your research, please cite:

@article{mccollam2024deepweeds,
  title={Deep Weeds via Deep Fake: GAN-Based Data Augmentation for Weed Classification},
  author={McCollam, Gerald A.},
  year={2024},
  note={George Mason University}
}

Acknowledgements

  • Dr. Richard Johnson, Research Agronomist, USDA-ARS SRU
  • Dr. Alice Wright, Research Agronomist, Weed Science, USDA-ARS SRU
  • Dr. Al Ogeron, Agronomist, LSU Ag Center, Sugar Research Station
  • Dr. Erhan Guven, Data Scientist, Johns Hopkins University
  • Mohammed Rashed, Johns Hopkins University

License

This project is provided for research and educational purposes.

About

DCGAN and ACGAN for Weed Discrimination Task Training

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors