Skip to content

Commit 45a6eda

Browse files
authored
Merge pull request #194 from StochasticTree/var-weights-update-hotfix
Add ability to update variance weights
2 parents f6aed65 + a05c9af commit 45a6eda

File tree

15 files changed

+996
-5
lines changed

15 files changed

+996
-5
lines changed

R/cpp11.R

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,26 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
3636
invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis))
3737
}
3838

39+
forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
40+
invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
41+
}
42+
3943
forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
4044
invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights))
4145
}
4246

47+
forest_dataset_get_covariates_cpp <- function(dataset_ptr) {
48+
.Call(`_stochtree_forest_dataset_get_covariates_cpp`, dataset_ptr)
49+
}
50+
51+
forest_dataset_get_basis_cpp <- function(dataset_ptr) {
52+
.Call(`_stochtree_forest_dataset_get_basis_cpp`, dataset_ptr)
53+
}
54+
55+
forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
56+
.Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr)
57+
}
58+
4359
create_column_vector_cpp <- function(outcome) {
4460
.Call(`_stochtree_create_column_vector_cpp`, outcome)
4561
}
@@ -68,6 +84,22 @@ create_rfx_dataset_cpp <- function() {
6884
.Call(`_stochtree_create_rfx_dataset_cpp`)
6985
}
7086

87+
rfx_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
88+
invisible(.Call(`_stochtree_rfx_dataset_update_basis_cpp`, dataset_ptr, basis))
89+
}
90+
91+
rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
92+
invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
93+
}
94+
95+
rfx_dataset_update_group_labels_cpp <- function(dataset_ptr, group_labels) {
96+
invisible(.Call(`_stochtree_rfx_dataset_update_group_labels_cpp`, dataset_ptr, group_labels))
97+
}
98+
99+
rfx_dataset_num_basis_cpp <- function(dataset) {
100+
.Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset)
101+
}
102+
71103
rfx_dataset_num_rows_cpp <- function(dataset) {
72104
.Call(`_stochtree_rfx_dataset_num_rows_cpp`, dataset)
73105
}
@@ -96,6 +128,18 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
96128
invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights))
97129
}
98130

131+
rfx_dataset_get_group_labels_cpp <- function(dataset_ptr) {
132+
.Call(`_stochtree_rfx_dataset_get_group_labels_cpp`, dataset_ptr)
133+
}
134+
135+
rfx_dataset_get_basis_cpp <- function(dataset_ptr) {
136+
.Call(`_stochtree_rfx_dataset_get_basis_cpp`, dataset_ptr)
137+
}
138+
139+
rfx_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
140+
.Call(`_stochtree_rfx_dataset_get_variance_weights_cpp`, dataset_ptr)
141+
}
142+
99143
rfx_container_cpp <- function(num_components, num_groups) {
100144
.Call(`_stochtree_rfx_container_cpp`, num_components, num_groups)
101145
}

R/data.R

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@ ForestDataset <- R6::R6Class(
3636
update_basis = function(basis) {
3737
stopifnot(self$has_basis())
3838
forest_dataset_update_basis_cpp(self$data_ptr, basis)
39-
},
39+
},
40+
41+
#' @description
42+
#' Update variance_weights in a dataset
43+
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
44+
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.
45+
update_variance_weights = function(variance_weights, exponentiate = F) {
46+
stopifnot(self$has_variance_weights())
47+
forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
48+
},
4049

4150
#' @description
4251
#' Return number of observations in a `ForestDataset` object
@@ -59,6 +68,27 @@ ForestDataset <- R6::R6Class(
5968
return(dataset_num_basis_cpp(self$data_ptr))
6069
},
6170

71+
#' @description
72+
#' Return covariates as an R matrix
73+
#' @return Covariate data
74+
get_covariates = function() {
75+
return(forest_dataset_get_covariates_cpp(self$data_ptr))
76+
},
77+
78+
#' @description
79+
#' Return bases as an R matrix
80+
#' @return Basis data
81+
get_basis = function() {
82+
return(forest_dataset_get_basis_cpp(self$data_ptr))
83+
},
84+
85+
#' @description
86+
#' Return variance weights as an R vector
87+
#' @return Variance weight data
88+
get_variance_weights = function() {
89+
return(forest_dataset_get_variance_weights_cpp(self$data_ptr))
90+
},
91+
6292
#' @description
6393
#' Whether or not a dataset has a basis matrix
6494
#' @return True if basis matrix is loaded, false otherwise
@@ -190,11 +220,56 @@ RandomEffectsDataset <- R6::R6Class(
190220
}
191221
},
192222

223+
#' @description
224+
#' Update basis matrix in a dataset
225+
#' @param basis Updated matrix of bases used to define random slopes / intercepts
226+
update_basis = function(basis) {
227+
stopifnot(self$has_basis())
228+
rfx_dataset_update_basis_cpp(self$data_ptr, basis)
229+
},
230+
231+
#' @description
232+
#' Update variance_weights in a dataset
233+
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
234+
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.
235+
update_variance_weights = function(variance_weights, exponentiate = F) {
236+
stopifnot(self$has_variance_weights())
237+
rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
238+
},
239+
193240
#' @description
194241
#' Return number of observations in a `RandomEffectsDataset` object
195242
#' @return Observation count
196243
num_observations = function() {
197244
return(rfx_dataset_num_rows_cpp(self$data_ptr))
245+
},
246+
247+
#' @description
248+
#' Return dimension of the basis matrix in a `RandomEffectsDataset` object
249+
#' @return Basis vector count
250+
num_basis = function() {
251+
return(rfx_dataset_num_basis_cpp(self$data_ptr))
252+
},
253+
254+
#' @description
255+
#' Return group labels as an R vector
256+
#' @return Group label data
257+
get_group_labels = function() {
258+
return(rfx_dataset_get_group_labels_cpp(self$data_ptr))
259+
},
260+
261+
#' @description
262+
#' Return bases as an R matrix
263+
#' @return Basis data
264+
get_basis = function() {
265+
return(rfx_dataset_get_basis_cpp(self$data_ptr))
266+
},
267+
268+
#' @description
269+
#' Return variance weights as an R vector
270+
#' @return Variance weight data
271+
get_variance_weights = function() {
272+
return(rfx_dataset_get_variance_weights_cpp(self$data_ptr))
198273
},
199274

200275
#' @description

include/stochtree/data.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ class RandomEffectsDataset {
497497
*/
498498
void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
499499
basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major);
500+
num_basis_ = num_col;
500501
has_basis_ = true;
501502
}
502503
/*!
@@ -509,6 +510,64 @@ class RandomEffectsDataset {
509510
var_weights_ = ColumnVector(data_ptr, num_row);
510511
has_var_weights_ = true;
511512
}
513+
/*!
514+
* \brief Update the data in the internal basis matrix to new values stored in a raw double array
515+
*
516+
* \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix
517+
* \param num_row Number of rows in the basis matrix
518+
* \param num_col Number of columns in the basis matrix
519+
* \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion
520+
*/
521+
void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
522+
CHECK(has_basis_);
523+
CHECK_EQ(num_col, num_basis_);
524+
// Copy data from R / Python process memory to Eigen matrix
525+
double temp_value;
526+
for (data_size_t i = 0; i < num_row; ++i) {
527+
for (int j = 0; j < num_col; ++j) {
528+
if (is_row_major){
529+
// Numpy 2-d arrays are stored in "row major" order
530+
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_col) * i + j));
531+
} else {
532+
// R matrices are stored in "column major" order
533+
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_row) * j + i));
534+
}
535+
basis_.SetElement(i, j, temp_value);
536+
}
537+
}
538+
}
539+
/*!
540+
* \brief Update the data in the internal variance weight vector to new values stored in a raw double array
541+
*
542+
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
543+
* \param num_row Number of rows in the weight vector
544+
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
545+
*/
546+
void UpdateVarWeights(double* data_ptr, data_size_t num_row, bool exponentiate = true) {
547+
CHECK(has_var_weights_);
548+
// Copy data from R / Python process memory to Eigen vector
549+
double temp_value;
550+
for (data_size_t i = 0; i < num_row; ++i) {
551+
if (exponentiate) temp_value = std::exp(static_cast<double>(*(data_ptr + i)));
552+
else temp_value = static_cast<double>(*(data_ptr + i));
553+
var_weights_.SetElement(i, temp_value);
554+
}
555+
}
556+
/*!
557+
* \brief Update a RandomEffectsDataset's group indices
558+
*
559+
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
560+
* \param num_row Number of rows in the weight vector
561+
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
562+
*/
563+
void UpdateGroupLabels(std::vector<int32_t>& group_labels, data_size_t num_row) {
564+
CHECK(has_group_labels_);
565+
CHECK_EQ(this->NumObservations(), num_row)
566+
// Copy data from R / Python process memory to internal vector
567+
for (data_size_t i = 0; i < num_row; ++i) {
568+
group_labels_[i] = group_labels[i];
569+
}
570+
}
512571
/*!
513572
* \brief Copy / load group indices for random effects
514573
*
@@ -570,6 +629,7 @@ class RandomEffectsDataset {
570629
ColumnMatrix basis_;
571630
ColumnVector var_weights_;
572631
std::vector<int32_t> group_labels_;
632+
int num_basis_{0};
573633
bool has_basis_{false};
574634
bool has_var_weights_{false};
575635
bool has_group_labels_{false};

man/ForestDataset.Rd

Lines changed: 62 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ForestModel.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)