From 5b40a55f5f95b591f3720ad541161b5499b43354 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Fri, 24 Feb 2023 15:14:03 -0700 Subject: [PATCH] Handle possible NULL pointers in the array buffer --- stringdtype/stringdtype/src/casts.c | 25 ++++++++++++++++----- stringdtype/stringdtype/src/dtype.c | 16 +++++++++++-- stringdtype/stringdtype/src/static_string.c | 19 +++++++++++++++- stringdtype/stringdtype/src/static_string.h | 9 +++++++- stringdtype/stringdtype/src/umath.c | 13 +++++++++-- stringdtype/tests/test_stringdtype.py | 13 +++++++++++ 6 files changed, 83 insertions(+), 12 deletions(-) diff --git a/stringdtype/stringdtype/src/casts.c b/stringdtype/stringdtype/src/casts.c index ea75e4fa..21900574 100644 --- a/stringdtype/stringdtype/src/casts.c +++ b/stringdtype/stringdtype/src/casts.c @@ -53,11 +53,17 @@ string_to_string(PyArrayMethod_Context *context, char *const data[], npy_intp out_stride = strides[1] / context->descriptors[1]->elsize; while (N--) { - out[0] = ssdup(in[0]); + ss *s = empty_if_null(in); + out[0] = ssdup(s); if (out[0] == NULL) { gil_error(PyExc_MemoryError, "ssdup failed"); return -1; } + + if (*in == NULL) { + free(s); + } + in += in_stride; out += out_stride; } @@ -217,12 +223,13 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[], gil_error(PyExc_TypeError, "Invalid unicode code point found"); return -1; } - ss *out_ss = ssnewempty(out_num_bytes); + ss *out_ss = ssnewemptylen(out_num_bytes); if (out_ss == NULL) { - gil_error(PyExc_MemoryError, "ssnewempty failed"); + gil_error(PyExc_MemoryError, "ssnewemptylen failed"); + return -1; } char *out_buf = out_ss->buf; - for (int i = 0; i < num_codepoints; i++) { + for (size_t i = 0; i < num_codepoints; i++) { // get code point Py_UCS4 code = in[i]; @@ -339,8 +346,9 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[], long max_out_size = (context->descriptors[1]->elsize) / 4; while (N--) { - unsigned char *this_string = (unsigned char *)((*in)->buf); - size_t n_bytes = (*in)->len; + ss *s = empty_if_null(in); + unsigned char *this_string = (unsigned char *)(s->buf); + size_t n_bytes = s->len; size_t tot_n_bytes = 0; for (int i = 0; i < max_out_size; i++) { @@ -362,6 +370,11 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[], break; } } + + if (*in == NULL) { + free(s); + } + in += in_stride; out += out_stride; } diff --git a/stringdtype/stringdtype/src/dtype.c b/stringdtype/stringdtype/src/dtype.c index 1a8b492a..75380823 100644 --- a/stringdtype/stringdtype/src/dtype.c +++ b/stringdtype/stringdtype/src/dtype.c @@ -135,9 +135,21 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj, } static PyObject * -stringdtype_getitem(StringDTypeObject *descr, char **dataptr) +stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr) { - PyObject *val_obj = PyUnicode_FromString(((ss *)*dataptr)->buf); + char *data; + + // in the future this could represent missing data too, but we'd + // need to make it so np.empty and np.zeros take their initial value + // from an API hook that doesn't exist yet + if (*dataptr == NULL) { + data = "\0"; + } + else { + data = ((ss *)*dataptr)->buf; + } + + PyObject *val_obj = PyUnicode_FromString(data); if (val_obj == NULL) { return NULL; diff --git a/stringdtype/stringdtype/src/static_string.c b/stringdtype/stringdtype/src/static_string.c index 0defdfea..17c9a5bf 100644 --- a/stringdtype/stringdtype/src/static_string.c +++ b/stringdtype/stringdtype/src/static_string.c @@ -33,9 +33,26 @@ ssdup(const ss *s) // does not do any initialization, the caller must // initialize and null-terminate the string ss * -ssnewempty(size_t len) +ssnewemptylen(size_t len) { ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1)); ret->len = len; return ret; } + +ss * +ssempty(void) +{ + return ssnewlen("", 0); +} + +ss * +empty_if_null(ss **data) +{ + if (*data == NULL) { + return ssempty(); + } + else { + return *data; + } +} diff --git a/stringdtype/stringdtype/src/static_string.h b/stringdtype/stringdtype/src/static_string.h index ee4fc4fa..76af3875 100644 --- a/stringdtype/stringdtype/src/static_string.h +++ b/stringdtype/stringdtype/src/static_string.h @@ -21,6 +21,13 @@ ssdup(const ss *s); // does not do any initialization, the caller must // initialize and null-terminate the string ss * -ssnewempty(size_t len); +ssnewemptylen(size_t len); + +// returns an new heap-allocated empty string +ss * +ssnewempty(void); + +ss * +empty_if_null(ss **data); #endif /*_NPY_STATIC_STRING_H */ diff --git a/stringdtype/stringdtype/src/umath.c b/stringdtype/stringdtype/src/umath.c index 7ef8dc12..48a42508 100644 --- a/stringdtype/stringdtype/src/umath.c +++ b/stringdtype/stringdtype/src/umath.c @@ -30,8 +30,8 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[], npy_intp out_stride = strides[2]; while (N--) { - ss *s1 = *in1; - ss *s2 = *in2; + ss *s1 = empty_if_null(in1); + ss *s2 = empty_if_null(in2); if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) { *out = (npy_bool)1; @@ -39,6 +39,15 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[], else { *out = (npy_bool)0; } + + if (*in1 == NULL) { + free(s1); + } + + if (*in2 == NULL) { + free(s2); + } + in1 += in1_stride; in2 += in2_stride; out += out_stride; diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index ccdfe403..82b82456 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -181,3 +181,16 @@ def test_sort(strings): np.random.default_rng().shuffle(arr) arr.sort() np.testing.assert_array_equal(arr, arr_sorted) + + +def test_creation_functions(): + np.testing.assert_array_equal( + np.zeros(3, dtype=StringDType()), ["", "", ""] + ) + + np.testing.assert_array_equal( + np.empty(3, dtype=StringDType()), ["", "", ""] + ) + + # make sure getitem works too + assert np.empty(3, dtype=StringDType())[0] == ""