Skip to content

Commit 887d6e5

Browse files
committed
Handle possible NULL pointers in the array buffer
1 parent 21c5618 commit 887d6e5

File tree

6 files changed

+81
-11
lines changed

6 files changed

+81
-11
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,17 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
5353
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
5454

5555
while (N--) {
56-
out[0] = ssdup(in[0]);
56+
ss *s = empty_if_null(in);
57+
out[0] = ssdup(s);
5758
if (out[0] == NULL) {
5859
gil_error(PyExc_MemoryError, "ssdup failed");
5960
return -1;
6061
}
62+
63+
if (*in == NULL) {
64+
free(s);
65+
}
66+
6167
in += in_stride;
6268
out += out_stride;
6369
}
@@ -217,9 +223,9 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
217223
gil_error(PyExc_TypeError, "Invalid unicode code point found");
218224
return -1;
219225
}
220-
ss *out_ss = ssnewempty(out_num_bytes);
226+
ss *out_ss = ssnewemptylen(out_num_bytes);
221227
if (out_ss == NULL) {
222-
gil_error(PyExc_MemoryError, "ssnewempty failed");
228+
gil_error(PyExc_MemoryError, "ssnewemptylen failed");
223229
return -1;
224230
}
225231
char *out_buf = out_ss->buf;
@@ -340,8 +346,9 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
340346
long max_out_size = (context->descriptors[1]->elsize) / 4;
341347

342348
while (N--) {
343-
unsigned char *this_string = (unsigned char *)((*in)->buf);
344-
size_t n_bytes = (*in)->len;
349+
ss *s = empty_if_null(in);
350+
unsigned char *this_string = (unsigned char *)(s->buf);
351+
size_t n_bytes = s->len;
345352
size_t tot_n_bytes = 0;
346353

347354
for (int i = 0; i < max_out_size; i++) {
@@ -363,6 +370,11 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
363370
break;
364371
}
365372
}
373+
374+
if (*in == NULL) {
375+
free(s);
376+
}
377+
366378
in += in_stride;
367379
out += out_stride;
368380
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,21 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
136136
}
137137

138138
static PyObject *
139-
stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
139+
stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr)
140140
{
141-
PyObject *val_obj = PyUnicode_FromString(((ss *)*dataptr)->buf);
141+
char *data;
142+
143+
// in the future this could represent missing data too, but we'd
144+
// need to make it so np.empty and np.zeros take their initial value
145+
// from an API hook that doesn't exist yet
146+
if (*dataptr == NULL) {
147+
data = "\0";
148+
}
149+
else {
150+
data = ((ss *)*dataptr)->buf;
151+
}
152+
153+
PyObject *val_obj = PyUnicode_FromString(data);
142154

143155
if (val_obj == NULL) {
144156
return NULL;

stringdtype/stringdtype/src/static_string.c

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,26 @@ ssdup(const ss *s)
3333
// does not do any initialization, the caller must
3434
// initialize and null-terminate the string
3535
ss *
36-
ssnewempty(size_t len)
36+
ssnewemptylen(size_t len)
3737
{
3838
ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1));
3939
ret->len = len;
4040
return ret;
4141
}
42+
43+
ss *
44+
ssempty(void)
45+
{
46+
return ssnewlen("", 0);
47+
}
48+
49+
ss *
50+
empty_if_null(ss **data)
51+
{
52+
if (*data == NULL) {
53+
return ssempty();
54+
}
55+
else {
56+
return *data;
57+
}
58+
}

stringdtype/stringdtype/src/static_string.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ ssdup(const ss *s);
2121
// does not do any initialization, the caller must
2222
// initialize and null-terminate the string
2323
ss *
24-
ssnewempty(size_t len);
24+
ssnewemptylen(size_t len);
25+
26+
// returns an new heap-allocated empty string
27+
ss *
28+
ssnewempty(void);
29+
30+
ss *
31+
empty_if_null(ss **data);
2532

2633
#endif /*_NPY_STATIC_STRING_H */

stringdtype/stringdtype/src/umath.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,24 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
3030
npy_intp out_stride = strides[2];
3131

3232
while (N--) {
33-
ss *s1 = *in1;
34-
ss *s2 = *in2;
33+
ss *s1 = empty_if_null(in1);
34+
ss *s2 = empty_if_null(in2);
3535

3636
if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) {
3737
*out = (npy_bool)1;
3838
}
3939
else {
4040
*out = (npy_bool)0;
4141
}
42+
43+
if (*in1 == NULL) {
44+
free(s1);
45+
}
46+
47+
if (*in2 == NULL) {
48+
free(s2);
49+
}
50+
4251
in1 += in1_stride;
4352
in2 += in2_stride;
4453
out += out_stride;

stringdtype/tests/test_stringdtype.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,16 @@ def test_sort(strings):
181181
np.random.default_rng().shuffle(arr)
182182
arr.sort()
183183
np.testing.assert_array_equal(arr, arr_sorted)
184+
185+
186+
def test_creation_functions():
187+
np.testing.assert_array_equal(
188+
np.zeros(3, dtype=StringDType()), ["", "", ""]
189+
)
190+
191+
np.testing.assert_array_equal(
192+
np.empty(3, dtype=StringDType()), ["", "", ""]
193+
)
194+
195+
# make sure getitem works too
196+
assert np.empty(3, dtype=StringDType())[0] == ""

0 commit comments

Comments
 (0)