@@ -164,7 +164,7 @@ class ForestContainerCpp {
164
164
return forest_samples_->NumSamples ();
165
165
}
166
166
167
- int NumLeaves (int forest_num) {
167
+ int NumLeavesForest (int forest_num) {
168
168
StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble (forest_num);
169
169
return forest->NumLeaves ();
170
170
}
@@ -428,7 +428,8 @@ class ForestContainerCpp {
428
428
StochTree::Tree* tree = ensemble->GetTree (i);
429
429
std::vector<int32_t > split_nodes = tree->GetInternalNodes ();
430
430
for (int j = 0 ; j < split_nodes.size (); j++) {
431
- auto split_feature = split_nodes.at (j);
431
+ auto node_id = split_nodes.at (j);
432
+ auto split_feature = tree->SplitIndex (node_id);
432
433
accessor (split_feature)++;
433
434
}
434
435
}
@@ -449,7 +450,8 @@ class ForestContainerCpp {
449
450
StochTree::Tree* tree = ensemble->GetTree (j);
450
451
std::vector<int32_t > split_nodes = tree->GetInternalNodes ();
451
452
for (int k = 0 ; k < split_nodes.size (); k++) {
452
- auto split_feature = split_nodes.at (k);
453
+ auto node_id = split_nodes.at (k);
454
+ auto split_feature = tree->SplitIndex (node_id);
453
455
accessor (split_feature)++;
454
456
}
455
457
}
@@ -460,11 +462,11 @@ class ForestContainerCpp {
460
462
py::array_t <int > GetGranularSplitCounts (int num_features) {
461
463
int num_samples = forest_samples_->NumSamples ();
462
464
int num_trees = forest_samples_->NumTrees ();
463
- auto result = py::array_t <int >(py::detail::any_container<py::ssize_t >({num_trees,num_features,num_samples }));
465
+ auto result = py::array_t <int >(py::detail::any_container<py::ssize_t >({num_samples, num_trees,num_features}));
464
466
auto accessor = result.mutable_unchecked <3 >();
465
- for (int i = 0 ; i < num_trees ; i++) {
466
- for (int j = 0 ; j < num_features ; j++) {
467
- for (int k = 0 ; k < num_samples ; k++) {
467
+ for (int i = 0 ; i < num_samples ; i++) {
468
+ for (int j = 0 ; j < num_trees ; j++) {
469
+ for (int k = 0 ; k < num_features ; k++) {
468
470
accessor (i,j,k) = 0 ;
469
471
}
470
472
}
@@ -475,14 +477,144 @@ class ForestContainerCpp {
475
477
StochTree::Tree* tree = ensemble->GetTree (j);
476
478
std::vector<int32_t > split_nodes = tree->GetInternalNodes ();
477
479
for (int k = 0 ; k < split_nodes.size (); k++) {
478
- auto split_feature = split_nodes.at (k);
479
- accessor (j,split_feature,i)++;
480
+ auto node_id = split_nodes.at (k);
481
+ auto split_feature = tree->SplitIndex (node_id);
482
+ accessor (i,j,split_feature)++;
480
483
}
481
484
}
482
485
}
483
486
return result;
484
487
}
485
488
489
+ bool IsLeafNode (int forest_id, int tree_id, int node_id) {
490
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
491
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
492
+ return tree->IsLeaf (node_id);
493
+ }
494
+
495
+ bool IsNumericSplitNode (int forest_id, int tree_id, int node_id) {
496
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
497
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
498
+ return tree->IsNumericSplitNode (node_id);
499
+ }
500
+
501
+ bool IsCategoricalSplitNode (int forest_id, int tree_id, int node_id) {
502
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
503
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
504
+ return tree->IsCategoricalSplitNode (node_id);
505
+ }
506
+
507
+ int ParentNode (int forest_id, int tree_id, int node_id) {
508
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
509
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
510
+ return tree->Parent (node_id);
511
+ }
512
+
513
+ int LeftChildNode (int forest_id, int tree_id, int node_id) {
514
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
515
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
516
+ return tree->LeftChild (node_id);
517
+ }
518
+
519
+ int RightChildNode (int forest_id, int tree_id, int node_id) {
520
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
521
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
522
+ return tree->RightChild (node_id);
523
+ }
524
+
525
+ int SplitIndex (int forest_id, int tree_id, int node_id) {
526
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
527
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
528
+ return tree->SplitIndex (node_id);
529
+ }
530
+
531
+ int NodeDepth (int forest_id, int tree_id, int node_id) {
532
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
533
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
534
+ return tree->GetDepth (node_id);
535
+ }
536
+
537
+ double SplitThreshold (int forest_id, int tree_id, int node_id) {
538
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
539
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
540
+ return tree->Threshold (node_id);
541
+ }
542
+
543
+ py::array_t <int > SplitCategories (int forest_id, int tree_id, int node_id) {
544
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
545
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
546
+ std::vector<std::uint32_t > raw_categories = tree->CategoryList (node_id);
547
+ int num_categories = raw_categories.size ();
548
+ auto result = py::array_t <int >(py::detail::any_container<py::ssize_t >({num_categories}));
549
+ auto accessor = result.mutable_unchecked <1 >();
550
+ for (int i = 0 ; i < num_categories; i++) {
551
+ accessor (i) = raw_categories.at (i);
552
+ }
553
+ return result;
554
+ }
555
+
556
+ py::array_t <double > NodeLeafValues (int forest_id, int tree_id, int node_id) {
557
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
558
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
559
+ int num_outputs = tree->OutputDimension ();
560
+ auto result = py::array_t <double >(py::detail::any_container<py::ssize_t >({num_outputs}));
561
+ auto accessor = result.mutable_unchecked <1 >();
562
+ for (int i = 0 ; i < num_outputs; i++) {
563
+ accessor (i) = tree->LeafValue (node_id, i);
564
+ }
565
+ return result;
566
+ }
567
+
568
+ int NumNodes (int forest_id, int tree_id) {
569
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
570
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
571
+ return tree->NumValidNodes ();
572
+ }
573
+
574
+ int NumLeaves (int forest_id, int tree_id) {
575
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
576
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
577
+ return tree->NumLeaves ();
578
+ }
579
+
580
+ int NumLeafParents (int forest_id, int tree_id) {
581
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
582
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
583
+ return tree->NumLeafParents ();
584
+ }
585
+
586
+ int NumSplitNodes (int forest_id, int tree_id) {
587
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
588
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
589
+ return tree->NumSplitNodes ();
590
+ }
591
+
592
+ py::array_t <int > Nodes (int forest_id, int tree_id) {
593
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
594
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
595
+ std::vector<std::int32_t > nodes = tree->GetNodes ();
596
+ int num_nodes = nodes.size ();
597
+ auto result = py::array_t <int >(py::detail::any_container<py::ssize_t >({num_nodes}));
598
+ auto accessor = result.mutable_unchecked <1 >();
599
+ for (int i = 0 ; i < num_nodes; i++) {
600
+ accessor (i) = nodes.at (i);
601
+ }
602
+ return result;
603
+ }
604
+
605
+ py::array_t <int > Leaves (int forest_id, int tree_id) {
606
+ StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble (forest_id);
607
+ StochTree::Tree* tree = ensemble->GetTree (tree_id);
608
+ std::vector<std::int32_t > leaves = tree->GetLeaves ();
609
+ int num_leaves = leaves.size ();
610
+ auto result = py::array_t <int >(py::detail::any_container<py::ssize_t >({num_leaves}));
611
+ auto accessor = result.mutable_unchecked <1 >();
612
+ for (int i = 0 ; i < num_leaves; i++) {
613
+ accessor (i) = leaves.at (i);
614
+ }
615
+ return result;
616
+ }
617
+
486
618
private:
487
619
std::unique_ptr<StochTree::ForestContainer> forest_samples_;
488
620
int num_trees_;
@@ -1044,8 +1176,25 @@ PYBIND11_MODULE(stochtree_cpp, m) {
1044
1176
.def (" GetForestSplitCounts" , &ForestContainerCpp::GetForestSplitCounts)
1045
1177
.def (" GetOverallSplitCounts" , &ForestContainerCpp::GetOverallSplitCounts)
1046
1178
.def (" GetGranularSplitCounts" , &ForestContainerCpp::GetGranularSplitCounts)
1179
+ .def (" NumLeavesForest" , &ForestContainerCpp::NumLeavesForest)
1180
+ .def (" SumLeafSquared" , &ForestContainerCpp::SumLeafSquared)
1181
+ .def (" IsLeafNode" , &ForestContainerCpp::IsLeafNode)
1182
+ .def (" IsNumericSplitNode" , &ForestContainerCpp::IsNumericSplitNode)
1183
+ .def (" IsCategoricalSplitNode" , &ForestContainerCpp::IsCategoricalSplitNode)
1184
+ .def (" ParentNode" , &ForestContainerCpp::ParentNode)
1185
+ .def (" LeftChildNode" , &ForestContainerCpp::LeftChildNode)
1186
+ .def (" RightChildNode" , &ForestContainerCpp::RightChildNode)
1187
+ .def (" SplitIndex" , &ForestContainerCpp::SplitIndex)
1188
+ .def (" NodeDepth" , &ForestContainerCpp::NodeDepth)
1189
+ .def (" SplitThreshold" , &ForestContainerCpp::SplitThreshold)
1190
+ .def (" SplitCategories" , &ForestContainerCpp::SplitCategories)
1191
+ .def (" NodeLeafValues" , &ForestContainerCpp::NodeLeafValues)
1192
+ .def (" NumNodes" , &ForestContainerCpp::NumNodes)
1047
1193
.def (" NumLeaves" , &ForestContainerCpp::NumLeaves)
1048
- .def (" SumLeafSquared" , &ForestContainerCpp::SumLeafSquared);
1194
+ .def (" NumLeafParents" , &ForestContainerCpp::NumLeafParents)
1195
+ .def (" NumSplitNodes" , &ForestContainerCpp::NumSplitNodes)
1196
+ .def (" Nodes" , &ForestContainerCpp::Nodes)
1197
+ .def (" Leaves" , &ForestContainerCpp::Leaves);
1049
1198
1050
1199
py::class_<ForestSamplerCpp>(m, " ForestSamplerCpp" )
1051
1200
.def (py::init<ForestDatasetCpp&, py::array_t <int >, int , data_size_t , double , double , int , int >())
0 commit comments