Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export(computeForestLeafVariances)
export(computeMaxLeafIndex)
export(convertBARTModelToJson)
export(convertBCFModelToJson)
export(convertPreprocessorToJson)
export(createBARTModelFromCombinedJson)
export(createBARTModelFromCombinedJsonString)
export(createBARTModelFromJson)
Expand All @@ -31,6 +32,8 @@ export(createForestCovariatesFromMetadata)
export(createForestDataset)
export(createForestModel)
export(createOutcome)
export(createPreprocessorFromJson)
export(createPreprocessorFromJsonString)
export(createRNG)
export(createRandomEffectSamples)
export(createRandomEffectsDataset)
Expand Down Expand Up @@ -69,6 +72,7 @@ export(saveBARTModelToJsonFile)
export(saveBARTModelToJsonString)
export(saveBCFModelToJsonFile)
export(saveBCFModelToJsonString)
export(savePreprocessorToJsonString)
importFrom(R6,R6Class)
importFrom(stats,coef)
importFrom(stats,lm)
Expand Down
26 changes: 25 additions & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,12 @@ convertBARTModelToJson <- function(object){
jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids)
}

# Add covariate preprocessor metadata
preprocessor_metadata_string <- savePreprocessorToJsonString(
object$train_set_metadata
)
jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)

return(jsonobj)
}

Expand Down Expand Up @@ -1322,7 +1328,7 @@ saveBARTModelToJsonFile <- function(object, filename){
#' Convert the persistent aspects of a BART model to (in-memory) JSON string
#'
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
#' @return JSON string
#' @return in-memory JSON string
#' @export
#'
#' @examples
Expand Down Expand Up @@ -1460,6 +1466,12 @@ createBARTModelFromJson <- function(json_object){
output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0)
}

# Unpack covariate preprocessor
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
preprocessor_metadata_string
)

class(output) <- "bartmodel"
return(output)
}
Expand Down Expand Up @@ -1686,6 +1698,12 @@ createBARTModelFromCombinedJson <- function(json_object_list){
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
}

# Unpack covariate preprocessor
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
preprocessor_metadata_string
)

class(output) <- "bartmodel"
return(output)
}
Expand Down Expand Up @@ -1832,6 +1850,12 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
}

# Unpack covariate preprocessor
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
preprocessor_metadata_string
)

class(output) <- "bartmodel"
return(output)
}
20 changes: 19 additions & 1 deletion R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,12 @@ convertBCFModelToJson <- function(object){
jsonobj$add_string("bart_propensity_model", bart_propensity_string)
}

# Add covariate preprocessor metadata
preprocessor_metadata_string <- savePreprocessorToJsonString(
object$train_set_metadata
)
jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)

return(jsonobj)
}

Expand All @@ -1716,7 +1722,7 @@ convertBCFModelToJson <- function(object){
#' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @param filename String of filepath, must end in ".json"
#'
#' @return NULL
#' @return in-memory JSON string
#' @export
#'
#' @examples
Expand Down Expand Up @@ -2018,6 +2024,12 @@ createBCFModelFromJson <- function(json_object){
)
}

# Unpack covariate preprocessor
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
preprocessor_metadata_string
)

class(output) <- "bcf"
return(output)
}
Expand Down Expand Up @@ -2393,6 +2405,12 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
}

# Unpack covariate preprocessor
preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata")
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
preprocessor_metadata_string
)

class(output) <- "bcf"
return(output)
}
Expand Down
32 changes: 32 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,14 @@ json_add_double_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_double_cpp`, json_ptr, field_name, field_value))
}

json_add_integer_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_integer_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
}

json_add_integer_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_integer_cpp`, json_ptr, field_name, field_value))
}

json_add_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
}
Expand All @@ -628,6 +636,14 @@ json_add_vector_cpp <- function(json_ptr, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_vector_cpp`, json_ptr, field_name, field_vector))
}

json_add_integer_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_integer_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
}

json_add_integer_vector_cpp <- function(json_ptr, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_integer_vector_cpp`, json_ptr, field_name, field_vector))
}

json_add_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
}
Expand Down Expand Up @@ -660,6 +676,14 @@ json_extract_double_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_double_cpp`, json_ptr, field_name)
}

json_extract_integer_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_integer_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

json_extract_integer_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_integer_cpp`, json_ptr, field_name)
}

json_extract_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}
Expand All @@ -684,6 +708,14 @@ json_extract_vector_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_vector_cpp`, json_ptr, field_name)
}

json_extract_integer_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_integer_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

json_extract_integer_vector_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_integer_vector_cpp`, json_ptr, field_name)
}

json_extract_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}
Expand Down
63 changes: 62 additions & 1 deletion R/serialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ CppJson <- R6::R6Class(
}
},

#' @description
#' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be added to json
#' @param field_value Integer value of the field to be added to json
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
#' @return NULL
add_integer = function(field_name, field_value, subfolder_name = NULL) {
if (is.null(subfolder_name)) {
json_add_integer_cpp(self$json_ptr, field_name, field_value)
} else {
json_add_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value)
}
},

#' @description
#' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be added to json
Expand Down Expand Up @@ -110,7 +124,7 @@ CppJson <- R6::R6Class(
},

#' @description
#' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be added to json
#' @param field_vector Vector to be stored in json
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
Expand All @@ -124,6 +138,21 @@ CppJson <- R6::R6Class(
}
},

#' @description
#' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be added to json
#' @param field_vector Vector to be stored in json
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value
#' @return NULL
add_integer_vector = function(field_name, field_vector, subfolder_name = NULL) {
field_vector <- as.numeric(field_vector)
if (is.null(subfolder_name)) {
json_add_integer_vector_cpp(self$json_ptr, field_name, field_vector)
} else {
json_add_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_vector)
}
},

#' @description
#' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be added to json
Expand Down Expand Up @@ -184,6 +213,22 @@ CppJson <- R6::R6Class(
return(result)
},

#' @description
#' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be accessed from json
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored
#' @return NULL
get_integer = function(field_name, subfolder_name = NULL) {
if (is.null(subfolder_name)) {
stopifnot(json_contains_field_cpp(self$json_ptr, field_name))
result <- json_extract_integer_cpp(self$json_ptr, field_name)
} else {
stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name))
result <- json_extract_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name)
}
return(result)
},

#' @description
#' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be accessed from json
Expand Down Expand Up @@ -232,6 +277,22 @@ CppJson <- R6::R6Class(
return(result)
},

#' @description
#' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be accessed from json
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored
#' @return NULL
get_integer_vector = function(field_name, subfolder_name = NULL) {
if (is.null(subfolder_name)) {
stopifnot(json_contains_field_cpp(self$json_ptr, field_name))
result <- json_extract_integer_vector_cpp(self$json_ptr, field_name)
} else {
stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name))
result <- json_extract_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name)
}
return(result)
},

#' @description
#' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
#' @param field_name The name of the field to be accessed from json
Expand Down
Loading
Loading