Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,74 @@ set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
}

is_leaf_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_is_leaf_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

is_numeric_split_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_is_numeric_split_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

is_categorical_split_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_is_categorical_split_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

parent_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_parent_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

left_child_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_left_child_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

right_child_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_right_child_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

node_depth_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_node_depth_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

split_index_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_split_index_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

split_theshold_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_split_theshold_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

split_categories_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_split_categories_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

leaf_values_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
.Call(`_stochtree_leaf_values_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
}

num_nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_num_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

num_leaves_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_num_leaves_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

num_leaf_parents_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_num_leaf_parents_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

num_split_nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_num_split_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

leaves_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
.Call(`_stochtree_leaves_forest_container_cpp`, forest_samples, forest_num, tree_num)
}

initialize_forest_model_cpp <- function(data, residual, forest_samples, tracker, init_values, leaf_model_int) {
invisible(.Call(`_stochtree_initialize_forest_model_cpp`, data, residual, forest_samples, tracker, init_values, leaf_model_int))
}
Expand Down
193 changes: 187 additions & 6 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ ForestSamples <- R6::R6Class(
n_samples <- self$num_samples()
n_trees <- self$num_trees()
output <- get_granular_split_count_array_forest_container_cpp(self$forest_container_ptr, num_features)
dim(output) <- c(n_trees, num_features, n_samples)
dim(output) <- c(n_samples, n_trees, num_features)
return(output)
},

#' @description
#' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object
#' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object
#' @param ensemble_num Ensemble number
#' @param tree_num Tree index within ensemble `ensemble_num`
#' @return Maximum leaf depth
Expand All @@ -311,7 +311,7 @@ ForestSamples <- R6::R6Class(
},

#' @description
#' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object
#' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object
#' @param ensemble_num Ensemble number
#' @return Average maximum depth
average_ensemble_max_depth = function(ensemble_num) {
Expand All @@ -326,19 +326,200 @@ ForestSamples <- R6::R6Class(
},

#' @description
#' Number of leaves in a given ensemble in a `ForestContainer` object
#' Number of leaves in a given ensemble in a `ForestSamples` object
#' @param forest_num Index of the ensemble to be queried
#' @return Count of leaves in the ensemble stored at `forest_num`
num_leaves = function(forest_num) {
num_forest_leaves = function(forest_num) {
return(num_leaves_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
},

#' @description
#' Sum of squared (raw) leaf values in a given ensemble in a `ForestContainer` object
#' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object
#' @param forest_num Index of the ensemble to be queried
#' @return Average maximum depth
sum_leaves_squared = function(forest_num) {
return(sum_leaves_squared_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
},

#' @description
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return `TRUE` if node is a leaf, `FALSE` otherwise
is_leaf_node = function(forest_num, tree_num, node_id) {
return(is_leaf_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return `TRUE` if node is a numeric split node, `FALSE` otherwise
is_numeric_split_node = function(forest_num, tree_num, node_id) {
return(is_numeric_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return `TRUE` if node is a categorical split node, `FALSE` otherwise
is_categorical_split_node = function(forest_num, tree_num, node_id) {
return(is_categorical_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Parent node of given node of a given tree in a given forest in a `ForestSamples` object
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Integer ID of the parent node
parent_node = function(forest_num, tree_num, node_id) {
return(parent_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Left child node of given node of a given tree in a given forest in a `ForestSamples` object
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Integer ID of the left child node
left_child_node = function(forest_num, tree_num, node_id) {
return(left_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Right child node of given node of a given tree in a given forest in a `ForestSamples` object
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Integer ID of the right child node
right_child_node = function(forest_num, tree_num, node_id) {
return(right_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Integer valued depth of the node
node_depth = function(forest_num, tree_num, node_id) {
return(node_depth_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Integer valued depth of the node
node_split_index = function(forest_num, tree_num, node_id) {
if (self$is_leaf_node(forest_num, tree_num, node_id)) {
return(-1)
} else {
return(split_index_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
}
},

#' @description
#' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object.
#' Returns `Inf` if the node is a leaf or a categorical split node.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Threshold defining a split for the node
node_split_threshold = function(forest_num, tree_num, node_id) {
if (self$is_leaf_node(forest_num, tree_num, node_id) ||
self$is_categorical_split_node(forest_num, tree_num, node_id)) {
return(Inf)
} else {
return(split_theshold_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
}
},

#' @description
#' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object.
#' Returns `c(Inf)` if the node is a leaf or a numeric split node.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Categories defining a split for the node
node_split_categories = function(forest_num, tree_num, node_id) {
if (self$is_leaf_node(forest_num, tree_num, node_id) ||
self$is_numeric_split_node(forest_num, tree_num, node_id)) {
return(c(Inf))
} else {
return(split_categories_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
}
},

#' @description
#' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object.
#' Values are stale if the node is a split node.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @param node_id Index of the node to be queried
#' @return Vector (often univariate) of leaf values
node_leaf_values = function(forest_num, tree_num, node_id) {
return(leaf_values_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
},

#' @description
#' Number of nodes in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Count of total tree nodes
num_nodes = function(forest_num, tree_num) {
return(num_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
},

#' @description
#' Number of leaves in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Count of total tree leaves
num_leaves = function(forest_num, tree_num) {
return(num_leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
},

#' @description
#' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Count of total tree leaf parents
num_leaf_parents = function(forest_num, tree_num) {
return(num_leaf_parents_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
},

#' @description
#' Number of split nodes in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Count of total tree split nodes
num_split_nodes = function(forest_num, tree_num) {
return(num_split_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
},

#' @description
#' Array of node indices in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Indices of tree nodes
nodes = function(forest_num, tree_num) {
return(nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
},

#' @description
#' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object.
#' @param forest_num Index of the forest to be queried
#' @param tree_num Index of the tree to be queried
#' @return Indices of leaf nodes
leaves = function(forest_num, tree_num) {
return(leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
}
)
)
Expand Down
24 changes: 12 additions & 12 deletions demo/notebooks/tree_inspection.ipynb

Large diffs are not rendered by default.

Loading