Skip to content

Commit bce48f6

Browse files
committed
Compare unicode character points instead of strcmp
1 parent b06a05a commit bce48f6

File tree

4 files changed

+44
-16
lines changed

4 files changed

+44
-16
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
295295
// codepoint for the next character, returning the size of the character in
296296
// bytes. Does not do any validation or error checking: assumes *c* is valid
297297
// utf-8
298-
static size_t
298+
size_t
299299
utf8_char_to_ucs4_code(unsigned char *c, Py_UCS4 *code)
300300
{
301301
if (c[0] <= 0x7F) {

stringdtype/stringdtype/src/casts.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
PyArrayMethod_Spec **
1414
get_casts(void);
1515

16+
size_t
17+
utf8_char_to_ucs4_code(unsigned char *, Py_UCS4 *);
18+
1619
#endif /* _NPY_CASTS_H */

stringdtype/stringdtype/src/dtype.c

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,45 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
156156
}
157157

158158
// Implementation of PyArray_CompareFunc.
159+
// Compares unicode strings by their code points.
159160
int
160-
compare_strings(char **a, char **b, PyArrayObject *arr)
161+
compare_strings(char **a, char **b, PyArrayObject *NPY_UNUSED(arr))
161162
{
162163
ss *ss_a = (ss *)*a;
163164
ss *ss_b = (ss *)*b;
164-
return strcmp(ss_a->buf, ss_b->buf);
165+
166+
// Index into utf8 byte array
167+
int i_a = 0;
168+
int i_b = 0;
169+
170+
Py_UCS4 code_a;
171+
Py_UCS4 code_b;
172+
173+
while (i_a < ss_a->len && i_b < ss_b->len) {
174+
unsigned char ca = ss_a->buf[i_a];
175+
unsigned char cb = ss_b->buf[i_b];
176+
177+
i_a += utf8_char_to_ucs4_code(&ca, &code_a);
178+
i_b += utf8_char_to_ucs4_code(&cb, &code_b);
179+
180+
// Only compare next code point if these are identical
181+
if (code_a > code_b) {
182+
return 1;
183+
}
184+
else if (code_a < code_b) {
185+
return -1;
186+
}
187+
}
188+
189+
if (i_a == ss_a->len) {
190+
if (i_b == ss_b->len) {
191+
return 0;
192+
}
193+
return -1;
194+
}
195+
else {
196+
return 1;
197+
}
165198
}
166199

167200
static StringDTypeObject *
@@ -321,7 +354,6 @@ init_string_dtype(void)
321354
/* Loaded dynamically, so may need to be set here: */
322355
((PyObject *)&StringDType)->ob_type = &PyArrayDTypeMeta_Type;
323356
((PyTypeObject *)&StringDType)->tp_base = &PyArrayDescr_Type;
324-
325357
if (PyType_Ready((PyTypeObject *)&StringDType) < 0) {
326358
return -1;
327359
}

stringdtype/tests/test_stringdtype.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import concurrent.futures
22
import os
33
import pickle
4-
import string
54
import tempfile
65

76
import numpy as np
@@ -16,14 +15,8 @@ def string_list():
1615

1716

1817
@pytest.fixture
19-
def string_list_long():
20-
abcs = string.ascii_lowercase
21-
22-
pairs = []
23-
for pair in zip(abcs, abcs[1:] + abcs[0]):
24-
pairs.append("".join(pair))
25-
26-
return pairs
18+
def string_list_similar():
19+
return ["left", "right", "leftovers", "righty", "up" "down"]
2720

2821

2922
def test_scalar_creation():
@@ -175,10 +168,10 @@ def test_pickle(string_list):
175168
os.remove(f.name)
176169

177170

178-
def test_sort(string_list_long):
171+
def test_sort(string_list_similar):
179172
"""Test that sorting matches python's internal sorting."""
180-
arr = np.array(string_list_long, dtype=StringDType())
181-
arr_sorted = np.array(sorted(string_list_long), dtype=StringDType())
173+
arr = np.array(string_list_similar, dtype=StringDType())
174+
arr_sorted = np.array(sorted(string_list_similar), dtype=StringDType())
182175

183176
np.random.default_rng().shuffle(arr)
184177
arr.sort()

0 commit comments

Comments
 (0)