Skip to content

Commit cd082dc

Browse files
authored
Merge pull request #113 from StochasticTree/tree-inspection-functionality
Tree inspection functionality and bug fixes
2 parents 19fb779 + 450779f commit cd082dc

File tree

6 files changed

+881
-19
lines changed

6 files changed

+881
-19
lines changed

demo/notebooks/prototype_interface.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@
766766
],
767767
"metadata": {
768768
"kernelspec": {
769-
"display_name": "stochtree-dev",
769+
"display_name": "venv",
770770
"language": "python",
771771
"name": "python3"
772772
},
@@ -780,7 +780,7 @@
780780
"name": "python",
781781
"nbconvert_exporter": "python",
782782
"pygments_lexer": "ipython3",
783-
"version": "3.10.14"
783+
"version": "3.8.17"
784784
}
785785
},
786786
"nbformat": 4,

demo/notebooks/tree_inspection.ipynb

Lines changed: 449 additions & 0 deletions
Large diffs are not rendered by default.

include/stochtree/tree.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,22 @@ class Tree {
461461
return node_type_[nid];
462462
}
463463

464+
/*!
465+
* \brief Whether the node is a numeric split node
466+
* \param nid ID of node being queried
467+
*/
468+
bool IsNumericSplitNode(std::int32_t nid) const {
469+
return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
470+
}
471+
472+
/*!
473+
* \brief Whether the node is a numeric split node
474+
* \param nid ID of node being queried
475+
*/
476+
bool IsCategoricalSplitNode(std::int32_t nid) const {
477+
return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
478+
}
479+
464480
/*!
465481
* \brief Query whether this tree contains any categorical splits
466482
*/
@@ -500,18 +516,35 @@ class Tree {
500516
[[nodiscard]] std::vector<std::int32_t> const& GetInternalNodes() const {
501517
return internal_nodes_;
502518
}
519+
503520
/*!
504521
* \brief Get indices of all leaf nodes.
505522
*/
506523
[[nodiscard]] std::vector<std::int32_t> const& GetLeaves() const {
507524
return leaves_;
508525
}
526+
509527
/*!
510528
* \brief Get indices of all leaf parent nodes.
511529
*/
512530
[[nodiscard]] std::vector<std::int32_t> const& GetLeafParents() const {
513531
return leaf_parents_;
514532
}
533+
534+
/*!
535+
* \brief Get indices of all valid (non-deleted) nodes.
536+
*/
537+
[[nodiscard]] std::vector<std::int32_t> GetNodes() {
538+
std::vector<std::int32_t> output;
539+
auto const& self = *this;
540+
this->WalkTree([&output, &self](std::int32_t nidx) {
541+
if (!self.IsDeleted(nidx)) {
542+
output.push_back(nidx);
543+
}
544+
return true;
545+
});
546+
return output;
547+
}
515548

516549
/*!
517550
* \brief Get the depth of a node

src/forest.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ cpp11::writable::integers get_tree_split_counts_forest_container_cpp(cpp11::exte
227227
StochTree::Tree* tree = ensemble->GetTree(tree_num);
228228
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
229229
for (int i = 0; i < split_nodes.size(); i++) {
230-
auto split_feature = split_nodes.at(i);
230+
auto node_id = split_nodes.at(i);
231+
auto split_feature = tree->SplitIndex(node_id);
231232
output.at(split_feature)++;
232233
}
233234
return output;
@@ -243,7 +244,8 @@ cpp11::writable::integers get_forest_split_counts_forest_container_cpp(cpp11::ex
243244
StochTree::Tree* tree = ensemble->GetTree(i);
244245
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
245246
for (int j = 0; j < split_nodes.size(); j++) {
246-
auto split_feature = split_nodes.at(j);
247+
auto node_id = split_nodes.at(j);
248+
auto split_feature = tree->SplitIndex(node_id);
247249
output.at(split_feature)++;
248250
}
249251
}
@@ -262,7 +264,8 @@ cpp11::writable::integers get_overall_split_counts_forest_container_cpp(cpp11::e
262264
StochTree::Tree* tree = ensemble->GetTree(j);
263265
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
264266
for (int k = 0; k < split_nodes.size(); k++) {
265-
auto split_feature = split_nodes.at(k);
267+
auto node_id = split_nodes.at(k);
268+
auto split_feature = tree->SplitIndex(node_id);
266269
output.at(split_feature)++;
267270
}
268271
}
@@ -282,8 +285,9 @@ cpp11::writable::integers get_granular_split_count_array_forest_container_cpp(cp
282285
StochTree::Tree* tree = ensemble->GetTree(j);
283286
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
284287
for (int k = 0; k < split_nodes.size(); k++) {
285-
auto split_feature = split_nodes.at(k);
286-
output.at(num_features*num_trees*i + split_feature*num_trees + j)++;
288+
auto node_id = split_nodes.at(k);
289+
auto split_feature = tree->SplitIndex(node_id);
290+
output.at(split_feature*num_samples*num_trees + j*num_samples + i)++;
287291
}
288292
}
289293
}

src/py_stochtree.cpp

Lines changed: 159 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class ForestContainerCpp {
164164
return forest_samples_->NumSamples();
165165
}
166166

167-
int NumLeaves(int forest_num) {
167+
int NumLeavesForest(int forest_num) {
168168
StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num);
169169
return forest->NumLeaves();
170170
}
@@ -428,7 +428,8 @@ class ForestContainerCpp {
428428
StochTree::Tree* tree = ensemble->GetTree(i);
429429
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
430430
for (int j = 0; j < split_nodes.size(); j++) {
431-
auto split_feature = split_nodes.at(j);
431+
auto node_id = split_nodes.at(j);
432+
auto split_feature = tree->SplitIndex(node_id);
432433
accessor(split_feature)++;
433434
}
434435
}
@@ -449,7 +450,8 @@ class ForestContainerCpp {
449450
StochTree::Tree* tree = ensemble->GetTree(j);
450451
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
451452
for (int k = 0; k < split_nodes.size(); k++) {
452-
auto split_feature = split_nodes.at(k);
453+
auto node_id = split_nodes.at(k);
454+
auto split_feature = tree->SplitIndex(node_id);
453455
accessor(split_feature)++;
454456
}
455457
}
@@ -460,11 +462,11 @@ class ForestContainerCpp {
460462
py::array_t<int> GetGranularSplitCounts(int num_features) {
461463
int num_samples = forest_samples_->NumSamples();
462464
int num_trees = forest_samples_->NumTrees();
463-
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_trees,num_features,num_samples}));
465+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_samples,num_trees,num_features}));
464466
auto accessor = result.mutable_unchecked<3>();
465-
for (int i = 0; i < num_trees; i++) {
466-
for (int j = 0; j < num_features; j++) {
467-
for (int k = 0; k < num_samples; k++) {
467+
for (int i = 0; i < num_samples; i++) {
468+
for (int j = 0; j < num_trees; j++) {
469+
for (int k = 0; k < num_features; k++) {
468470
accessor(i,j,k) = 0;
469471
}
470472
}
@@ -475,14 +477,144 @@ class ForestContainerCpp {
475477
StochTree::Tree* tree = ensemble->GetTree(j);
476478
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
477479
for (int k = 0; k < split_nodes.size(); k++) {
478-
auto split_feature = split_nodes.at(k);
479-
accessor(j,split_feature,i)++;
480+
auto node_id = split_nodes.at(k);
481+
auto split_feature = tree->SplitIndex(node_id);
482+
accessor(i,j,split_feature)++;
480483
}
481484
}
482485
}
483486
return result;
484487
}
485488

489+
bool IsLeafNode(int forest_id, int tree_id, int node_id) {
490+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
491+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
492+
return tree->IsLeaf(node_id);
493+
}
494+
495+
bool IsNumericSplitNode(int forest_id, int tree_id, int node_id) {
496+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
497+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
498+
return tree->IsNumericSplitNode(node_id);
499+
}
500+
501+
bool IsCategoricalSplitNode(int forest_id, int tree_id, int node_id) {
502+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
503+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
504+
return tree->IsCategoricalSplitNode(node_id);
505+
}
506+
507+
int ParentNode(int forest_id, int tree_id, int node_id) {
508+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
509+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
510+
return tree->Parent(node_id);
511+
}
512+
513+
int LeftChildNode(int forest_id, int tree_id, int node_id) {
514+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
515+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
516+
return tree->LeftChild(node_id);
517+
}
518+
519+
int RightChildNode(int forest_id, int tree_id, int node_id) {
520+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
521+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
522+
return tree->RightChild(node_id);
523+
}
524+
525+
int SplitIndex(int forest_id, int tree_id, int node_id) {
526+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
527+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
528+
return tree->SplitIndex(node_id);
529+
}
530+
531+
int NodeDepth(int forest_id, int tree_id, int node_id) {
532+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
533+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
534+
return tree->GetDepth(node_id);
535+
}
536+
537+
double SplitThreshold(int forest_id, int tree_id, int node_id) {
538+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
539+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
540+
return tree->Threshold(node_id);
541+
}
542+
543+
py::array_t<int> SplitCategories(int forest_id, int tree_id, int node_id) {
544+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
545+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
546+
std::vector<std::uint32_t> raw_categories = tree->CategoryList(node_id);
547+
int num_categories = raw_categories.size();
548+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_categories}));
549+
auto accessor = result.mutable_unchecked<1>();
550+
for (int i = 0; i < num_categories; i++) {
551+
accessor(i) = raw_categories.at(i);
552+
}
553+
return result;
554+
}
555+
556+
py::array_t<double> NodeLeafValues(int forest_id, int tree_id, int node_id) {
557+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
558+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
559+
int num_outputs = tree->OutputDimension();
560+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_outputs}));
561+
auto accessor = result.mutable_unchecked<1>();
562+
for (int i = 0; i < num_outputs; i++) {
563+
accessor(i) = tree->LeafValue(node_id, i);
564+
}
565+
return result;
566+
}
567+
568+
int NumNodes(int forest_id, int tree_id) {
569+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
570+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
571+
return tree->NumValidNodes();
572+
}
573+
574+
int NumLeaves(int forest_id, int tree_id) {
575+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
576+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
577+
return tree->NumLeaves();
578+
}
579+
580+
int NumLeafParents(int forest_id, int tree_id) {
581+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
582+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
583+
return tree->NumLeafParents();
584+
}
585+
586+
int NumSplitNodes(int forest_id, int tree_id) {
587+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
588+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
589+
return tree->NumSplitNodes();
590+
}
591+
592+
py::array_t<int> Nodes(int forest_id, int tree_id) {
593+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
594+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
595+
std::vector<std::int32_t> nodes = tree->GetNodes();
596+
int num_nodes = nodes.size();
597+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_nodes}));
598+
auto accessor = result.mutable_unchecked<1>();
599+
for (int i = 0; i < num_nodes; i++) {
600+
accessor(i) = nodes.at(i);
601+
}
602+
return result;
603+
}
604+
605+
py::array_t<int> Leaves(int forest_id, int tree_id) {
606+
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
607+
StochTree::Tree* tree = ensemble->GetTree(tree_id);
608+
std::vector<std::int32_t> leaves = tree->GetLeaves();
609+
int num_leaves = leaves.size();
610+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_leaves}));
611+
auto accessor = result.mutable_unchecked<1>();
612+
for (int i = 0; i < num_leaves; i++) {
613+
accessor(i) = leaves.at(i);
614+
}
615+
return result;
616+
}
617+
486618
private:
487619
std::unique_ptr<StochTree::ForestContainer> forest_samples_;
488620
int num_trees_;
@@ -1044,8 +1176,25 @@ PYBIND11_MODULE(stochtree_cpp, m) {
10441176
.def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts)
10451177
.def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts)
10461178
.def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts)
1179+
.def("NumLeavesForest", &ForestContainerCpp::NumLeavesForest)
1180+
.def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared)
1181+
.def("IsLeafNode", &ForestContainerCpp::IsLeafNode)
1182+
.def("IsNumericSplitNode", &ForestContainerCpp::IsNumericSplitNode)
1183+
.def("IsCategoricalSplitNode", &ForestContainerCpp::IsCategoricalSplitNode)
1184+
.def("ParentNode", &ForestContainerCpp::ParentNode)
1185+
.def("LeftChildNode", &ForestContainerCpp::LeftChildNode)
1186+
.def("RightChildNode", &ForestContainerCpp::RightChildNode)
1187+
.def("SplitIndex", &ForestContainerCpp::SplitIndex)
1188+
.def("NodeDepth", &ForestContainerCpp::NodeDepth)
1189+
.def("SplitThreshold", &ForestContainerCpp::SplitThreshold)
1190+
.def("SplitCategories", &ForestContainerCpp::SplitCategories)
1191+
.def("NodeLeafValues", &ForestContainerCpp::NodeLeafValues)
1192+
.def("NumNodes", &ForestContainerCpp::NumNodes)
10471193
.def("NumLeaves", &ForestContainerCpp::NumLeaves)
1048-
.def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared);
1194+
.def("NumLeafParents", &ForestContainerCpp::NumLeafParents)
1195+
.def("NumSplitNodes", &ForestContainerCpp::NumSplitNodes)
1196+
.def("Nodes", &ForestContainerCpp::Nodes)
1197+
.def("Leaves", &ForestContainerCpp::Leaves);
10491198

10501199
py::class_<ForestSamplerCpp>(m, "ForestSamplerCpp")
10511200
.def(py::init<ForestDatasetCpp&, py::array_t<int>, int, data_size_t, double, double, int, int>())

0 commit comments

Comments
 (0)