You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#' Set a constant predicted value for every tree in the ensemble.
139
+
#' Stops program if any tree is more than a root node.
140
+
#' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...)
141
+
#' @param outcome `Outcome` Outcome class (residual / partial residual)
142
+
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
143
+
#' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
144
+
#' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension.
#' corresponds to the observations for which outcomes are unobserved and must be estimated
133
133
#' based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided,
134
134
#' this function will only compute k(X_train, X_train).
135
-
#' @param forest_num (Option) Index of the forest sample to use for kernel computation. If not provided,
135
+
#' @param forest_num (Optional) Index of the forest sample to use for kernel computation. If not provided,
136
136
#' this function will use the last forest.
137
+
#' @param forest_type (Optional) Whether to compute the kernel from the mean or variance forest. Default: "mean". Specify "variance" for the variance forest.
138
+
#' All other inputs are invalid. Must have sampled the relevant forest or an error will occur.
137
139
#' @return List of kernel matrices. If `X_test = NULL`, the list contains
138
140
#' one `n_train` x `n_train` matrix, where `n_train = nrow(X_train)`.
139
141
#' This matrix is the kernel defined by `W_train %*% t(W_train)` where `W_train`
140
142
#' is a matrix with `n_train` rows and as many columns as there are total leaves in an ensemble.
141
143
#' If `X_test` is not `NULL`, the list contains two more matrices defined by
142
144
#' `W_test %*% t(W_train)` and `W_test %*% t(W_test)`.
#' corresponds to the observations for which outcomes are unobserved and must be estimated
193
213
#' based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided,
194
214
#' this function will only compute k(X_train, X_train).
195
-
#' @param forest_num (Option) Index of the forest sample to use for kernel computation. If not provided,
215
+
#' @param forest_num (Optional) Index of the forest sample to use for kernel computation. If not provided,
196
216
#' this function will use the last forest.
217
+
#' @param forest_type (Optional) Whether to compute the kernel from the mean or variance forest. Default: "mean". Specify "variance" for the variance forest.
218
+
#' All other inputs are invalid. Must have sampled the relevant forest or an error will occur.
197
219
#' @return List of vectors. If `X_test = NULL`, the list contains
198
220
#' one vector of length `n_train * num_trees`, where `n_train = nrow(X_train)`
199
221
#' and `num_trees` is the number of trees in `bart_model`. If `X_test` is not `NULL`,
200
222
#' the list contains another vector of length `n_test * num_trees`.
0 commit comments