Skip to content

Commit 681a2a1

Browse files
committed
Compare unicode character points instead of strcmp
1 parent b06a05a commit 681a2a1

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
295295
// codepoint for the next character, returning the size of the character in
296296
// bytes. Does not do any validation or error checking: assumes *c* is valid
297297
// utf-8
298-
static size_t
298+
size_t
299299
utf8_char_to_ucs4_code(unsigned char *c, Py_UCS4 *code)
300300
{
301301
if (c[0] <= 0x7F) {

stringdtype/stringdtype/src/casts.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
PyArrayMethod_Spec **
1414
get_casts(void);
1515

16+
size_t
17+
utf8_char_to_ucs4_code(unsigned char *, Py_UCS4 *);
18+
1619
#endif /* _NPY_CASTS_H */

stringdtype/stringdtype/src/dtype.c

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,45 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
156156
}
157157

158158
// Implementation of PyArray_CompareFunc.
159+
// Compares unicode strings by their code points.
159160
int
160-
compare_strings(char **a, char **b, PyArrayObject *arr)
161+
compare_strings(char **a, char **b, PyArrayObject *NPY_UNUSED(arr))
161162
{
162163
ss *ss_a = (ss *)*a;
163164
ss *ss_b = (ss *)*b;
164-
return strcmp(ss_a->buf, ss_b->buf);
165+
166+
// Index into utf8 byte array
167+
int i_a = 0;
168+
int i_b = 0;
169+
170+
Py_UCS4 code_a;
171+
Py_UCS4 code_b;
172+
173+
while (i_a < ss_a->len && i_b < ss_b->len) {
174+
unsigned char ca = ss_a->buf[i_a];
175+
unsigned char cb = ss_b->buf[i_b];
176+
177+
i_a += utf8_char_to_ucs4_code(&ca, &code_a);
178+
i_b += utf8_char_to_ucs4_code(&cb, &code_b);
179+
180+
// Only compare next code point if these are identical
181+
if (code_a > code_b) {
182+
return 1;
183+
}
184+
else if (code_a < code_b) {
185+
return -1;
186+
}
187+
}
188+
189+
if (i_a == ss_a->len) {
190+
if (i_b == ss_b->len) {
191+
return 0;
192+
}
193+
return -1;
194+
}
195+
else {
196+
return 1;
197+
}
165198
}
166199

167200
static StringDTypeObject *
@@ -321,7 +354,6 @@ init_string_dtype(void)
321354
/* Loaded dynamically, so may need to be set here: */
322355
((PyObject *)&StringDType)->ob_type = &PyArrayDTypeMeta_Type;
323356
((PyTypeObject *)&StringDType)->tp_base = &PyArrayDescr_Type;
324-
325357
if (PyType_Ready((PyTypeObject *)&StringDType) < 0) {
326358
return -1;
327359
}

0 commit comments

Comments
 (0)