-
Notifications
You must be signed in to change notification settings - Fork 17
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Description
Right now, the "low-level" stochtree
interface requires a long-list of parameters for initialization and sampling, for example:
forest_model_mean <- createForestModel(
forest_dataset, feature_types, num_trees, nrow(X),
alpha, beta, min_samples_leaf, max_depth
)
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, active_forest, rng,
feature_types, leaf_model, current_leaf_scale, variable_weights,
a_forest, b_forest, current_sigma2, cutpoint_grid_size,
keep_forest, gfr, pre_initialized
)
Even if we handle this for users with good defaults, these parameters have to cover every possible leaf / variance model we support. As stochtree
grows, this interface will become more and more complicated and the function calls will become lengthier.
Our proposal is to replace the long list of parameters like leaf_model
and a_forest
with a ModelParams
or ModelConfig
class which is initialized and configurable by users and then passed to functions like createForestModel()
and methods like forest_model$sample_one_iteration()
.
Ideas for Implementation / Expected Behavior
R6 class implementing a generic interface for specifying a model and its parameters
ModelConfig <- R6::R6Class(
classname = "ModelConfig",
cloneable = FALSE,
public = list(
#' @field leaf_model Integer coded leaf model type
leaf_model = NULL,
#' @field feature_types Vector of integer-coded feature types
feature_types = NULL,
# ....
initialize = function(model_type) {
self$leaf_model <- model_type
},
set_variance_forest_params = function(shape, scale) {
self$variance_forest_shape = shape
self$variance_forest_scale = scale
},
# ...
)
)
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request