@@ -504,14 +504,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
504
504
if (! is.null(X_test )) X_test <- preprocessPredictionData(X_test , X_train_metadata )
505
505
506
506
# Convert all input data to matrices if not already converted
507
- if ((is.null(dim(Z_train ))) && (! is.null(Z_train ))) {
508
- Z_train <- as.matrix(as.numeric(Z_train ))
509
- }
507
+ Z_col <- ifelse(is.null(dim(Z_train )), 1 , ncol(Z_train ))
508
+ Z_train <- matrix (as.numeric(Z_train ), ncol = Z_col )
510
509
if ((is.null(dim(propensity_train ))) && (! is.null(propensity_train ))) {
511
510
propensity_train <- as.matrix(propensity_train )
512
511
}
513
- if ((is.null(dim( Z_test ))) && ( ! is.null(Z_test ) )) {
514
- Z_test <- as. matrix(as.numeric(Z_test ))
512
+ if (! is.null(Z_test )) {
513
+ Z_test <- matrix (as.numeric(Z_test ), ncol = Z_col )
515
514
}
516
515
if ((is.null(dim(propensity_test ))) && (! is.null(propensity_test ))) {
517
516
propensity_test <- as.matrix(propensity_test )
@@ -580,9 +579,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
580
579
}
581
580
}
582
581
583
- # Stop if multivariate treatment is provided
584
- if (ncol(Z_train ) > 1 ) stop(" Multivariate treatments are not currently supported" )
585
-
582
+ # # Stop if multivariate treatment is provided
583
+ # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")
584
+
585
+ # Handle multivariate treatment
586
+ has_multivariate_treatment <- ncol(Z_train ) > 1
587
+ if (has_multivariate_treatment ) {
588
+ # Disable adaptive coding, internal propensity model, and
589
+ # leaf scale sampling if treatment is multivariate
590
+ if (adaptive_coding ) {
591
+ warning(" Adaptive coding is incompatible with multivariate treatment and will be ignored" )
592
+ adaptive_coding <- FALSE
593
+ }
594
+ if (is.null(propensity_train )) {
595
+ if (propensity_covariate != " none" ) {
596
+ warning(" No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" )
597
+ propensity_covariate <- " none"
598
+ }
599
+ }
600
+ if (sample_sigma2_leaf_tau ) {
601
+ warning(" Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." )
602
+ sample_sigma2_leaf_tau <- FALSE
603
+ }
604
+ }
605
+
586
606
# Random effects covariance prior
587
607
if (has_rfx ) {
588
608
if (is.null(rfx_prior_var )) {
@@ -835,18 +855,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
835
855
current_sigma2 <- sigma2_init
836
856
}
837
857
838
- # Switch off leaf scale sampling for multivariate treatments
839
- if (ncol(Z_train ) > 1 ) {
840
- if (sample_sigma2_leaf_tau ) {
841
- warning(" Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." )
842
- sample_sigma2_leaf_tau <- FALSE
843
- }
844
- }
845
-
846
858
# Set mu and tau leaf models / dimensions
847
859
leaf_model_mu_forest <- 0
848
860
leaf_dimension_mu_forest <- 1
849
- if (ncol( Z_train ) > 1 ) {
861
+ if (has_multivariate_treatment ) {
850
862
leaf_model_tau_forest <- 2
851
863
leaf_dimension_tau_forest <- ncol(Z_train )
852
864
} else {
@@ -973,21 +985,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
973
985
974
986
# Container of forest samples
975
987
forest_samples_mu <- createForestSamples(num_trees_mu , 1 , TRUE )
976
- forest_samples_tau <- createForestSamples(num_trees_tau , 1 , FALSE )
988
+ forest_samples_tau <- createForestSamples(num_trees_tau , ncol( Z_train ) , FALSE )
977
989
active_forest_mu <- createForest(num_trees_mu , 1 , TRUE )
978
- active_forest_tau <- createForest(num_trees_tau , 1 , FALSE )
990
+ active_forest_tau <- createForest(num_trees_tau , ncol( Z_train ) , FALSE )
979
991
if (include_variance_forest ) {
980
992
forest_samples_variance <- createForestSamples(num_trees_variance , 1 , TRUE , TRUE )
981
993
active_forest_variance <- createForest(num_trees_variance , 1 , TRUE , TRUE )
982
994
}
983
995
984
996
# Initialize the leaves of each tree in the prognostic forest
985
- active_forest_mu $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_mu , 0 , init_mu )
997
+ active_forest_mu $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_mu , leaf_model_mu_forest , init_mu )
986
998
active_forest_mu $ adjust_residual(forest_dataset_train , outcome_train , forest_model_mu , FALSE , FALSE )
987
999
988
1000
# Initialize the leaves of each tree in the treatment effect forest
989
- init_tau <- 0 .
990
- active_forest_tau $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_tau , 1 , init_tau )
1001
+ init_tau <- rep( 0 . , ncol( Z_train ))
1002
+ active_forest_tau $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_tau , leaf_model_tau_forest , init_tau )
991
1003
active_forest_tau $ adjust_residual(forest_dataset_train , outcome_train , forest_model_tau , TRUE , FALSE )
992
1004
993
1005
# Initialize the leaves of each tree in the variance forest
@@ -1450,7 +1462,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1450
1462
} else {
1451
1463
tau_hat_train <- forest_samples_tau $ predict_raw(forest_dataset_train )* y_std_train
1452
1464
}
1453
- y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train )
1465
+ if (has_multivariate_treatment ) {
1466
+ tau_train_dim <- dim(tau_hat_train )
1467
+ tau_num_obs <- tau_train_dim [1 ]
1468
+ tau_num_samples <- tau_train_dim [3 ]
1469
+ treatment_term_train <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1470
+ for (i in 1 : nrow(Z_train )) {
1471
+ treatment_term_train [i ,] <- colSums(tau_hat_train [i ,,] * Z_train [i ,])
1472
+ }
1473
+ } else {
1474
+ treatment_term_train <- tau_hat_train * as.numeric(Z_train )
1475
+ }
1476
+ y_hat_train <- mu_hat_train + treatment_term_train
1454
1477
if (has_test ) {
1455
1478
mu_hat_test <- forest_samples_mu $ predict(forest_dataset_test )* y_std_train + y_bar_train
1456
1479
if (adaptive_coding ) {
@@ -1459,7 +1482,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1459
1482
} else {
1460
1483
tau_hat_test <- forest_samples_tau $ predict_raw(forest_dataset_test )* y_std_train
1461
1484
}
1462
- y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test )
1485
+ if (has_multivariate_treatment ) {
1486
+ tau_test_dim <- dim(tau_hat_test )
1487
+ tau_num_obs <- tau_test_dim [1 ]
1488
+ tau_num_samples <- tau_test_dim [3 ]
1489
+ treatment_term_test <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1490
+ for (i in 1 : nrow(Z_test )) {
1491
+ treatment_term_test [i ,] <- colSums(tau_hat_test [i ,,] * Z_test [i ,])
1492
+ }
1493
+ } else {
1494
+ treatment_term_test <- tau_hat_test * as.numeric(Z_test )
1495
+ }
1496
+ y_hat_test <- mu_hat_test + treatment_term_test
1463
1497
}
1464
1498
if (include_variance_forest ) {
1465
1499
sigma2_x_hat_train <- exp(sigma2_x_train_raw )
@@ -1526,6 +1560,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1526
1560
" treatment_dim" = ncol(Z_train ),
1527
1561
" propensity_covariate" = propensity_covariate ,
1528
1562
" binary_treatment" = binary_treatment ,
1563
+ " multivariate_treatment" = has_multivariate_treatment ,
1529
1564
" adaptive_coding" = adaptive_coding ,
1530
1565
" internal_propensity_model" = internal_propensity_model ,
1531
1566
" num_samples" = num_retained_samples ,
@@ -1722,6 +1757,17 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
1722
1757
} else {
1723
1758
tau_hat <- object $ forests_tau $ predict_raw(forest_dataset_pred )* y_std
1724
1759
}
1760
+ if (object $ model_params $ multivariate_treatment ) {
1761
+ tau_dim <- dim(tau_hat )
1762
+ tau_num_obs <- tau_dim [1 ]
1763
+ tau_num_samples <- tau_dim [3 ]
1764
+ treatment_term <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1765
+ for (i in 1 : nrow(Z )) {
1766
+ treatment_term [i ,] <- colSums(tau_hat [i ,,] * Z [i ,])
1767
+ }
1768
+ } else {
1769
+ treatment_term <- tau_hat * as.numeric(Z )
1770
+ }
1725
1771
if (object $ model_params $ include_variance_forest ) {
1726
1772
s_x_raw <- object $ forests_variance $ predict(forest_dataset_pred )
1727
1773
}
@@ -1732,7 +1778,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
1732
1778
}
1733
1779
1734
1780
# Compute overall "y_hat" predictions
1735
- y_hat <- mu_hat + tau_hat * as.numeric( Z )
1781
+ y_hat <- mu_hat + treatment_term
1736
1782
if (object $ model_params $ has_rfx ) y_hat <- y_hat + rfx_predictions
1737
1783
1738
1784
# Scale variance forest predictions
@@ -1974,6 +2020,7 @@ saveBCFModelToJson <- function(object){
1974
2020
jsonobj $ add_boolean(" has_rfx" , object $ model_params $ has_rfx )
1975
2021
jsonobj $ add_boolean(" has_rfx_basis" , object $ model_params $ has_rfx_basis )
1976
2022
jsonobj $ add_scalar(" num_rfx_basis" , object $ model_params $ num_rfx_basis )
2023
+ jsonobj $ add_boolean(" multivariate_treatment" , object $ model_params $ multivariate_treatment )
1977
2024
jsonobj $ add_boolean(" adaptive_coding" , object $ model_params $ adaptive_coding )
1978
2025
jsonobj $ add_boolean(" internal_propensity_model" , object $ model_params $ internal_propensity_model )
1979
2026
jsonobj $ add_scalar(" num_gfr" , object $ model_params $ num_gfr )
@@ -2305,6 +2352,7 @@ createBCFModelFromJson <- function(json_object){
2305
2352
model_params [[" has_rfx_basis" ]] <- json_object $ get_boolean(" has_rfx_basis" )
2306
2353
model_params [[" num_rfx_basis" ]] <- json_object $ get_scalar(" num_rfx_basis" )
2307
2354
model_params [[" adaptive_coding" ]] <- json_object $ get_boolean(" adaptive_coding" )
2355
+ model_params [[" multivariate_treatment" ]] <- json_object $ get_boolean(" multivariate_treatment" )
2308
2356
model_params [[" internal_propensity_model" ]] <- json_object $ get_boolean(" internal_propensity_model" )
2309
2357
model_params [[" num_gfr" ]] <- json_object $ get_scalar(" num_gfr" )
2310
2358
model_params [[" num_burnin" ]] <- json_object $ get_scalar(" num_burnin" )
@@ -2644,6 +2692,7 @@ createBCFModelFromCombinedJson <- function(json_object_list){
2644
2692
model_params [[" num_chains" ]] <- json_object_default $ get_scalar(" num_chains" )
2645
2693
model_params [[" keep_every" ]] <- json_object_default $ get_scalar(" keep_every" )
2646
2694
model_params [[" adaptive_coding" ]] <- json_object_default $ get_boolean(" adaptive_coding" )
2695
+ model_params [[" multivariate_treatment" ]] <- json_object_default $ get_boolean(" multivariate_treatment" )
2647
2696
model_params [[" internal_propensity_model" ]] <- json_object_default $ get_boolean(" internal_propensity_model" )
2648
2697
model_params [[" probit_outcome_model" ]] <- json_object_default $ get_boolean(" probit_outcome_model" )
2649
2698
@@ -2870,6 +2919,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
2870
2919
model_params [[" num_covariates" ]] <- json_object_default $ get_scalar(" num_covariates" )
2871
2920
model_params [[" num_chains" ]] <- json_object_default $ get_scalar(" num_chains" )
2872
2921
model_params [[" keep_every" ]] <- json_object_default $ get_scalar(" keep_every" )
2922
+ model_params [[" multivariate_treatment" ]] <- json_object_default $ get_boolean(" multivariate_treatment" )
2873
2923
model_params [[" adaptive_coding" ]] <- json_object_default $ get_boolean(" adaptive_coding" )
2874
2924
model_params [[" internal_propensity_model" ]] <- json_object_default $ get_boolean(" internal_propensity_model" )
2875
2925
model_params [[" probit_outcome_model" ]] <- json_object_default $ get_boolean(" probit_outcome_model" )
0 commit comments