Skip to content

Commit 19fb779

Browse files
authored
Merge pull request #111 from StochasticTree/python-serialization-expansion
Added demo for json file roundtrip
2 parents 242bea7 + b2e940e commit 19fb779

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

demo/notebooks/serialization.ipynb

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"metadata": {},
2828
"outputs": [],
2929
"source": [
30+
"import json\n",
3031
"import numpy as np\n",
3132
"import pandas as pd\n",
3233
"import seaborn as sns\n",
@@ -242,6 +243,84 @@
242243
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
243244
"plt.show()"
244245
]
246+
},
247+
{
248+
"cell_type": "markdown",
249+
"metadata": {},
250+
"source": [
251+
"Save to JSON file"
252+
]
253+
},
254+
{
255+
"cell_type": "code",
256+
"execution_count": null,
257+
"metadata": {},
258+
"outputs": [],
259+
"source": [
260+
"with open('bart.json', 'w') as f:\n",
261+
" bart_json_python = json.loads(bart_json_string)\n",
262+
" json.dump(bart_json_python, f)"
263+
]
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {},
268+
"source": [
269+
"Reload from JSON file"
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": null,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": [
278+
"with open('bart.json', 'r') as f:\n",
279+
" bart_json_python_reload = json.load(f)\n",
280+
"bart_json_string_reload = json.dumps(bart_json_python_reload)\n",
281+
"bart_model_file_deserialized = BARTModel()\n",
282+
"bart_model_file_deserialized.from_json(bart_json_string_reload)"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"Compare predictions"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"metadata": {},
296+
"outputs": [],
297+
"source": [
298+
"y_hat_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)\n",
299+
"y_avg_mcmc_file_deserialized = np.squeeze(y_hat_file_deserialized).mean(axis = 1, keepdims = True)\n",
300+
"y_df = pd.DataFrame(np.concatenate((y_avg_mcmc, y_avg_mcmc_file_deserialized), axis = 1), columns=[\"Original model\", \"Deserialized model\"])\n",
301+
"sns.scatterplot(data=y_df, x=\"Original model\", y=\"Deserialized model\")\n",
302+
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
303+
"plt.show()"
304+
]
305+
},
306+
{
307+
"cell_type": "markdown",
308+
"metadata": {},
309+
"source": [
310+
"Compare parameter samples"
311+
]
312+
},
313+
{
314+
"cell_type": "code",
315+
"execution_count": null,
316+
"metadata": {},
317+
"outputs": [],
318+
"source": [
319+
"sigma2_df = pd.DataFrame(np.c_[bart_model.global_var_samples, bart_model_file_deserialized.global_var_samples], columns=[\"Original model\", \"Deserialized model\"])\n",
320+
"sns.scatterplot(data=sigma2_df, x=\"Original model\", y=\"Deserialized model\")\n",
321+
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
322+
"plt.show()"
323+
]
245324
}
246325
],
247326
"metadata": {

0 commit comments

Comments
 (0)