Skip to content

Commit 06efbb5

Browse files
committed
Updated unit tests
1 parent 64c19e8 commit 06efbb5

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

test/cpp/test_model.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stochtree/leaf_model.h>
55
#include <stochtree/log.h>
66
#include <stochtree/partition_tracker.h>
7+
#include <stochtree/tree_sampler.h>
78
#include <iostream>
89
#include <memory>
910
#include <vector>
@@ -49,9 +50,10 @@ TEST(LeafConstantModel, FullEnumeration) {
4950
StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau);
5051

5152
// Evaluate all possible cutpoints
52-
leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features,
53-
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights,
54-
feature_types);
53+
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
54+
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
55+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
56+
);
5557

5658
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
5759
ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1);
@@ -107,9 +109,10 @@ TEST(LeafConstantModel, CutpointThinning) {
107109
StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau);
108110

109111
// Evaluate all possible cutpoints
110-
leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features,
111-
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights,
112-
feature_types);
112+
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113+
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
114+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
115+
);
113116

114117
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
115118
ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1);
@@ -165,9 +168,10 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
165168
StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau);
166169

167170
// Evaluate all possible cutpoints
168-
leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features,
169-
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights,
170-
feature_types);
171+
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
172+
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
173+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
174+
);
171175

172176
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
173177
ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1);
@@ -224,9 +228,11 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
224228
StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau);
225229

226230
// Evaluate all possible cutpoints
227-
leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features,
228-
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights,
229-
feature_types);
231+
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
232+
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
233+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
234+
);
235+
230236

231237
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
232238
ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1);

0 commit comments

Comments
 (0)