|
4 | 4 | #include <stochtree/leaf_model.h>
|
5 | 5 | #include <stochtree/log.h>
|
6 | 6 | #include <stochtree/partition_tracker.h>
|
| 7 | +#include <stochtree/tree_sampler.h> |
7 | 8 | #include <iostream>
|
8 | 9 | #include <memory>
|
9 | 10 | #include <vector>
|
@@ -49,9 +50,10 @@ TEST(LeafConstantModel, FullEnumeration) {
|
49 | 50 | StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau);
|
50 | 51 |
|
51 | 52 | // Evaluate all possible cutpoints
|
52 |
| - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, |
53 |
| - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, |
54 |
| - feature_types); |
| 53 | + StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( |
| 54 | + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, |
| 55 | + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types |
| 56 | + ); |
55 | 57 |
|
56 | 58 | // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
|
57 | 59 | ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1);
|
@@ -107,9 +109,10 @@ TEST(LeafConstantModel, CutpointThinning) {
|
107 | 109 | StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau);
|
108 | 110 |
|
109 | 111 | // Evaluate all possible cutpoints
|
110 |
| - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, |
111 |
| - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, |
112 |
| - feature_types); |
| 112 | + StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( |
| 113 | + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, |
| 114 | + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types |
| 115 | + ); |
113 | 116 |
|
114 | 117 | // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
|
115 | 118 | ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1);
|
@@ -165,9 +168,10 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
|
165 | 168 | StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau);
|
166 | 169 |
|
167 | 170 | // Evaluate all possible cutpoints
|
168 |
| - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, |
169 |
| - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, |
170 |
| - feature_types); |
| 171 | + StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>( |
| 172 | + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, |
| 173 | + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types |
| 174 | + ); |
171 | 175 |
|
172 | 176 | // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
|
173 | 177 | ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1);
|
@@ -224,9 +228,11 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
|
224 | 228 | StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau);
|
225 | 229 |
|
226 | 230 | // Evaluate all possible cutpoints
|
227 |
| - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, |
228 |
| - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, |
229 |
| - feature_types); |
| 231 | + StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>( |
| 232 | + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, |
| 233 | + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types |
| 234 | + ); |
| 235 | + |
230 | 236 |
|
231 | 237 | // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
|
232 | 238 | ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1);
|
|
0 commit comments