Skip to content

Commit 1c11828

Browse files
authored
Merge pull request #138 from epiforecasts/variant-phase
Add variant phase
2 parents 41769c2 + 3277829 commit 1c11828

16 files changed

Lines changed: 159996 additions & 97 deletions

R/analysis-descriptive.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Aim: describe interval score in terms of model structure and country target type
22
# Load data:
33
# source(here("R", "process-data.R"))
4-
# scores <- prep_data(scoring_scale = "log")
4+
# scores <- process_data(scoring_scale = "log")
55
library(here)
66
library(dplyr)
77
library(purrr)
@@ -258,7 +258,7 @@ plot_ridges <- function(scores, target = "Deaths") {
258258
# Table of targets by model -------------
259259
table_targets <- function(scores) {
260260
table_targets <- scores |>
261-
select(Model, outcome_target, forecast_date, location) |>
261+
select(Model, outcome_target, forecast_date, Location) |>
262262
distinct() |>
263263
group_by(Model, outcome_target, forecast_date) |>
264264
summarise(target_count = n(), .groups = "drop") |>
@@ -312,11 +312,12 @@ table_metadata <- function(scores) {
312312
# Data --------------------
313313
data_plot <- function(scores, log = FALSE, all = FALSE) {
314314
data <- scores |>
315-
select(location, outcome_target, target_end_date, Incidence) |>
315+
select(Location, outcome_target, target_end_date, Incidence) |>
316316
distinct()
317-
pop <- read_csv(here("data", "populations.csv"), show_col_types = FALSE)
317+
pop <- read_csv(here("data", "populations.csv"), show_col_types = FALSE) |>
318+
rename(Location = location)
318319
data <- data |>
319-
left_join(pop, by = join_by(location)) |>
320+
left_join(pop, by = join_by(Location)) |>
320321
mutate(
321322
rel_inc = Incidence / population * 1e5,
322323
log_inc = log(Incidence + 1)
@@ -331,11 +332,11 @@ data_plot <- function(scores, log = FALSE, all = FALSE) {
331332
mutate(
332333
rel_inc = Incidence / population * 1e5,
333334
log_inc = log(Incidence + 1),
334-
location = "Total"
335+
Location = "Total"
335336
)
336337
var_name <- ifelse(log, "log_inc", "rel_inc")
337338
plot <- ggplot(mapping = aes(
338-
x = target_end_date, y = .data[[var_name]], group = location
339+
x = target_end_date, y = .data[[var_name]], group = Location
339340
))
340341

341342
if (all) {
@@ -360,14 +361,14 @@ data_plot <- function(scores, log = FALSE, all = FALSE) {
360361

361362
trends_plot <- function(scores) {
362363
trends <- scores |>
363-
select(location, target_end_date, Incidence, Trend) |>
364+
select(Location, target_end_date, Incidence, Trend) |>
364365
distinct()
365366
p <- ggplot(trends, aes(x = target_end_date, y = Incidence)) +
366367
geom_point(mapping = aes(colour = Trend), size = 1) +
367368
geom_line() +
368369
scale_colour_brewer(palette = "Set2", na.value = "grey") +
369370
theme(legend.position = "bottom") +
370-
facet_wrap(~location, scales = "free_y") +
371+
facet_wrap(~Location, scales = "free_y") +
371372
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
372373
xlab("")
373374
return(p)

R/analysis-model.R

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11
# Aim: use a GAMM to model the effects of model structure and country target type on WIS
2+
# Model:
3+
#
4+
# Method: model method (mechanistic, statistical, etc.)
5+
# CountryTargets: model predicts for single- vs multi-country
6+
# Trend: epidemic trend (stable, increasing, decreasing)
7+
# Location: location (random effect)
8+
# VariantPhase: dominant variant phase (random effect)
9+
# Horizon: forecast horizon (smooth, by model)
10+
# Model: individual model (random effect)
11+
#
12+
# Response: WIS (log-transformed, Gaussian family with log link)
13+
214
library(here)
315
library(dplyr)
416
library(readr)
@@ -11,44 +23,31 @@ source(here("R", "process-data.R"))
1123
source(here("R", "analysis-descriptive.R"))
1224

1325
# --- Get data ---
14-
data <- prep_data(scoring_scale = "log")
15-
outcomes <- unique(data$outcome_target)
16-
classification <- classify_models()
17-
targets <- table_targets(data)
18-
26+
data <- process_data(scoring_scale = "log")
1927
m.data <- data |>
20-
filter(!grepl("EuroCOVIDhub-", Model)) |>
21-
mutate(location = factor(location)) |>
22-
group_by(location) |>
23-
mutate(
24-
time = as.numeric(forecast_date - min(forecast_date)) / 7,
25-
Horizon = as.numeric(Horizon),
26-
wis = wis + 1e-7
27-
) |>
28-
ungroup()
28+
filter(!grepl("EuroCOVIDhub-", Model))
29+
outcomes <- unique(data$outcome_target)
2930

3031
# --- Model formula ---
31-
# Univariate for explanatory variables
32-
m.formula_uni_type <- wis ~ s(Method, bs = "re")
33-
m.formula_uni_tgt <- wis ~ s(CountryTargets, bs = "re")
34-
m.formula_uni_model <- wis ~ s(Model, bs = "re")
32+
# Univariate for each explanatory variable
33+
m.formulas_uni <- list(
34+
method = wis ~ s(Method, bs = "re"),
35+
target = wis ~ s(CountryTargets, bs = "re"),
36+
trend = wis ~ s(Trend, bs = "re"),
37+
location = wis ~ s(Location, bs = "re"),
38+
variant = wis ~ s(VariantPhase, bs = "re"),
39+
horizon = wis ~ s(Horizon, by = Model, k = 3, bs = "sz"),
40+
model = wis ~ s(Model, bs = "re")
41+
)
3542

36-
# Full model
37-
m.formula <- wis ~
38-
# Method
43+
# Full joint model
44+
m.formula_joint <- wis ~
3945
s(Method, bs = "re") +
40-
# Number of target countries
4146
s(CountryTargets, bs = "re") +
42-
# -----------------------------
43-
# Trend
4447
s(Trend, bs = "re") +
45-
# Location
46-
s(location, bs = "re") +
47-
# Week * location
48-
s(time, by = location, k = 40) +
49-
# Horizon
50-
s(Horizon, k = 3, by = Model, bs = "sz") +
51-
# Individual model
48+
s(Location, bs = "re") +
49+
s(VariantPhase, bs = "re") +
50+
s(Horizon, by = Model, k = 3, bs = "sz") +
5251
s(Model, bs = "re")
5352

5453
# --- Model fitting ---
@@ -69,23 +68,22 @@ m.fit <- function(outcomes, m.formula) {
6968
}
7069
# Fit
7170
cat("--------fitting univariate models")
72-
m.fits_uni_type <- m.fit(outcomes, m.formula_uni_type)
73-
m.fits_uni_tgt <- m.fit(outcomes, m.formula_uni_tgt)
74-
m.fits_uni_model <- m.fit(outcomes, m.formula_uni_model)
71+
m.fits_uni <- map(m.formulas_uni, ~ m.fit(outcomes, .x))
72+
7573
cat("--------fitting joint model")
76-
m.fits_joint <- m.fit(outcomes, m.formula)
77-
cat("finished fitting")
74+
m.fits_joint <- m.fit(outcomes, m.formula_joint)
75+
7876
# --- Output handling ---
7977
# Extract estimates for random effects
80-
random_effects_uni <- map_df(
81-
c(m.fits_uni_type, m.fits_uni_tgt, m.fits_uni_model),
82-
extract_ranef,
83-
.id = "outcome_target") |>
78+
random_effects_uni <- m.fits_uni[!grepl("horizon", names(m.fits_uni))] |>
79+
map_depth(.depth = 2, ~ extract_ranef(.x)) |>
80+
map(~ list_rbind(.x, names_to = "outcome_target")) |>
81+
list_rbind() |>
8482
mutate(model = "Unadjusted")
8583

8684
random_effects_joint <- map_df(m.fits_joint,
87-
extract_ranef,
88-
.id = "outcome_target") |>
85+
extract_ranef,
86+
.id = "outcome_target") |>
8987
mutate(model = "Adjusted")
9088

9189
random_effects <- random_effects_joint |>

R/plot-model-results.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ plot_models <- function(random_effects, scores, x_labels = TRUE,
6363
}
6464

6565
plot_effects <- function(random_effects,
66-
variables = c("Method", "CountryTargets")) {
66+
variables = NULL) {
67+
if(is.null(variables)){variables <- unique(random_effects$group_var)}
68+
6769
random_effects |>
6870
filter(group_var %in% variables) |>
6971
mutate(group = factor(group, levels = unique(as.character(rev(group)))),

R/process-data.R

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
library("dplyr")
2-
library("tidyr")
3-
library("purrr")
4-
library("readr")
5-
library("lubridate")
1+
library(here)
2+
library(dplyr)
3+
library(tidyr)
4+
library(purrr)
5+
library(readr)
6+
library(lubridate)
7+
source(here("R", "utils-variants.R"))
68

79
# Metadata ----------------------------------------------------------------
810
# Get classification of model types
@@ -32,20 +34,21 @@ classify_models <- function(file = here("data", "model-classification.csv")) {
3234
return(methods)
3335
}
3436

35-
# Scores data: add explanatory variables -----------------------------
36-
# Get scores for all forecasts and add explanatory variables used:
37-
# number of country targets, method classification, trend of observed incidence
38-
prep_data <- function(scoring_scale = "log") {
37+
# Prepare data for analysis -----------------------------
38+
# Get scores for all forecasts; and add explanatory variables in a single dframe
39+
process_data <- function(scoring_scale = "log") {
40+
# Get raw interval scores ----------------------------------------
41+
# scores data created in: R/process-score.r
3942
scores_files <- list.files(here("data"), pattern = "scores-raw-.*\\.csv")
4043
names(scores_files) <- sub("scores-raw-(.*)\\..*$", "\\1", scores_files)
41-
# Get raw interval score
4244
scores_raw <- scores_files |>
4345
map(\(file) {
4446
read_csv(here("data", file))
4547
}) |>
4648
bind_rows(.id = "outcome_target") |>
4749
filter(scale == scoring_scale)
4850

51+
# Add variables of interest to scores dataframe ----------------------
4952
# Target type
5053
country_targets <- scores_raw |>
5154
select(model, forecast_date, location) |>
@@ -66,7 +69,7 @@ prep_data <- function(scoring_scale = "log") {
6669
methods <- classify_models() |>
6770
select(model, Method = classification, agreement)
6871

69-
# Incidence level + trend (see: R/import-data.r)
72+
# Incidence level + trend (observed data from: R/utils-data.r)
7073
obs <- names(scores_files) |>
7174
set_names() |>
7275
map(~ read_csv(here("data", paste0("observed-", .x, ".csv")))) |>
@@ -76,19 +79,23 @@ prep_data <- function(scoring_scale = "log") {
7679
rename(Incidence = observed) |>
7780
select(target_end_date, location, outcome_target, Trend, Incidence)
7881

82+
# Variant phase
83+
variant_phase <- classify_variant_phases()
84+
85+
# Combine all data -----------------------------------------------------
7986
data <- scores_raw |>
8087
left_join(obs, by = c("location", "target_end_date", "outcome_target")) |>
88+
left_join(variant_phase, by = c("location", "target_end_date")) |>
8189
left_join(country_targets, by = "model") |>
8290
left_join(methods, by = "model") |>
83-
rename(Model = model, Horizon = horizon) |>
91+
# set to factors
92+
rename(Model = model, Horizon = horizon, Location = location) |>
8493
mutate(
94+
Horizon = ifelse(!Horizon %in% 1:4, NA_integer_, Horizon),
8595
Model = as.factor(Model),
96+
Location = as.factor(Location),
8697
outcome_target = paste0(str_to_title(outcome_target), "s"),
87-
Horizon = ordered(Horizon,
88-
levels = 1:4, labels = 1:4
89-
),
90-
log_wis = log(wis + 0.01)
91-
) |>
98+
wis = wis + 1e-7) |>
9299
filter(!is.na(Horizon)) ## horizon not in 1:4
93100
return(data)
94101
}

0 commit comments

Comments
 (0)