Skip to content

Commit e63d809

Browse files
authored
Merge pull request #130 from StochasticTree/serialize-covariate-preprocessor
Add serialization for covariate preprocessor objects
2 parents 79317c8 + 0ef2cdd commit e63d809

23 files changed

+1185
-113
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(computeForestLeafVariances)
1212
export(computeMaxLeafIndex)
1313
export(convertBARTModelToJson)
1414
export(convertBCFModelToJson)
15+
export(convertPreprocessorToJson)
1516
export(createBARTModelFromCombinedJson)
1617
export(createBARTModelFromCombinedJsonString)
1718
export(createBARTModelFromJson)
@@ -31,6 +32,8 @@ export(createForestCovariatesFromMetadata)
3132
export(createForestDataset)
3233
export(createForestModel)
3334
export(createOutcome)
35+
export(createPreprocessorFromJson)
36+
export(createPreprocessorFromJsonString)
3437
export(createRNG)
3538
export(createRandomEffectSamples)
3639
export(createRandomEffectsDataset)
@@ -69,6 +72,7 @@ export(saveBARTModelToJsonFile)
6972
export(saveBARTModelToJsonString)
7073
export(saveBCFModelToJsonFile)
7174
export(saveBCFModelToJsonString)
75+
export(savePreprocessorToJsonString)
7276
importFrom(R6,R6Class)
7377
importFrom(stats,coef)
7478
importFrom(stats,lm)

R/bart.R

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,12 @@ convertBARTModelToJson <- function(object){
12151215
jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids)
12161216
}
12171217

1218+
# Add covariate preprocessor metadata
1219+
preprocessor_metadata_string <- savePreprocessorToJsonString(
1220+
object$train_set_metadata
1221+
)
1222+
jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)
1223+
12181224
return(jsonobj)
12191225
}
12201226

@@ -1322,7 +1328,7 @@ saveBARTModelToJsonFile <- function(object, filename){
13221328
#' Convert the persistent aspects of a BART model to (in-memory) JSON string
13231329
#'
13241330
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
1325-
#' @return JSON string
1331+
#' @return in-memory JSON string
13261332
#' @export
13271333
#'
13281334
#' @examples
@@ -1460,6 +1466,12 @@ createBARTModelFromJson <- function(json_object){
14601466
output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0)
14611467
}
14621468

1469+
# Unpack covariate preprocessor
1470+
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
1471+
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
1472+
preprocessor_metadata_string
1473+
)
1474+
14631475
class(output) <- "bartmodel"
14641476
return(output)
14651477
}
@@ -1686,6 +1698,12 @@ createBARTModelFromCombinedJson <- function(json_object_list){
16861698
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
16871699
}
16881700

1701+
# Unpack covariate preprocessor
1702+
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
1703+
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
1704+
preprocessor_metadata_string
1705+
)
1706+
16891707
class(output) <- "bartmodel"
16901708
return(output)
16911709
}
@@ -1832,6 +1850,12 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
18321850
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
18331851
}
18341852

1853+
# Unpack covariate preprocessor
1854+
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
1855+
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
1856+
preprocessor_metadata_string
1857+
)
1858+
18351859
class(output) <- "bartmodel"
18361860
return(output)
18371861
}

R/bcf.R

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,12 @@ convertBCFModelToJson <- function(object){
17081708
jsonobj$add_string("bart_propensity_model", bart_propensity_string)
17091709
}
17101710

1711+
# Add covariate preprocessor metadata
1712+
preprocessor_metadata_string <- savePreprocessorToJsonString(
1713+
object$train_set_metadata
1714+
)
1715+
jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)
1716+
17111717
return(jsonobj)
17121718
}
17131719

@@ -1716,7 +1722,7 @@ convertBCFModelToJson <- function(object){
17161722
#' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs.
17171723
#' @param filename String of filepath, must end in ".json"
17181724
#'
1719-
#' @return NULL
1725+
#' @return in-memory JSON string
17201726
#' @export
17211727
#'
17221728
#' @examples
@@ -2018,6 +2024,12 @@ createBCFModelFromJson <- function(json_object){
20182024
)
20192025
}
20202026

2027+
# Unpack covariate preprocessor
2028+
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
2029+
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
2030+
preprocessor_metadata_string
2031+
)
2032+
20212033
class(output) <- "bcf"
20222034
return(output)
20232035
}
@@ -2393,6 +2405,12 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
23932405
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
23942406
}
23952407

2408+
# Unpack covariate preprocessor
2409+
preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata")
2410+
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
2411+
preprocessor_metadata_string
2412+
)
2413+
23962414
class(output) <- "bcf"
23972415
return(output)
23982416
}

R/cpp11.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,14 @@ json_add_double_cpp <- function(json_ptr, field_name, field_value) {
612612
invisible(.Call(`_stochtree_json_add_double_cpp`, json_ptr, field_name, field_value))
613613
}
614614

615+
json_add_integer_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
616+
invisible(.Call(`_stochtree_json_add_integer_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
617+
}
618+
619+
json_add_integer_cpp <- function(json_ptr, field_name, field_value) {
620+
invisible(.Call(`_stochtree_json_add_integer_cpp`, json_ptr, field_name, field_value))
621+
}
622+
615623
json_add_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
616624
invisible(.Call(`_stochtree_json_add_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
617625
}
@@ -628,6 +636,14 @@ json_add_vector_cpp <- function(json_ptr, field_name, field_vector) {
628636
invisible(.Call(`_stochtree_json_add_vector_cpp`, json_ptr, field_name, field_vector))
629637
}
630638

639+
json_add_integer_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
640+
invisible(.Call(`_stochtree_json_add_integer_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
641+
}
642+
643+
json_add_integer_vector_cpp <- function(json_ptr, field_name, field_vector) {
644+
invisible(.Call(`_stochtree_json_add_integer_vector_cpp`, json_ptr, field_name, field_vector))
645+
}
646+
631647
json_add_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
632648
invisible(.Call(`_stochtree_json_add_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
633649
}
@@ -660,6 +676,14 @@ json_extract_double_cpp <- function(json_ptr, field_name) {
660676
.Call(`_stochtree_json_extract_double_cpp`, json_ptr, field_name)
661677
}
662678

679+
json_extract_integer_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
680+
.Call(`_stochtree_json_extract_integer_subfolder_cpp`, json_ptr, subfolder_name, field_name)
681+
}
682+
683+
json_extract_integer_cpp <- function(json_ptr, field_name) {
684+
.Call(`_stochtree_json_extract_integer_cpp`, json_ptr, field_name)
685+
}
686+
663687
json_extract_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
664688
.Call(`_stochtree_json_extract_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name)
665689
}
@@ -684,6 +708,14 @@ json_extract_vector_cpp <- function(json_ptr, field_name) {
684708
.Call(`_stochtree_json_extract_vector_cpp`, json_ptr, field_name)
685709
}
686710

711+
json_extract_integer_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
712+
.Call(`_stochtree_json_extract_integer_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
713+
}
714+
715+
json_extract_integer_vector_cpp <- function(json_ptr, field_name) {
716+
.Call(`_stochtree_json_extract_integer_vector_cpp`, json_ptr, field_name)
717+
}
718+
687719
json_extract_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
688720
.Call(`_stochtree_json_extract_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
689721
}

R/serialization.R

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@ CppJson <- R6::R6Class(
8181
}
8282
},
8383

84+
#' @description
85+
#' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name")
86+
#' @param field_name The name of the field to be added to json
87+
#' @param field_value Integer value of the field to be added to json
88+
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
89+
#' @return NULL
90+
add_integer = function(field_name, field_value, subfolder_name = NULL) {
91+
if (is.null(subfolder_name)) {
92+
json_add_integer_cpp(self$json_ptr, field_name, field_value)
93+
} else {
94+
json_add_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value)
95+
}
96+
},
97+
8498
#' @description
8599
#' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name")
86100
#' @param field_name The name of the field to be added to json
@@ -110,7 +124,7 @@ CppJson <- R6::R6Class(
110124
},
111125

112126
#' @description
113-
#' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name")
127+
#' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
114128
#' @param field_name The name of the field to be added to json
115129
#' @param field_vector Vector to be stored in json
116130
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
@@ -124,6 +138,21 @@ CppJson <- R6::R6Class(
124138
}
125139
},
126140

141+
#' @description
142+
#' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
143+
#' @param field_name The name of the field to be added to json
144+
#' @param field_vector Vector to be stored in json
145+
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
146+
#' @return NULL
147+
add_integer_vector = function(field_name, field_vector, subfolder_name = NULL) {
148+
field_vector <- as.numeric(field_vector)
149+
if (is.null(subfolder_name)) {
150+
json_add_integer_vector_cpp(self$json_ptr, field_name, field_vector)
151+
} else {
152+
json_add_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_vector)
153+
}
154+
},
155+
127156
#' @description
128157
#' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name")
129158
#' @param field_name The name of the field to be added to json
@@ -184,6 +213,22 @@ CppJson <- R6::R6Class(
184213
return(result)
185214
},
186215

216+
#' @description
217+
#' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
218+
#' @param field_name The name of the field to be accessed from json
219+
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored
220+
#' @return NULL
221+
get_integer = function(field_name, subfolder_name = NULL) {
222+
if (is.null(subfolder_name)) {
223+
stopifnot(json_contains_field_cpp(self$json_ptr, field_name))
224+
result <- json_extract_integer_cpp(self$json_ptr, field_name)
225+
} else {
226+
stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name))
227+
result <- json_extract_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name)
228+
}
229+
return(result)
230+
},
231+
187232
#' @description
188233
#' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
189234
#' @param field_name The name of the field to be accessed from json
@@ -232,6 +277,22 @@ CppJson <- R6::R6Class(
232277
return(result)
233278
},
234279

280+
#' @description
281+
#' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
282+
#' @param field_name The name of the field to be accessed from json
283+
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored
284+
#' @return NULL
285+
get_integer_vector = function(field_name, subfolder_name = NULL) {
286+
if (is.null(subfolder_name)) {
287+
stopifnot(json_contains_field_cpp(self$json_ptr, field_name))
288+
result <- json_extract_integer_vector_cpp(self$json_ptr, field_name)
289+
} else {
290+
stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name))
291+
result <- json_extract_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name)
292+
}
293+
return(result)
294+
},
295+
235296
#' @description
236297
#' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
237298
#' @param field_name The name of the field to be accessed from json

0 commit comments

Comments
 (0)