This release contains the model weights from our paper. These are only guaranteed to work on this tag, so make sure you git checkout v1 first.
There were 5 folds. Thus, for each model, we provide 5 sets of weights, one for each fold. You may wish to ensemble these in practice.
You can use these weights from within this repo with:
import torch
import yaml
import gff.models.creation
# Paths to downloaded files
config_path = "path/to/config.yml"
weights_path = "path/to/weights.th"
# Create model
with open(config_path) as f:
C = yaml.safe_load(f)
model = gff.models.creation.create(C)
# Load weights into model
checkpoint = torch.load(weights_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.cuda()
# Dummy forward; obtaining real data is left to the user. Comments are left as a guide. It is not easy.
B, T = 2, C["weather_window"]
cH, cW = 32, 32
fH, fW = 224, 224
ex = {
# ERA5/ERA5-land: see scripts/dl-era5-land.py, scripts/export-context.py and gff.data_sources.load_exported_era5_nc
"era5": torch.randn((B, T, len(C["era5_keys"]), cH, cW)).cuda(),
"era5_land": torch.randn((B, T, len(C["era5_land_keys"]), cH, cW)).cuda(),
# GloFAS: see scripts/export-glofas.py, scripts/export-context.py and gff.data_sources.load_glofas
"glofas": torch.randn((B, T, len(C["glofas_keys"]), cH, cW)).cuda(),
# HydroATLAS: download rasterised hydroatlas, see gff.data_sources.load_pregenerated_raster
"hydroatlas_basin": torch.randn((B, len(C["hydroatlas_keys"]), cH, cW)).cuda(),
# DEM (context): see scripts/export-context.py and gff.data_sources.load_pregenerated_raster
"dem_context": torch.randn((B, 1, cH, cW)).cuda(),
# Sentinel-1: see ./preprocessing, gff.data_sources.download_s1, gff.data_sources.export_s1 and scripts/export-local.py
"s1": torch.randn(B, 2, fH, fW).cuda(),
"s1_lead_days": torch.randint(0, 20, (B,)).cuda(), # computed field
# DEM (local): see scripts/export-local.py and gff.data_sources.get_dem
"dem_local": torch.randn((B, 1, fH, fW)).cuda(),
# HAND: see scripts/export-local.py and gff.data_sources.get_hand
"hand": torch.randn((B, 1, fH, fW)).cuda(),
}
output = model(ex) # Note: model automatically handles normalisation
print(output.shape) # Water/No-water segmentation: B, 2, 224, 224If you want to use these weights outside this repo, then you have two main options.
- Copy the gff/models folder in its entirety into your project. This depends only on pytorch and numpy.
- Clone repo, and run
pip install -e .from this project root. Then you canimport gffin your own project to access all of our dataset utility functions. However, this depends on all ofenvironment.yml.