Skip to content

Commit 8b1a93f

Browse files
authored
Improve error and memory handling in C extension (#89)
* Fix error and memory handling * Fix tobytes calls * Fix dict key memory leak
1 parent e2363ce commit 8b1a93f

File tree

2 files changed

+88
-40
lines changed

2 files changed

+88
-40
lines changed

accel.c

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,7 +1950,6 @@ static PyObject *read_row_from_packet(
19501950
case ACCEL_OUT_DICTS:
19511951
case ACCEL_OUT_ARROW:
19521952
PyDict_SetItem(py_result, py_state->py_names[i], py_item);
1953-
Py_INCREF(py_state->py_names[i]);
19541953
Py_DECREF(py_item);
19551954
break;
19561955
default:
@@ -2678,8 +2677,19 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
26782677

26792678
exit:
26802679
if (ctypes) free(ctypes);
2681-
if (out_cols) free(out_cols);
2682-
if (mask_cols) free(mask_cols);
2680+
if (out_cols) {
2681+
for (i = 0; i < n_cols; i++) {
2682+
if (out_cols[i]) free(out_cols[i]);
2683+
}
2684+
free(out_cols);
2685+
}
2686+
if (mask_cols) {
2687+
for (i = 0; i < n_cols; i++) {
2688+
if (mask_cols[i]) free(mask_cols[i]);
2689+
}
2690+
free(mask_cols);
2691+
}
2692+
if (out_row_ids) free(out_row_ids);
26832693
if (data_formats) free(data_formats);
26842694
if (item_sizes) free(item_sizes);
26852695

@@ -2943,11 +2953,17 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29432953
out_l = 256 * n_cols;
29442954
out_idx = 0;
29452955
out = malloc(out_l);
2946-
if (!out) goto error;
2956+
if (!out) {
2957+
PyErr_SetString(PyExc_MemoryError, "failed to allocate output buffer");
2958+
goto error;
2959+
}
29472960

29482961
// Get return types
29492962
returns = malloc(sizeof(int) * n_cols);
2950-
if (!returns) goto error;
2963+
if (!returns) {
2964+
PyErr_SetString(PyExc_MemoryError, "failed to allocate returns array");
2965+
goto error;
2966+
}
29512967

29522968
for (i = 0; i < n_cols; i++) {
29532969
PyObject *py_item = PySequence_GetItem(py_returns, i);
@@ -2959,11 +2975,20 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29592975

29602976
// Get column array memory
29612977
cols = calloc(sizeof(char*), n_cols);
2962-
if (!cols) goto error;
2978+
if (!cols) {
2979+
PyErr_SetString(PyExc_MemoryError, "failed to allocate cols array");
2980+
goto error;
2981+
}
29632982
col_types = calloc(sizeof(NumpyColType), n_cols);
2964-
if (!col_types) goto error;
2983+
if (!col_types) {
2984+
PyErr_SetString(PyExc_MemoryError, "failed to allocate col_types array");
2985+
goto error;
2986+
}
29652987
masks = calloc(sizeof(char*), n_cols);
2966-
if (!masks) goto error;
2988+
if (!masks) {
2989+
PyErr_SetString(PyExc_MemoryError, "failed to allocate masks array");
2990+
goto error;
2991+
}
29672992
for (i = 0; i < n_cols; i++) {
29682993
PyObject *py_item = PyList_GetItem(py_cols, i);
29692994
if (!py_item) goto error;
@@ -2996,8 +3021,12 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29963021
#define CHECKMEM(x) \
29973022
if ((out_idx + x) > out_l) { \
29983023
out_l = out_l * 2 + x; \
2999-
out = realloc(out, out_l); \
3000-
if (!out) goto error; \
3024+
char *new_out = realloc(out, out_l); \
3025+
if (!new_out) { \
3026+
PyErr_SetString(PyExc_MemoryError, "failed to reallocate output buffer"); \
3027+
goto error; \
3028+
} \
3029+
out = new_out; \
30013030
}
30023031

30033032
for (j = 0; j < n_rows; j++) {
@@ -4079,10 +4108,10 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
40794108
}
40804109
}
40814110

4082-
py_out = PyMemoryView_FromMemory(out, out_idx, PyBUF_WRITE);
4083-
if (!py_out) goto error;
4111+
py_out = PyBytes_FromStringAndSize(out, out_idx);
40844112

40854113
exit:
4114+
if (out) free(out);
40864115
if (returns) free(returns);
40874116
if (masks) free(masks);
40884117
if (cols) free(cols);
@@ -4091,7 +4120,6 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
40914120
return py_out;
40924121

40934122
error:
4094-
if (!py_out && out) free(out);
40954123
Py_XDECREF(py_out);
40964124
py_out = NULL;
40974125

@@ -4471,8 +4499,12 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44714499
#define CHECKMEM(x) \
44724500
if ((out_idx + x) > out_l) { \
44734501
out_l = out_l * 2 + x; \
4474-
out = realloc(out, out_l); \
4475-
if (!out) goto error; \
4502+
char *new_out = realloc(out, out_l); \
4503+
if (!new_out) { \
4504+
PyErr_SetString(PyExc_MemoryError, "failed to reallocate output buffer"); \
4505+
goto error; \
4506+
} \
4507+
out = new_out; \
44764508
}
44774509

44784510
py_rows_iter = PyObject_GetIter(py_rows);
@@ -4483,12 +4515,20 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44834515

44844516
while ((py_row = PyIter_Next(py_rows_iter))) {
44854517
py_row_iter = PyObject_GetIter(py_row);
4486-
if (!py_row_iter) goto error;
4518+
if (!py_row_iter) {
4519+
Py_DECREF(py_row);
4520+
goto error;
4521+
}
44874522

44884523
// First item is always a row ID
44894524
py_item = PyIter_Next(py_row_ids_iter);
4490-
if (!py_item) goto error;
4525+
if (!py_item) {
4526+
Py_DECREF(py_row_iter);
4527+
Py_DECREF(py_row);
4528+
goto error;
4529+
}
44914530
row_id = (int64_t)PyLong_AsLongLong(py_item);
4531+
Py_DECREF(py_item);
44924532

44934533
CHECKMEM(8);
44944534
memcpy(out+out_idx, &row_id, 8);
@@ -4631,12 +4671,16 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
46314671
out_idx += 8;
46324672
} else {
46334673
PyObject *py_bytes = PyUnicode_AsEncodedString(py_item, "utf-8", "strict");
4634-
if (!py_bytes) goto error;
4674+
if (!py_bytes) {
4675+
Py_DECREF(py_item);
4676+
goto error;
4677+
}
46354678

46364679
char *str = NULL;
46374680
Py_ssize_t str_l = 0;
46384681
if (PyBytes_AsStringAndSize(py_bytes, &str, &str_l) < 0) {
46394682
Py_DECREF(py_bytes);
4683+
Py_DECREF(py_item);
46404684
goto error;
46414685
}
46424686

@@ -4671,6 +4715,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
46714715
char *str = NULL;
46724716
Py_ssize_t str_l = 0;
46734717
if (PyBytes_AsStringAndSize(py_item, &str, &str_l) < 0) {
4718+
Py_DECREF(py_item);
46744719
goto error;
46754720
}
46764721

@@ -4684,6 +4729,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
46844729
break;
46854730

46864731
default:
4732+
Py_DECREF(py_item);
46874733
goto error;
46884734
}
46894735

@@ -4693,14 +4739,17 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
46934739
i++;
46944740
}
46954741

4742+
Py_DECREF(py_row_iter);
46964743
Py_DECREF(py_row);
4744+
py_row_iter = NULL;
46974745
py_row = NULL;
46984746
}
46994747

4700-
py_out = PyMemoryView_FromMemory(out, out_idx, PyBUF_WRITE);
4701-
if (!py_out) goto error;
4748+
// Convert the output buffer to a Python bytes object and free the buffer
4749+
py_out = PyBytes_FromStringAndSize(out, out_idx);
47024750

47034751
exit:
4752+
if (out) free(out);
47044753
if (returns) free(returns);
47054754

47064755
Py_XDECREF(py_item);
@@ -4712,7 +4761,6 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
47124761
return py_out;
47134762

47144763
error:
4715-
if (!py_out && out) free(out);
47164764
Py_XDECREF(py_out);
47174765
py_out = NULL;
47184766

@@ -4839,7 +4887,7 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
48394887

48404888
PyObj.create_numpy_array_kwargs = PyDict_New();
48414889
if (!PyObj.create_numpy_array_kwargs) goto error;
4842-
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs, "copy", Py_False)) {
4890+
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs, "copy", Py_True)) {
48434891
goto error;
48444892
}
48454893

singlestoredb/tests/test_ext_func_data.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ class TestRowdat1(unittest.TestCase):
269269
def test_numpy_accel(self):
270270
dump_res = rowdat_1._dump_numpy_accel(
271271
col_types, numpy_row_ids, numpy_data,
272-
).tobytes()
272+
)
273273
load_res = rowdat_1._load_numpy_accel(col_spec, dump_res)
274274

275275
ids = load_res[0]
@@ -294,7 +294,7 @@ def test_numpy_accel(self):
294294
def test_numpy(self):
295295
dump_res = rowdat_1._dump_numpy(
296296
col_types, numpy_row_ids, numpy_data,
297-
).tobytes()
297+
)
298298
load_res = rowdat_1._load_numpy(col_spec, dump_res)
299299

300300
ids = load_res[0]
@@ -387,7 +387,7 @@ def test_numpy_accel_limits(self, name, dtype, data, res):
387387
with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'):
388388
rowdat_1._dump_numpy_accel(
389389
[dtype], numpy_row_ids, [(arr, None)],
390-
).tobytes()
390+
)
391391

392392
# Pure Python
393393
if 'mediumint exceeds' in name:
@@ -396,21 +396,21 @@ def test_numpy_accel_limits(self, name, dtype, data, res):
396396
with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'):
397397
rowdat_1._dump_numpy(
398398
[dtype], numpy_row_ids, [(arr, None)],
399-
).tobytes()
399+
)
400400

401401
else:
402402
# Accelerated
403403
dump_res = rowdat_1._dump_numpy_accel(
404404
[dtype], numpy_row_ids, [(arr, None)],
405-
).tobytes()
405+
)
406406
load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res)
407407
assert load_res[1][0][0] == res, \
408408
f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}'
409409

410410
# Pure Python
411411
dump_res = rowdat_1._dump_numpy(
412412
[dtype], numpy_row_ids, [(arr, None)],
413-
).tobytes()
413+
)
414414
load_res = rowdat_1._load_numpy([('x', dtype)], dump_res)
415415
assert load_res[1][0][0] == res, \
416416
f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}'
@@ -788,7 +788,7 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
788788
# Accelerated
789789
dump_res = rowdat_1._dump_numpy_accel(
790790
[dtype], numpy_row_ids, [(data, None)],
791-
).tobytes()
791+
)
792792
load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res)
793793

794794
if name == 'double from float32':
@@ -800,7 +800,7 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
800800
# Pure Python
801801
dump_res = rowdat_1._dump_numpy(
802802
[dtype], numpy_row_ids, [(data, None)],
803-
).tobytes()
803+
)
804804
load_res = rowdat_1._load_numpy([('x', dtype)], dump_res)
805805

806806
if name == 'double from float32':
@@ -812,7 +812,7 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
812812
def test_python(self):
813813
dump_res = rowdat_1._dump(
814814
col_types, py_row_ids, py_col_data,
815-
).tobytes()
815+
)
816816
load_res = rowdat_1._load(col_spec, dump_res)
817817

818818
ids = load_res[0]
@@ -824,7 +824,7 @@ def test_python(self):
824824
def test_python_accel(self):
825825
dump_res = rowdat_1._dump_accel(
826826
col_types, py_row_ids, py_col_data,
827-
).tobytes()
827+
)
828828
load_res = rowdat_1._load_accel(col_spec, dump_res)
829829

830830
ids = load_res[0]
@@ -836,7 +836,7 @@ def test_python_accel(self):
836836
def test_polars(self):
837837
dump_res = rowdat_1._dump_polars(
838838
col_types, polars_row_ids, polars_data,
839-
).tobytes()
839+
)
840840
load_res = rowdat_1._load_polars(col_spec, dump_res)
841841

842842
ids = load_res[0]
@@ -861,7 +861,7 @@ def test_polars(self):
861861
def test_polars_accel(self):
862862
dump_res = rowdat_1._dump_polars_accel(
863863
col_types, polars_row_ids, polars_data,
864-
).tobytes()
864+
)
865865
load_res = rowdat_1._load_polars_accel(col_spec, dump_res)
866866

867867
ids = load_res[0]
@@ -886,7 +886,7 @@ def test_polars_accel(self):
886886
def test_pandas(self):
887887
dump_res = rowdat_1._dump_pandas(
888888
col_types, pandas_row_ids, pandas_data,
889-
).tobytes()
889+
)
890890
load_res = rowdat_1._load_pandas(col_spec, dump_res)
891891

892892
ids = load_res[0]
@@ -911,7 +911,7 @@ def test_pandas(self):
911911
def test_pandas_accel(self):
912912
dump_res = rowdat_1._dump_pandas_accel(
913913
col_types, pandas_row_ids, pandas_data,
914-
).tobytes()
914+
)
915915
load_res = rowdat_1._load_pandas_accel(col_spec, dump_res)
916916

917917
ids = load_res[0]
@@ -936,7 +936,7 @@ def test_pandas_accel(self):
936936
def test_pyarrow(self):
937937
dump_res = rowdat_1._dump_arrow(
938938
col_types, pyarrow_row_ids, pyarrow_data,
939-
).tobytes()
939+
)
940940
load_res = rowdat_1._load_arrow(col_spec, dump_res)
941941

942942
ids = load_res[0]
@@ -961,7 +961,7 @@ def test_pyarrow(self):
961961
def test_pyarrow_accel(self):
962962
dump_res = rowdat_1._dump_arrow_accel(
963963
col_types, pyarrow_row_ids, pyarrow_data,
964-
).tobytes()
964+
)
965965
load_res = rowdat_1._load_arrow_accel(col_spec, dump_res)
966966

967967
ids = load_res[0]
@@ -1053,7 +1053,7 @@ def test_polars(self):
10531053
def test_pandas(self):
10541054
dump_res = rowdat_1._dump_pandas(
10551055
col_types, pandas_row_ids, pandas_data,
1056-
).tobytes()
1056+
)
10571057
load_res = rowdat_1._load_pandas(col_spec, dump_res)
10581058

10591059
ids = load_res[0]
@@ -1078,7 +1078,7 @@ def test_pandas(self):
10781078
def test_pyarrow(self):
10791079
dump_res = rowdat_1._dump_arrow(
10801080
col_types, pyarrow_row_ids, pyarrow_data,
1081-
).tobytes()
1081+
)
10821082
load_res = rowdat_1._load_arrow(col_spec, dump_res)
10831083

10841084
ids = load_res[0]

0 commit comments

Comments
 (0)