diff --git a/stringdtype/stringdtype/src/casts.c b/stringdtype/stringdtype/src/casts.c index 8ad331ce..7b951183 100644 --- a/stringdtype/stringdtype/src/casts.c +++ b/stringdtype/stringdtype/src/casts.c @@ -40,24 +40,26 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self), } static int -string_to_string(PyArrayMethod_Context *context, char *const data[], - npy_intp const dimensions[], npy_intp const strides[], - NpyAuxData *NPY_UNUSED(auxdata)) +string_to_string(PyArrayMethod_Context *NPY_UNUSED(context), + char *const data[], npy_intp const dimensions[], + npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) { npy_intp N = dimensions[0]; - ss **in = (ss **)data[0]; - ss **out = (ss **)data[1]; - // strides are in bytes but pointer offsets are in pointer widths, so - // divide by the element size (one pointer width) to get the pointer offset - npy_intp in_stride = strides[0] / context->descriptors[0]->elsize; - npy_intp out_stride = strides[1] / context->descriptors[1]->elsize; + char *in = data[0]; + char *out = data[1]; + npy_intp in_stride = strides[0]; + npy_intp out_stride = strides[1]; + + ss *s = NULL, *os = NULL; while (N--) { - out[0] = ssdup(in[0]); - if (out[0] == NULL) { + load_string(in, &s); + os = (ss *)out; + if (ssdup(s, os) < 0) { gil_error(PyExc_MemoryError, "ssdup failed"); return -1; } + in += in_stride; out += out_stride; } @@ -114,7 +116,7 @@ unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self), // is the number of codepoints that are not trailing null codepoints. Returns // 0 on success and -1 when an invalid code point is found. static int -utf8_size(Py_UCS4 *codepoints, long max_length, size_t *num_codepoints, +utf8_size(const Py_UCS4 *codepoints, long max_length, size_t *num_codepoints, size_t *utf8_bytes) { size_t ucs4len = max_length; @@ -126,7 +128,7 @@ utf8_size(Py_UCS4 *codepoints, long max_length, size_t *num_codepoints, size_t num_bytes = 0; - for (int i = 0; i < ucs4len; i++) { + for (size_t i = 0; i < ucs4len; i++) { Py_UCS4 code = codepoints[i]; if (code <= 0x7F) { @@ -201,13 +203,11 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[], npy_intp N = dimensions[0]; Py_UCS4 *in = (Py_UCS4 *)data[0]; - ss **out = (ss **)data[1]; + char *out = data[1]; // 4 bytes per UCS4 character npy_intp in_stride = strides[0] / 4; - // strides are in bytes but pointer offsets are in pointer widths, so - // divide by the element size (one pointer width) to get the pointer offset - npy_intp out_stride = strides[1] / context->descriptors[1]->elsize; + npy_intp out_stride = strides[1]; while (N--) { size_t out_num_bytes = 0; @@ -217,9 +217,9 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[], gil_error(PyExc_TypeError, "Invalid unicode code point found"); return -1; } - ss *out_ss = ssnewempty(out_num_bytes); - if (out_ss == NULL) { - gil_error(PyExc_MemoryError, "ssnewempty failed"); + ss *out_ss = (ss *)out; + if (ssnewemptylen(out_num_bytes, out_ss) < 0) { + gil_error(PyExc_MemoryError, "ssnewemptylen failed"); return -1; } char *out_buf = out_ss->buf; @@ -247,9 +247,6 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[], // pad string with null character out_buf[out_num_bytes] = '\0'; - // set out to the address of the beginning of the string - out[0] = out_ss; - in += in_stride; out += out_stride; } @@ -329,19 +326,20 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[], NpyAuxData *NPY_UNUSED(auxdata)) { npy_intp N = dimensions[0]; - ss **in = (ss **)data[0]; + char *in = data[0]; Py_UCS4 *out = (Py_UCS4 *)data[1]; - // strides are in bytes but pointer offsets are in pointer widths, so - // divide by the element size (one pointer width) to get the pointer offset - npy_intp in_stride = strides[0] / context->descriptors[0]->elsize; + npy_intp in_stride = strides[0]; // 4 bytes per UCS4 character npy_intp out_stride = strides[1] / 4; // max number of 4 byte UCS4 characters that can fit in the output long max_out_size = (context->descriptors[1]->elsize) / 4; + ss *s = NULL; + while (N--) { - unsigned char *this_string = (unsigned char *)((*in)->buf); - size_t n_bytes = (*in)->len; + load_string(in, &s); + unsigned char *this_string = (unsigned char *)(s->buf); + size_t n_bytes = s->len; size_t tot_n_bytes = 0; for (int i = 0; i < max_out_size; i++) { @@ -363,6 +361,7 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[], break; } } + in += in_stride; out += out_stride; } diff --git a/stringdtype/stringdtype/src/dtype.c b/stringdtype/stringdtype/src/dtype.c index ff7aafc1..44fc328d 100644 --- a/stringdtype/stringdtype/src/dtype.c +++ b/stringdtype/stringdtype/src/dtype.c @@ -16,8 +16,8 @@ new_stringdtype_instance(void) if (new == NULL) { return NULL; } - new->base.elsize = sizeof(ss *); - new->base.alignment = _Alignof(ss *); + new->base.elsize = sizeof(ss); + new->base.alignment = _Alignof(ss); new->base.flags |= NPY_NEEDS_INIT; new->base.flags |= NPY_LIST_PICKLE; new->base.flags |= NPY_ITEM_REFCOUNT; @@ -119,26 +119,42 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj, return -1; } - ss *str_val = ssnewlen(val, length); - if (str_val == NULL) { - PyErr_SetString(PyExc_MemoryError, "ssnewlen failed"); + // free if dataptr holds preexisting string data, + // ssfree does a NULL check + ssfree((ss *)dataptr); + + // copies contents of val into item_val->buf + int res = ssnewlen(val, length, (ss *)dataptr); + + // val_obj must stay alive until here to ensure *val* doesn't get + // deallocated + Py_DECREF(val_obj); + + if (res == -1) { + PyErr_NoMemory(); return -1; } - // the dtype instance has the NPY_NEEDS_INIT flag set, - // so if *dataptr is NULL, that means we're initializing - // the array and don't need to free an existing string - if (*dataptr != NULL) { - free((ss *)*dataptr); + else if (res == -2) { + // this should never happen + assert(0); } - *dataptr = (char *)str_val; - Py_DECREF(val_obj); + return 0; } static PyObject * -stringdtype_getitem(StringDTypeObject *descr, char **dataptr) +stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr) { - PyObject *val_obj = PyUnicode_FromString(((ss *)*dataptr)->buf); + char *data; + + if (*dataptr == NULL) { + data = "\0"; + } + else { + data = ((ss *)dataptr)->buf; + } + + PyObject *val_obj = PyUnicode_FromString(data); if (val_obj == NULL) { return NULL; @@ -147,10 +163,6 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr) PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)StringScalar_Type, val_obj, NULL); - if (res == NULL) { - return NULL; - } - Py_DECREF(val_obj); return res; @@ -161,8 +173,8 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr) int compare_strings(char **a, char **b, PyArrayObject *NPY_UNUSED(arr)) { - ss *ss_a = (ss *)*a; - ss *ss_b = (ss *)*b; + ss *ss_a = (ss *)a; + ss *ss_b = (ss *)b; return strcmp(ss_a->buf, ss_b->buf); } @@ -181,8 +193,8 @@ stringdtype_clear_loop(void *NPY_UNUSED(traverse_context), { while (size--) { if (data != NULL) { - free(*(ss **)data); - *(ss **)data = NULL; + ssfree((ss *)data); + memset(data, 0, sizeof(ss)); } data += stride; } diff --git a/stringdtype/stringdtype/src/main.c b/stringdtype/stringdtype/src/main.c index 24648a08..f5691bb0 100644 --- a/stringdtype/stringdtype/src/main.c +++ b/stringdtype/stringdtype/src/main.c @@ -50,16 +50,15 @@ _memory_usage(PyObject *NPY_UNUSED(self), PyObject *obj) // initialize with the size of the internal buffer size_t memory_usage = PyArray_NBYTES(arr); - size_t struct_size = sizeof(ss); do { - ss **in = (ss **)*dataptr; - npy_intp stride = *strideptr / descr->elsize; + char *in = dataptr[0]; + npy_intp stride = *strideptr; npy_intp count = *innersizeptr; while (count--) { // +1 byte for the null terminator - memory_usage += (*in)->len + struct_size + 1; + memory_usage += ((ss *)in)->len + 1; in += stride; } diff --git a/stringdtype/stringdtype/src/static_string.c b/stringdtype/stringdtype/src/static_string.c index 0defdfea..143d6eaa 100644 --- a/stringdtype/stringdtype/src/static_string.c +++ b/stringdtype/stringdtype/src/static_string.c @@ -1,41 +1,77 @@ #include "static_string.h" -// allocates a new ss string of length len, filling with the contents of init -ss * -ssnewlen(const char *init, size_t len) +int +ssnewlen(const char *init, size_t len, ss *to_init) { + if ((to_init->buf != NULL) || (to_init->len != 0)) { + return -2; + } + // one extra byte for null terminator - ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1)); + char *ret_buf = (char *)malloc(sizeof(char) * (len + 1)); - if (ret == NULL) { - return NULL; + if (ret_buf == NULL) { + return -1; } - ret->len = len; + to_init->len = len; if (len > 0) { - memcpy(ret->buf, init, len); + memcpy(ret_buf, init, len); } - ret->buf[len] = '\0'; + ret_buf[len] = '\0'; + + to_init->buf = ret_buf; + + return 0; +} - return ret; +void +ssfree(ss *str) +{ + if (str->buf != NULL) { + free(str->buf); + str->buf = NULL; + } + str->len = 0; } -// returns a new heap-allocated copy of input string *s* -ss * -ssdup(const ss *s) +int +ssdup(ss *in, ss *out) { - return ssnewlen(s->buf, s->len); + return ssnewlen(in->buf, in->len, out); } -// returns a new, empty string of length len -// does not do any initialization, the caller must -// initialize and null-terminate the string -ss * -ssnewempty(size_t len) +int +ssnewemptylen(size_t num_bytes, ss *out) { - ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1)); - ret->len = len; - return ret; + if (out->len != 0 || out->buf != NULL) { + return -2; + } + + char *buf = (char *)malloc(sizeof(char) * (num_bytes + 1)); + + if (buf == NULL) { + return -1; + } + + out->buf = buf; + out->len = num_bytes; + + return 0; +} + +static ss EMPTY = {0, "\0"}; + +void +load_string(char *data, ss **out) +{ + ss *ss_d = (ss *)data; + if (ss_d->len == 0) { + *out = &EMPTY; + } + else { + *out = ss_d; + } } diff --git a/stringdtype/stringdtype/src/static_string.h b/stringdtype/stringdtype/src/static_string.h index ee4fc4fa..c4a956d2 100644 --- a/stringdtype/stringdtype/src/static_string.h +++ b/stringdtype/stringdtype/src/static_string.h @@ -6,21 +6,41 @@ typedef struct ss { size_t len; - char buf[]; + char *buf; } ss; -// allocates a new ss string of length len, filling with the contents of init -ss * -ssnewlen(const char *init, size_t len); +// Allocates a new buffer for *to_init*, filling with the copied contents of +// *init* and sets *to_init->len* to *len*. Returns -1 if malloc fails and -2 +// if *to_init* is not empty. Returns 0 on success. +int +ssnewlen(const char *init, size_t len, ss *to_init); -// returns a new heap-allocated copy of input string *s* -ss * -ssdup(const ss *s); +// Sets len to 0 and if str->buf is not already NULL, frees it and sets it to +// NULL. Cannot fail. +void +ssfree(ss *str); -// returns a new, empty string of length len -// does not do any initialization, the caller must -// initialize and null-terminate the string -ss * -ssnewempty(size_t len); +// copies the contents out *in* into *out*. Allocates a new string buffer for +// *out*. Returns -1 if malloc fails and -2 if *out* is not empty. Returns 0 on +// success. +int +ssdup(ss *in, ss *out); + +// Allocates a new string buffer for *out* with enough capacity to store +// *num_bytes* of text. The actual allocation is num_bytes + 1 bytes, to +// account for the null terminator. Does not do any initialization, the caller +// must initialize and null-terminate the string buffer. Returns -1 if malloc +// fails and -2 if *out* is not empty. Returns 0 on success. +int +ssnewemptylen(size_t num_bytes, ss *out); + +// Interpret the contents of buffer *data* as an ss struct and set *out* to +// that struct. If *data* is NULL, set *out* to point to a statically +// allocated, empty SS struct. Since this function may set *out* to point to +// statically allocated data, do not ever free memory owned by an output of +// this function. That means this function is most useful for read-only +// applications. +void +load_string(char *data, ss **out); #endif /*_NPY_STATIC_STRING_H */ diff --git a/stringdtype/stringdtype/src/umath.c b/stringdtype/stringdtype/src/umath.c index 7ef8dc12..52da9ee0 100644 --- a/stringdtype/stringdtype/src/umath.c +++ b/stringdtype/stringdtype/src/umath.c @@ -14,24 +14,24 @@ #include "umath.h" static int -string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[], - npy_intp const dimensions[], +string_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context), + char *const data[], npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) { npy_intp N = dimensions[0]; - ss **in1 = (ss **)data[0]; - ss **in2 = (ss **)data[1]; + char *in1 = data[0]; + char *in2 = data[1]; npy_bool *out = (npy_bool *)data[2]; - // strides are in bytes but pointer offsets are in pointer widths, so - // divide by the element size (one pointer width) to get the pointer offset - npy_intp in1_stride = strides[0] / context->descriptors[0]->elsize; - npy_intp in2_stride = strides[1] / context->descriptors[1]->elsize; + npy_intp in1_stride = strides[0]; + npy_intp in2_stride = strides[1]; npy_intp out_stride = strides[2]; + ss *s1 = NULL, *s2 = NULL; + while (N--) { - ss *s1 = *in1; - ss *s2 = *in2; + load_string(in1, &s1); + load_string(in2, &s2); if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) { *out = (npy_bool)1; @@ -39,6 +39,7 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[], else { *out = (npy_bool)0; } + in1 += in1_stride; in2 += in2_stride; out += out_stride; diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index ccdfe403..82b82456 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -181,3 +181,16 @@ def test_sort(strings): np.random.default_rng().shuffle(arr) arr.sort() np.testing.assert_array_equal(arr, arr_sorted) + + +def test_creation_functions(): + np.testing.assert_array_equal( + np.zeros(3, dtype=StringDType()), ["", "", ""] + ) + + np.testing.assert_array_equal( + np.empty(3, dtype=StringDType()), ["", "", ""] + ) + + # make sure getitem works too + assert np.empty(3, dtype=StringDType())[0] == ""