Skip to content

Commit 0b97aa0

Browse files
committed
changes for #1229
1 parent bba92cd commit 0b97aa0

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

R/misc.R

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,15 @@ prompt_missing_implementation <- function(spec,
241241
#' @keywords internal
242242
#' @export
243243
show_call <- function(object) {
244-
object$method$fit$args <-
245-
map(object$method$fit$args, convert_arg)
244+
object$method$fit$args <- map(object$method$fit$args, convert_arg)
246245

247-
call2(object$method$fit$func["fun"],
248-
!!!object$method$fit$args,
249-
.ns = object$method$fit$func["pkg"]
250-
)
246+
fn_info <- as.list(object$method$fit$func)
247+
if (!any(names(fn_info) == "pkg")) {
248+
res <- call2(fn_info$fun, !!!object$method$fit$args)
249+
} else {
250+
res <- call2(fn_info$fun, !!!object$method$fit$args, .ns = fn_info$pkg)
251+
}
252+
res
251253
}
252254

253255
convert_arg <- function(x) {

parsnip.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
Version: 1.0
2+
ProjectId: e5169c4e-5aba-443d-938b-8765efc1d040
23

34
RestoreWorkspace: No
45
SaveWorkspace: No

tests/testthat/test-misc.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,44 @@ test_that('obtaining prediction columns', {
299299
)
300300

301301
})
302+
303+
304+
# ------------------------------------------------------------------------------
305+
306+
# https://github.com/tidymodels/parsnip/issues/1229
307+
test_that('register local models', {
308+
set_new_model("my_model")
309+
set_model_mode(model = "my_model", mode = "regression")
310+
set_model_engine(
311+
"my_model",
312+
mode = "regression",
313+
eng = "my_engine"
314+
)
315+
316+
my_model <-
317+
function(mode = "regression") {
318+
new_model_spec(
319+
"my_model",
320+
args = list(),
321+
eng_args = NULL,
322+
mode = mode,
323+
method = NULL,
324+
engine = NULL
325+
)
326+
}
327+
328+
set_fit(
329+
model = "my_model",
330+
eng = "my_engine",
331+
mode = "regression",
332+
value = list(
333+
interface = "matrix",
334+
protect = c("formula", "data"),
335+
func = c(fun = "my_model_fun"),
336+
defaults = list()
337+
)
338+
)
339+
340+
expect_snapshot(my_model() %>% translate("my_engine"))
341+
})
342+

0 commit comments

Comments
 (0)