Skip to content

Commit 78f30a5

Browse files
authored
Merge pull request #207 from StochasticTree/split_count_hotfix
Fix split count extraction bugs
2 parents f621350 + 2100c1b commit 78f30a5

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/forest.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,9 @@ cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::exte
770770
StochTree::Tree* tree = active_forest->GetTree(i);
771771
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
772772
for (int j = 0; j < split_nodes.size(); j++) {
773-
auto split_feature = split_nodes.at(j);
774-
output.at(split_feature)++;
773+
auto node_id = split_nodes.at(j);
774+
auto feature_split = tree->SplitIndex(node_id);
775+
output.at(feature_split)++;
775776
}
776777
}
777778
return output;
@@ -786,8 +787,9 @@ cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11
786787
StochTree::Tree* tree = active_forest->GetTree(i);
787788
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
788789
for (int j = 0; j < split_nodes.size(); j++) {
789-
auto split_feature = split_nodes.at(j);
790-
output.at(split_feature*num_trees + i)++;
790+
auto node_id = split_nodes.at(j);
791+
auto feature_split = tree->SplitIndex(node_id);
792+
output.at(feature_split*num_trees + i)++;
791793
}
792794
}
793795
return output;

src/py_stochtree.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ class ForestContainerCpp {
501501
StochTree::Tree* tree = ensemble->GetTree(tree_num);
502502
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
503503
for (int i = 0; i < split_nodes.size(); i++) {
504-
auto split_feature = split_nodes.at(i);
504+
auto node_id = split_nodes.at(i);
505+
auto split_feature = tree->SplitIndex(node_id);
505506
accessor(split_feature)++;
506507
}
507508
return result;

0 commit comments

Comments
 (0)