@@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
885
885
if (sample_sigma2_global ) global_var_samples <- rep(NA , num_retained_samples )
886
886
if (sample_sigma2_leaf_mu ) leaf_scale_mu_samples <- rep(NA , num_retained_samples )
887
887
if (sample_sigma2_leaf_tau ) leaf_scale_tau_samples <- rep(NA , num_retained_samples )
888
+ muhat_train_raw <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
889
+ if (include_variance_forest ) sigma2_x_train_raw <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
888
890
sample_counter <- 0
889
891
890
892
# Prepare adaptive coding structure
@@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
997
999
global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
998
1000
)
999
1001
1002
+ # Cache train set predictions since they are already computed during sampling
1003
+ if (keep_sample ) {
1004
+ muhat_train_raw [,sample_counter ] <- forest_model_mu $ get_cached_forest_predictions()
1005
+ }
1006
+
1000
1007
# Sample variance parameters (if requested)
1001
1008
if (sample_sigma2_global ) {
1002
1009
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1016
1023
global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
1017
1024
)
1018
1025
1026
+ # Cannot cache train set predictions for tau because the cached predictions in the
1027
+ # tracking data structures are pre-multiplied by the basis (treatment)
1028
+ # ...
1029
+
1019
1030
# Sample coding parameters (if requested)
1020
1031
if (adaptive_coding ) {
1021
1032
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1060
1071
active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
1061
1072
global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
1062
1073
)
1074
+
1075
+ # Cache train set predictions since they are already computed during sampling
1076
+ if (keep_sample ) {
1077
+ sigma2_x_train_raw [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
1078
+ }
1063
1079
}
1064
1080
if (sample_sigma2_global ) {
1065
1081
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1263
1279
global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
1264
1280
)
1265
1281
1282
+ # Cache train set predictions since they are already computed during sampling
1283
+ if (keep_sample ) {
1284
+ muhat_train_raw [,sample_counter ] <- forest_model_mu $ get_cached_forest_predictions()
1285
+ }
1286
+
1266
1287
# Sample variance parameters (if requested)
1267
1288
if (sample_sigma2_global ) {
1268
1289
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1282
1303
global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
1283
1304
)
1284
1305
1306
+ # Cannot cache train set predictions for tau because the cached predictions in the
1307
+ # tracking data structures are pre-multiplied by the basis (treatment)
1308
+ # ...
1309
+
1285
1310
# Sample coding parameters (if requested)
1286
1311
if (adaptive_coding ) {
1287
1312
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1326
1351
active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
1327
1352
global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
1328
1353
)
1354
+
1355
+ # Cache train set predictions since they are already computed during sampling
1356
+ if (keep_sample ) {
1357
+ sigma2_x_train_raw [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
1358
+ }
1329
1359
}
1330
1360
if (sample_sigma2_global ) {
1331
1361
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1372
1402
b_1_samples <- b_1_samples [(num_gfr + 1 ): length(b_1_samples )]
1373
1403
b_0_samples <- b_0_samples [(num_gfr + 1 ): length(b_0_samples )]
1374
1404
}
1405
+ muhat_train_raw <- muhat_train_raw [,(num_gfr + 1 ): ncol(muhat_train_raw )]
1406
+ if (include_variance_forest ) {
1407
+ sigma2_x_train_raw <- sigma2_x_train_raw [,(num_gfr + 1 ): ncol(sigma2_x_train_raw )]
1408
+ }
1375
1409
num_retained_samples <- num_retained_samples - num_gfr
1376
1410
}
1377
1411
1378
1412
# Forest predictions
1379
- mu_hat_train <- forest_samples_mu $ predict( forest_dataset_train ) * y_std_train + y_bar_train
1413
+ mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train
1380
1414
if (adaptive_coding ) {
1381
1415
tau_hat_train_raw <- forest_samples_tau $ predict_raw(forest_dataset_train )
1382
1416
tau_hat_train <- t(t(tau_hat_train_raw ) * (b_1_samples - b_0_samples ))* y_std_train
@@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
1395
1429
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test )
1396
1430
}
1397
1431
if (include_variance_forest ) {
1398
- sigma2_x_hat_train <- forest_samples_variance $ predict( forest_dataset_train )
1432
+ sigma2_x_hat_train <- exp( sigma2_x_train_raw )
1399
1433
if (has_test ) sigma2_x_hat_test <- forest_samples_variance $ predict(forest_dataset_test )
1400
1434
}
1401
1435
0 commit comments