diff --git a/R/PipeOpFeatureUnion.R b/R/PipeOpFeatureUnion.R index ecc05448..cd7dc074 100644 --- a/R/PipeOpFeatureUnion.R +++ b/R/PipeOpFeatureUnion.R @@ -57,6 +57,26 @@ cbind_tasks = function(inputs, assert_targets_equal) { stopf("All tasks must have the same target columns") } + if (is.null(names(inputs))) names(inputs) = paste0("input", seq_along(inputs)) + input_names = map(inputs, function(input) input$feature_names) + if (any(duplicated(unlist(input_names)))) { + # Treat duplicate names: Suffix with name of the input + new_names = imap(input_names, function(input, name) { + dupe = input %in% unlist(input_names[!(names(input_names) == name)]) + input[dupe] = paste(input[dupe], name, sep = "_") + if (any(dupe)) return(input) else return(NULL) + }) + new_names = discard(new_names, is.null) + # Use replace_features on renamed data. This seems very inefficient + inputs[names(new_names)] = imap(inputs[names(new_names)], function(input, name) { + browser() + dt = input$data(ids, input$feature_names) + colnames(dt) = new_names[[name]] + input$replace_features(dt) + return(input) + }) + } + new_cols = Reduce(function(x, y) rcbind(x, y$data(ids, y$feature_names)), tail(inputs, -1L), init = data.table()) task$clone(deep = TRUE)$cbind(new_cols) } diff --git a/tests/testthat/test_pipeop_featureunion.R b/tests/testthat/test_pipeop_featureunion.R index 36de060d..f90c613e 100644 --- a/tests/testthat/test_pipeop_featureunion.R +++ b/tests/testthat/test_pipeop_featureunion.R @@ -75,6 +75,21 @@ test_that("Test wrong inputs", { task = mlr_tasks$get("iris") expect_error(g$train(task), "Assertion on 'rows'") }) + +# FIXME: Somewhat depends on https://github.com/mlr-org/mlr3/issues/268 +# test_that("Duplicate Features", { +# # Define PipeOp's +# tsk = mlr_tasks$get("iris") +# t1 = tsk$clone()$set_col_role("Sepal.Length", character()) +# t2 = tsk$clone()$set_col_role(c("Petal.Length", "Petal.Width"), character()) + +# po = PipeOpFeatureUnion$new(2) + +# tout = train_pipeop(po, list(t1, t2)) +# expect_equivalent(tout[[1]]$feature_names, c()) +# expect_equivalent(tout[[1]]$target_names, tsk$target_names) +# }) + # FIXME: depends on mlr-org/mlr3#179 ## test_that("PipeOpFeatureUnion - levels are preserved", { ## diff --git a/vignettes/examples.Rmd b/vignettes/examples.Rmd index 03036d67..4c071d92 100644 --- a/vignettes/examples.Rmd +++ b/vignettes/examples.Rmd @@ -453,7 +453,7 @@ Additionally, we also keep a version of the level 0 output (via `PipeOpNull`) an # Plot the resulting graph level_2$plot(html = TRUE) - task = mlr_tasks$get("iris"), + task = mlr_tasks$get("iris") lrn = GraphLearner$new(level_2) lrn$