Skip to content
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.13
Version: 0.1.14
Authors@R: c(
person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Add `climatological_forecaster()` to automatically create climate baselines
- Replace `dist_quantiles()` with `hardhat::quantile_pred()`
- Allow `quantile()` to threshold to an interval if desired (#434)
- `arx_forecaster()` detects if there's enough data to predict

## Bug fixes

Expand Down
4 changes: 3 additions & 1 deletion R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ arx_fcast_epi_workflow <- function(
step_epi_ahead(!!outcome, ahead = args_list$ahead)
r <- r %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
check_enough_train_data(all_predictors(), n = args_list$check_enough_data_n, skip = FALSE)

if (!is.null(args_list$check_enough_data_n)) {
r <- r %>% check_enough_train_data(
all_predictors(),
Expand Down
26 changes: 23 additions & 3 deletions R/check_enough_train_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ check_enough_train_data <-
role = NA,
trained = FALSE,
columns = NULL,
skip = TRUE,
skip = FALSE,
id = rand_id("enough_train_data")) {
recipes::add_check(
recipe,
Expand Down Expand Up @@ -90,7 +90,7 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
}

if (x$drop_na) {
training <- tidyr::drop_na(training)
training <- tidyr::drop_na(training, any_of(unname(col_names)))
}
cols_not_enough_data <- training %>%
group_by(across(all_of(.env$x$epi_keys))) %>%
Expand All @@ -101,7 +101,8 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {

if (length(cols_not_enough_data) > 0) {
cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
class = "epipredict__not_enough_train_data"
)
}

Expand All @@ -120,6 +121,25 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {

#' @export
bake.check_enough_train_data <- function(object, new_data, ...) {
col_names <- object$columns
if (object$drop_na) {
non_na_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
} else {
non_na_data <- new_data
}
cols_not_enough_data <- non_na_data %>%
group_by(across(all_of(.env$object$epi_keys))) %>%
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$object$n), .groups = "drop") %>%
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
unlist() %>%
names(.)[.]

if (length(cols_not_enough_data) > 0) {
cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
class = "epipredict__not_enough_train_data"
)
}
new_data
}

Expand Down
2 changes: 1 addition & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,6 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
hardhat::extract_preprocessor(object),
object$original_data
)

test_data
predict(object, new_data = test_data)
}
2 changes: 1 addition & 1 deletion tests/testthat/_snaps/check_enough_train_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

Code
epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% prep(
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>% prep(
toy_epi_df) %>% bake(new_data = NULL)
Condition
Error in `prep()`:
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,22 @@ test_that("arx_forecaster errors if forecast date, target date, and ahead are in
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
)
})

test_that("warns if there's not enough data to predict", {
edf <- tibble(
geo_value = "ct",
time_value = seq(as.Date("2020-10-01"), as.Date("2023-05-31"), by = "day"),
) %>%
mutate(value = seq_len(nrow(.)) + rnorm(nrow(.))) %>%
# Oct to May (flu season, ish) only:
filter(!between(as.POSIXlt(time_value)$mon + 1L, 6L, 9L)) %>%
# and actually, pretend we're around mid-October 2022:
filter(time_value <= as.Date("2022-10-12")) %>%
as_epi_df(as_of = as.Date("2022-10-12"))
edf %>% filter(time_value > "2022-08-01")

expect_error(
edf %>% arx_forecaster("value"),
class = "epipredict__not_enough_train_data"
)
})
11 changes: 6 additions & 5 deletions tests/testthat/test-check_enough_train_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,34 @@ test_that("check_enough_train_data only checks train data", {
epiprocess::as_epi_df()
expect_no_error(
epi_recipe(toy_epi_df) %>%
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>%
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value", skip = TRUE) %>%
prep(toy_epi_df) %>%
bake(new_data = toy_test_data)
)
# Same thing, but skip = FALSE
expect_no_error(
epi_recipe(toy_epi_df) %>%
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value", skip = FALSE) %>%
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value") %>%
prep(toy_epi_df) %>%
bake(new_data = toy_test_data)
)
})

test_that("check_enough_train_data works with all_predictors() downstream of constructed terms", {
# With a lag of 2, we will get 2 * n - 6 non-NA rows
# With a lag of 2, we will get 2 * n - 5 non-NA rows (NA's in x but not in the
# lags don't count)
expect_no_error(
epi_recipe(toy_epi_df) %>%
step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 6) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
prep(toy_epi_df) %>%
bake(new_data = NULL)
)
expect_snapshot(
error = TRUE,
epi_recipe(toy_epi_df) %>%
step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>%
prep(toy_epi_df) %>%
bake(new_data = NULL)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ test_that("Canned forecasters work with / without", {
)

expect_silent(
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"))
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), args_list = arx_args_list(check_enough_data_n = 1))
)
expect_silent(
flatline_forecaster(
Expand Down
Loading