Skip to content

Handle possible NULL pointers in the array buffer #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions stringdtype/stringdtype/src/casts.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;

while (N--) {
out[0] = ssdup(in[0]);
ss *s = empty_if_null(in);
out[0] = ssdup(s);
if (out[0] == NULL) {
gil_error(PyExc_MemoryError, "ssdup failed");
return -1;
}

if (*in == NULL) {
free(s);
}
Copy link
Member

@seberg seberg Feb 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if we definitely don't do reference counting (each array that is not a view owns a copy of the string), then I do still think it makes sense to move the length information into the array data itself.

On this case, how about a pattern off:

    const ss *s = load_string(in);

And then return a statically allocated empty ss (or a once allocated global). That way you don't need the free() call. Freeing is only needed on store, and there you can do it unconditionally. The only thing is you must not free what get_string returns. I hope the const will protect a bit and it should at least be a nice hard crash that is easy to track down.


in += in_stride;
out += out_stride;
}
Expand Down Expand Up @@ -217,12 +223,13 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
gil_error(PyExc_TypeError, "Invalid unicode code point found");
return -1;
}
ss *out_ss = ssnewempty(out_num_bytes);
ss *out_ss = ssnewemptylen(out_num_bytes);
if (out_ss == NULL) {
gil_error(PyExc_MemoryError, "ssnewempty failed");
gil_error(PyExc_MemoryError, "ssnewemptylen failed");
return -1;
}
char *out_buf = out_ss->buf;
for (int i = 0; i < num_codepoints; i++) {
for (size_t i = 0; i < num_codepoints; i++) {
// get code point
Py_UCS4 code = in[i];

Expand Down Expand Up @@ -339,8 +346,9 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
long max_out_size = (context->descriptors[1]->elsize) / 4;

while (N--) {
unsigned char *this_string = (unsigned char *)((*in)->buf);
size_t n_bytes = (*in)->len;
ss *s = empty_if_null(in);
unsigned char *this_string = (unsigned char *)(s->buf);
size_t n_bytes = s->len;
size_t tot_n_bytes = 0;

for (int i = 0; i < max_out_size; i++) {
Expand All @@ -362,6 +370,11 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
break;
}
}

if (*in == NULL) {
free(s);
}

in += in_stride;
out += out_stride;
}
Expand Down
16 changes: 14 additions & 2 deletions stringdtype/stringdtype/src/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,21 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
}

static PyObject *
stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr)
{
PyObject *val_obj = PyUnicode_FromString(((ss *)*dataptr)->buf);
char *data;

// in the future this could represent missing data too, but we'd
// need to make it so np.empty and np.zeros take their initial value
// from an API hook that doesn't exist yet
if (*dataptr == NULL) {
data = "\0";
}
else {
data = ((ss *)*dataptr)->buf;
}

PyObject *val_obj = PyUnicode_FromString(data);

if (val_obj == NULL) {
return NULL;
Expand Down
19 changes: 18 additions & 1 deletion stringdtype/stringdtype/src/static_string.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,26 @@ ssdup(const ss *s)
// does not do any initialization, the caller must
// initialize and null-terminate the string
ss *
ssnewempty(size_t len)
ssnewemptylen(size_t len)
{
ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1));
ret->len = len;
return ret;
}

ss *
ssempty(void)
{
return ssnewlen("", 0);
}

ss *
empty_if_null(ss **data)
{
if (*data == NULL) {
return ssempty();
}
else {
return *data;
}
}
9 changes: 8 additions & 1 deletion stringdtype/stringdtype/src/static_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ ssdup(const ss *s);
// does not do any initialization, the caller must
// initialize and null-terminate the string
ss *
ssnewempty(size_t len);
ssnewemptylen(size_t len);

// returns an new heap-allocated empty string
ss *
ssnewempty(void);

ss *
empty_if_null(ss **data);

#endif /*_NPY_STATIC_STRING_H */
13 changes: 11 additions & 2 deletions stringdtype/stringdtype/src/umath.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,24 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp out_stride = strides[2];

while (N--) {
ss *s1 = *in1;
ss *s2 = *in2;
ss *s1 = empty_if_null(in1);
ss *s2 = empty_if_null(in2);

if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) {
*out = (npy_bool)1;
}
else {
*out = (npy_bool)0;
}

if (*in1 == NULL) {
free(s1);
}

if (*in2 == NULL) {
free(s2);
}

in1 += in1_stride;
in2 += in2_stride;
out += out_stride;
Expand Down
13 changes: 13 additions & 0 deletions stringdtype/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,16 @@ def test_sort(strings):
np.random.default_rng().shuffle(arr)
arr.sort()
np.testing.assert_array_equal(arr, arr_sorted)


def test_creation_functions():
np.testing.assert_array_equal(
np.zeros(3, dtype=StringDType()), ["", "", ""]
)

np.testing.assert_array_equal(
np.empty(3, dtype=StringDType()), ["", "", ""]
)

# make sure getitem works too
assert np.empty(3, dtype=StringDType())[0] == ""