Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions R/PipeOpFeatureUnion.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ cbind_tasks = function(inputs, assert_targets_equal) {
stopf("All tasks must have the same target columns")
}

if (is.null(names(inputs))) names(inputs) = paste0("input", seq_along(inputs))
input_names = map(inputs, function(input) input$feature_names)
if (any(duplicated(unlist(input_names)))) {
# Treat duplicate names: Suffix with name of the input
new_names = imap(input_names, function(input, name) {
dupe = input %in% unlist(input_names[!(names(input_names) == name)])
input[dupe] = paste(input[dupe], name, sep = "_")
if (any(dupe)) return(input) else return(NULL)
})
new_names = discard(new_names, is.null)
# Use replace_features on renamed data. This seems very inefficient
inputs[names(new_names)] = imap(inputs[names(new_names)], function(input, name) {
browser()
dt = input$data(ids, input$feature_names)
colnames(dt) = new_names[[name]]
input$replace_features(dt)
return(input)
})
}

new_cols = Reduce(function(x, y) rcbind(x, y$data(ids, y$feature_names)), tail(inputs, -1L), init = data.table())
task$clone(deep = TRUE)$cbind(new_cols)
}
15 changes: 15 additions & 0 deletions tests/testthat/test_pipeop_featureunion.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ test_that("Test wrong inputs", {
task = mlr_tasks$get("iris")
expect_error(g$train(task), "Assertion on 'rows'")
})

# FIXME: Somewhat depends on https://github.com/mlr-org/mlr3/issues/268
# test_that("Duplicate Features", {
# # Define PipeOp's
# tsk = mlr_tasks$get("iris")
# t1 = tsk$clone()$set_col_role("Sepal.Length", character())
# t2 = tsk$clone()$set_col_role(c("Petal.Length", "Petal.Width"), character())

# po = PipeOpFeatureUnion$new(2)

# tout = train_pipeop(po, list(t1, t2))
# expect_equivalent(tout[[1]]$feature_names, c())
# expect_equivalent(tout[[1]]$target_names, tsk$target_names)
# })

# FIXME: depends on mlr-org/mlr3#179
## test_that("PipeOpFeatureUnion - levels are preserved", {
##
Expand Down
2 changes: 1 addition & 1 deletion vignettes/examples.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ Additionally, we also keep a version of the level 0 output (via `PipeOpNull`) an
# Plot the resulting graph
level_2$plot(html = TRUE)

task = mlr_tasks$get("iris"),
task = mlr_tasks$get("iris")
lrn = GraphLearner$new(level_2)

lrn$
Expand Down