Skip to content

Commit 2c6d82c

Browse files
authored
Merge pull request #116 from StochasticTree/r-tree-interface
Adding R methods to inspect tree structure
2 parents b6fe05a + 0640fa5 commit 2c6d82c

File tree

8 files changed

+1106
-29
lines changed

8 files changed

+1106
-29
lines changed

R/cpp11.R

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,74 @@ set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
332332
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
333333
}
334334

335+
is_leaf_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
336+
.Call(`_stochtree_is_leaf_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
337+
}
338+
339+
is_numeric_split_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
340+
.Call(`_stochtree_is_numeric_split_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
341+
}
342+
343+
is_categorical_split_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
344+
.Call(`_stochtree_is_categorical_split_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
345+
}
346+
347+
parent_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
348+
.Call(`_stochtree_parent_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
349+
}
350+
351+
left_child_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
352+
.Call(`_stochtree_left_child_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
353+
}
354+
355+
right_child_node_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
356+
.Call(`_stochtree_right_child_node_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
357+
}
358+
359+
node_depth_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
360+
.Call(`_stochtree_node_depth_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
361+
}
362+
363+
split_index_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
364+
.Call(`_stochtree_split_index_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
365+
}
366+
367+
split_theshold_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
368+
.Call(`_stochtree_split_theshold_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
369+
}
370+
371+
split_categories_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
372+
.Call(`_stochtree_split_categories_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
373+
}
374+
375+
leaf_values_forest_container_cpp <- function(forest_samples, forest_num, tree_num, node_id) {
376+
.Call(`_stochtree_leaf_values_forest_container_cpp`, forest_samples, forest_num, tree_num, node_id)
377+
}
378+
379+
num_nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
380+
.Call(`_stochtree_num_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
381+
}
382+
383+
num_leaves_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
384+
.Call(`_stochtree_num_leaves_forest_container_cpp`, forest_samples, forest_num, tree_num)
385+
}
386+
387+
num_leaf_parents_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
388+
.Call(`_stochtree_num_leaf_parents_forest_container_cpp`, forest_samples, forest_num, tree_num)
389+
}
390+
391+
num_split_nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
392+
.Call(`_stochtree_num_split_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
393+
}
394+
395+
nodes_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
396+
.Call(`_stochtree_nodes_forest_container_cpp`, forest_samples, forest_num, tree_num)
397+
}
398+
399+
leaves_forest_container_cpp <- function(forest_samples, forest_num, tree_num) {
400+
.Call(`_stochtree_leaves_forest_container_cpp`, forest_samples, forest_num, tree_num)
401+
}
402+
335403
initialize_forest_model_cpp <- function(data, residual, forest_samples, tracker, init_values, leaf_model_int) {
336404
invisible(.Call(`_stochtree_initialize_forest_model_cpp`, data, residual, forest_samples, tracker, init_values, leaf_model_int))
337405
}

R/forest.R

Lines changed: 187 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,12 @@ ForestSamples <- R6::R6Class(
297297
n_samples <- self$num_samples()
298298
n_trees <- self$num_trees()
299299
output <- get_granular_split_count_array_forest_container_cpp(self$forest_container_ptr, num_features)
300-
dim(output) <- c(n_trees, num_features, n_samples)
300+
dim(output) <- c(n_samples, n_trees, num_features)
301301
return(output)
302302
},
303303

304304
#' @description
305-
#' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object
305+
#' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object
306306
#' @param ensemble_num Ensemble number
307307
#' @param tree_num Tree index within ensemble `ensemble_num`
308308
#' @return Maximum leaf depth
@@ -311,7 +311,7 @@ ForestSamples <- R6::R6Class(
311311
},
312312

313313
#' @description
314-
#' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object
314+
#' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object
315315
#' @param ensemble_num Ensemble number
316316
#' @return Average maximum depth
317317
average_ensemble_max_depth = function(ensemble_num) {
@@ -326,19 +326,200 @@ ForestSamples <- R6::R6Class(
326326
},
327327

328328
#' @description
329-
#' Number of leaves in a given ensemble in a `ForestContainer` object
329+
#' Number of leaves in a given ensemble in a `ForestSamples` object
330330
#' @param forest_num Index of the ensemble to be queried
331331
#' @return Count of leaves in the ensemble stored at `forest_num`
332-
num_leaves = function(forest_num) {
332+
num_forest_leaves = function(forest_num) {
333333
return(num_leaves_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
334334
},
335335

336336
#' @description
337-
#' Sum of squared (raw) leaf values in a given ensemble in a `ForestContainer` object
337+
#' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object
338338
#' @param forest_num Index of the ensemble to be queried
339339
#' @return Average maximum depth
340340
sum_leaves_squared = function(forest_num) {
341341
return(sum_leaves_squared_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
342+
},
343+
344+
#' @description
345+
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf
346+
#' @param forest_num Index of the forest to be queried
347+
#' @param tree_num Index of the tree to be queried
348+
#' @param node_id Index of the node to be queried
349+
#' @return `TRUE` if node is a leaf, `FALSE` otherwise
350+
is_leaf_node = function(forest_num, tree_num, node_id) {
351+
return(is_leaf_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
352+
},
353+
354+
#' @description
355+
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node
356+
#' @param forest_num Index of the forest to be queried
357+
#' @param tree_num Index of the tree to be queried
358+
#' @param node_id Index of the node to be queried
359+
#' @return `TRUE` if node is a numeric split node, `FALSE` otherwise
360+
is_numeric_split_node = function(forest_num, tree_num, node_id) {
361+
return(is_numeric_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
362+
},
363+
364+
#' @description
365+
#' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node
366+
#' @param forest_num Index of the forest to be queried
367+
#' @param tree_num Index of the tree to be queried
368+
#' @param node_id Index of the node to be queried
369+
#' @return `TRUE` if node is a categorical split node, `FALSE` otherwise
370+
is_categorical_split_node = function(forest_num, tree_num, node_id) {
371+
return(is_categorical_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
372+
},
373+
374+
#' @description
375+
#' Parent node of given node of a given tree in a given forest in a `ForestSamples` object
376+
#' @param forest_num Index of the forest to be queried
377+
#' @param tree_num Index of the tree to be queried
378+
#' @param node_id Index of the node to be queried
379+
#' @return Integer ID of the parent node
380+
parent_node = function(forest_num, tree_num, node_id) {
381+
return(parent_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
382+
},
383+
384+
#' @description
385+
#' Left child node of given node of a given tree in a given forest in a `ForestSamples` object
386+
#' @param forest_num Index of the forest to be queried
387+
#' @param tree_num Index of the tree to be queried
388+
#' @param node_id Index of the node to be queried
389+
#' @return Integer ID of the left child node
390+
left_child_node = function(forest_num, tree_num, node_id) {
391+
return(left_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
392+
},
393+
394+
#' @description
395+
#' Right child node of given node of a given tree in a given forest in a `ForestSamples` object
396+
#' @param forest_num Index of the forest to be queried
397+
#' @param tree_num Index of the tree to be queried
398+
#' @param node_id Index of the node to be queried
399+
#' @return Integer ID of the right child node
400+
right_child_node = function(forest_num, tree_num, node_id) {
401+
return(right_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
402+
},
403+
404+
#' @description
405+
#' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node.
406+
#' @param forest_num Index of the forest to be queried
407+
#' @param tree_num Index of the tree to be queried
408+
#' @param node_id Index of the node to be queried
409+
#' @return Integer valued depth of the node
410+
node_depth = function(forest_num, tree_num, node_id) {
411+
return(node_depth_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
412+
},
413+
414+
#' @description
415+
#' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf.
416+
#' @param forest_num Index of the forest to be queried
417+
#' @param tree_num Index of the tree to be queried
418+
#' @param node_id Index of the node to be queried
419+
#' @return Integer valued depth of the node
420+
node_split_index = function(forest_num, tree_num, node_id) {
421+
if (self$is_leaf_node(forest_num, tree_num, node_id)) {
422+
return(-1)
423+
} else {
424+
return(split_index_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
425+
}
426+
},
427+
428+
#' @description
429+
#' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object.
430+
#' Returns `Inf` if the node is a leaf or a categorical split node.
431+
#' @param forest_num Index of the forest to be queried
432+
#' @param tree_num Index of the tree to be queried
433+
#' @param node_id Index of the node to be queried
434+
#' @return Threshold defining a split for the node
435+
node_split_threshold = function(forest_num, tree_num, node_id) {
436+
if (self$is_leaf_node(forest_num, tree_num, node_id) ||
437+
self$is_categorical_split_node(forest_num, tree_num, node_id)) {
438+
return(Inf)
439+
} else {
440+
return(split_theshold_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
441+
}
442+
},
443+
444+
#' @description
445+
#' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object.
446+
#' Returns `c(Inf)` if the node is a leaf or a numeric split node.
447+
#' @param forest_num Index of the forest to be queried
448+
#' @param tree_num Index of the tree to be queried
449+
#' @param node_id Index of the node to be queried
450+
#' @return Categories defining a split for the node
451+
node_split_categories = function(forest_num, tree_num, node_id) {
452+
if (self$is_leaf_node(forest_num, tree_num, node_id) ||
453+
self$is_numeric_split_node(forest_num, tree_num, node_id)) {
454+
return(c(Inf))
455+
} else {
456+
return(split_categories_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
457+
}
458+
},
459+
460+
#' @description
461+
#' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object.
462+
#' Values are stale if the node is a split node.
463+
#' @param forest_num Index of the forest to be queried
464+
#' @param tree_num Index of the tree to be queried
465+
#' @param node_id Index of the node to be queried
466+
#' @return Vector (often univariate) of leaf values
467+
node_leaf_values = function(forest_num, tree_num, node_id) {
468+
return(leaf_values_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id))
469+
},
470+
471+
#' @description
472+
#' Number of nodes in a given tree in a given forest in a `ForestSamples` object.
473+
#' @param forest_num Index of the forest to be queried
474+
#' @param tree_num Index of the tree to be queried
475+
#' @return Count of total tree nodes
476+
num_nodes = function(forest_num, tree_num) {
477+
return(num_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
478+
},
479+
480+
#' @description
481+
#' Number of leaves in a given tree in a given forest in a `ForestSamples` object.
482+
#' @param forest_num Index of the forest to be queried
483+
#' @param tree_num Index of the tree to be queried
484+
#' @return Count of total tree leaves
485+
num_leaves = function(forest_num, tree_num) {
486+
return(num_leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
487+
},
488+
489+
#' @description
490+
#' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object.
491+
#' @param forest_num Index of the forest to be queried
492+
#' @param tree_num Index of the tree to be queried
493+
#' @return Count of total tree leaf parents
494+
num_leaf_parents = function(forest_num, tree_num) {
495+
return(num_leaf_parents_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
496+
},
497+
498+
#' @description
499+
#' Number of split nodes in a given tree in a given forest in a `ForestSamples` object.
500+
#' @param forest_num Index of the forest to be queried
501+
#' @param tree_num Index of the tree to be queried
502+
#' @return Count of total tree split nodes
503+
num_split_nodes = function(forest_num, tree_num) {
504+
return(num_split_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
505+
},
506+
507+
#' @description
508+
#' Array of node indices in a given tree in a given forest in a `ForestSamples` object.
509+
#' @param forest_num Index of the forest to be queried
510+
#' @param tree_num Index of the tree to be queried
511+
#' @return Indices of tree nodes
512+
nodes = function(forest_num, tree_num) {
513+
return(nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
514+
},
515+
516+
#' @description
517+
#' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object.
518+
#' @param forest_num Index of the forest to be queried
519+
#' @param tree_num Index of the tree to be queried
520+
#' @return Indices of leaf nodes
521+
leaves = function(forest_num, tree_num) {
522+
return(leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num))
342523
}
343524
)
344525
)

demo/notebooks/tree_inspection.ipynb

Lines changed: 12 additions & 12 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)