diff --git a/NEWS.md b/NEWS.md index d184da56a..f18c997b5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,8 @@ * Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083). +* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#1219). + * The deprecated function `rpart_train()` was removed after its deprecation period (#1044). ## Bug Fixes diff --git a/R/linear_reg.R b/R/linear_reg.R index 0b7b636b4..e0def96fd 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -73,6 +73,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) } + x } @@ -113,5 +114,26 @@ check_args.linear_reg <- function(object, call = rlang::caller_env()) { check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + # ------------------------------------------------------------------------------ + # We want to avoid folks passing in a poisson family instead of using + # poisson_reg(). It's hard to detect this. + + is_fam <- names(object$eng_args) == "family" + if (any(is_fam)) { + eng_args <- rlang::eval_tidy(object$eng_args[[which(is_fam)]]) + if (is.function(eng_args)) { + eng_args <- try(eng_args(), silent = TRUE) + } + if (inherits(eng_args, "family")) { + eng_args <- eng_args$family + } + if (eng_args == "poisson") { + cli::cli_abort( + "A Poisson family was requested for {.fn linear_reg}. Please use + {.fn poisson_reg} and the engines in the {.pkg poissonreg} package.", + call = rlang::call2("linear_reg")) + } + } + invisible(object) } diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index f497ce3da..229828f30 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -139,3 +139,39 @@ Error in `fit()`: ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. +# prevent using a Poisson family + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>% fit(mpg ~ + ., data = mtcars) + Condition + Error in `fit()`: + ! Please install the glmnet package to use this engine. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson) %>% + fit(mpg ~ ., data = mtcars) + Condition + Error in `fit()`: + ! Please install the glmnet package to use this engine. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson()) %>% + fit(mpg ~ ., data = mtcars) + Condition + Error in `fit()`: + ! Please install the glmnet package to use this engine. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = "poisson") %>% fit( + mpg ~ ., data = mtcars) + Condition + Error in `fit()`: + ! Please install the glmnet package to use this engine. + diff --git a/tests/testthat/test-linear_reg.R b/tests/testthat/test-linear_reg.R index 62567fc44..ef0022feb 100644 --- a/tests/testthat/test-linear_reg.R +++ b/tests/testthat/test-linear_reg.R @@ -358,3 +358,32 @@ test_that("check_args() works", { } ) }) + + +test_that("prevent using a Poisson family", { + skip_if(rlang::is_installed("glmnet")) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = poisson) %>% + fit(mpg ~ ., data = mtcars), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = stats::poisson) %>% + fit(mpg ~ ., data = mtcars), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = stats::poisson()) %>% + fit(mpg ~ ., data = mtcars), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = "poisson") %>% + fit(mpg ~ ., data = mtcars), + error = TRUE + ) +})