Skip to content

Commit b06a05a

Browse files
committed
Add string comparison support so that x.sort() works
1 parent aa1d862 commit b06a05a

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

stringdtype/README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,26 @@ NumPy.
1010
Ensure Meson and NumPy are installed in the python environment you would like to use:
1111

1212
```
13-
$ python3 -m pip install meson meson-python numpy build patchelf
13+
$ python3 -m pip install meson meson-python build patchelf
1414
```
1515

16-
Build with meson, create a wheel, and install it
16+
It is important to have the latest development version of numpy installed.
17+
Nightly wheels work well for this purpose, and can be installed easily:
1718

19+
```bash
20+
$ pip install -i https://pypi.anaconda.org/scipy-wheels-nightly/simple numpy
1821
```
22+
23+
Build with meson, create a wheel, and install it.
24+
25+
```bash
1926
$ rm -r dist/
2027
$ meson build
2128
$ python -m build --wheel -Cbuilddir=build
22-
$ python -m pip install dist/asciidtype*.whl
2329
```
2430

25-
The `mesonpy` build backend for pip [does not currently support editable
26-
installs](https://github.com/mesonbuild/meson-python/issues/47), so `pip install
27-
-e .` will not work.
31+
Or simply install directly, taking care to install without build isolation:
32+
33+
```bash
34+
$ pip install -v . --no-build-isolation
35+
```

stringdtype/stringdtype/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .scalar import StringScalar # isort: skip
66
from ._main import StringDType, _memory_usage
77

8-
98
__all__ = [
109
"StringDType",
1110
"StringScalar",

stringdtype/stringdtype/src/dtype.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
155155
return res;
156156
}
157157

158+
// Implementation of PyArray_CompareFunc.
159+
int
160+
compare_strings(char **a, char **b, PyArrayObject *arr)
161+
{
162+
ss *ss_a = (ss *)*a;
163+
ss *ss_b = (ss *)*b;
164+
return strcmp(ss_a->buf, ss_b->buf);
165+
}
166+
158167
static StringDTypeObject *
159168
stringdtype_ensure_canonical(StringDTypeObject *self)
160169
{
@@ -170,6 +179,7 @@ static PyType_Slot StringDType_Slots[] = {
170179
{NPY_DT_setitem, &stringdtype_setitem},
171180
{NPY_DT_getitem, &stringdtype_getitem},
172181
{NPY_DT_ensure_canonical, &stringdtype_ensure_canonical},
182+
{NPY_DT_PyArray_ArrFuncs_compare, &compare_strings},
173183
{0, NULL}};
174184

175185
static PyObject *
@@ -311,6 +321,7 @@ init_string_dtype(void)
311321
/* Loaded dynamically, so may need to be set here: */
312322
((PyObject *)&StringDType)->ob_type = &PyArrayDTypeMeta_Type;
313323
((PyTypeObject *)&StringDType)->tp_base = &PyArrayDescr_Type;
324+
314325
if (PyType_Ready((PyTypeObject *)&StringDType) < 0) {
315326
return -1;
316327
}

stringdtype/stringdtype/src/main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ PyInit__main(void)
9090
if (_import_array() < 0) {
9191
return NULL;
9292
}
93-
if (import_experimental_dtype_api(7) < 0) {
93+
if (import_experimental_dtype_api(8) < 0) {
9494
return NULL;
9595
}
9696

stringdtype/tests/test_stringdtype.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import concurrent.futures
22
import os
33
import pickle
4+
import string
45
import tempfile
56

67
import numpy as np
@@ -14,6 +15,17 @@ def string_list():
1415
return ["abc", "def", "ghi"]
1516

1617

18+
@pytest.fixture
19+
def string_list_long():
20+
abcs = string.ascii_lowercase
21+
22+
pairs = []
23+
for pair in zip(abcs, abcs[1:] + abcs[0]):
24+
pairs.append("".join(pair))
25+
26+
return pairs
27+
28+
1729
def test_scalar_creation():
1830
assert str(StringScalar("abc")) == "abc"
1931

@@ -161,3 +173,13 @@ def test_pickle(string_list):
161173
assert res[1] == dtype
162174

163175
os.remove(f.name)
176+
177+
178+
def test_sort(string_list_long):
179+
"""Test that sorting matches python's internal sorting."""
180+
arr = np.array(string_list_long, dtype=StringDType())
181+
arr_sorted = np.array(sorted(string_list_long), dtype=StringDType())
182+
183+
np.random.default_rng().shuffle(arr)
184+
arr.sort()
185+
np.testing.assert_array_equal(arr, arr_sorted)

0 commit comments

Comments
 (0)