Skip to content

Commit 6fe568e

Browse files
committed
introducing complex hash tables
1 parent 06f7a07 commit 6fe568e

File tree

7 files changed

+199
-39
lines changed

7 files changed

+199
-39
lines changed

pandas/_libs/hashtable.pxd

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from numpy cimport intp_t, ndarray
22

33
from pandas._libs.khash cimport (
4+
complex64_t,
5+
complex128_t,
46
float32_t,
57
float64_t,
68
int8_t,
79
int16_t,
810
int32_t,
911
int64_t,
12+
kh_complex64_t,
13+
kh_complex128_t,
1014
kh_float32_t,
1115
kh_float64_t,
1216
kh_int8_t,
@@ -19,6 +23,8 @@ from pandas._libs.khash cimport (
1923
kh_uint16_t,
2024
kh_uint32_t,
2125
kh_uint64_t,
26+
khcomplex64_t,
27+
khcomplex128_t,
2228
uint8_t,
2329
uint16_t,
2430
uint32_t,
@@ -90,6 +96,18 @@ cdef class Float32HashTable(HashTable):
9096
cpdef get_item(self, float32_t val)
9197
cpdef set_item(self, float32_t key, Py_ssize_t val)
9298

99+
cdef class Complex64HashTable(HashTable):
100+
cdef kh_complex64_t *table
101+
102+
cpdef get_item(self, complex64_t val)
103+
cpdef set_item(self, complex64_t key, Py_ssize_t val)
104+
105+
cdef class Complex128HashTable(HashTable):
106+
cdef kh_complex128_t *table
107+
108+
cpdef get_item(self, complex128_t val)
109+
cpdef set_item(self, complex128_t key, Py_ssize_t val)
110+
93111
cdef class PyObjectHashTable(HashTable):
94112
cdef kh_pymap_t *table
95113

pandas/_libs/hashtable.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@ cnp.import_array()
1313

1414

1515
from pandas._libs cimport util
16-
from pandas._libs.khash cimport KHASH_TRACE_DOMAIN, kh_str_t, khiter_t
16+
from pandas._libs.khash cimport (
17+
KHASH_TRACE_DOMAIN,
18+
are_equal_khcomplex64_t,
19+
are_equal_khcomplex128_t,
20+
kh_str_t,
21+
khcomplex64_t,
22+
khcomplex128_t,
23+
khiter_t,
24+
)
1725
from pandas._libs.missing cimport checknull
1826

1927

pandas/_libs/hashtable_class_helper.pxi.in

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,34 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
88
{{py:
99

1010
# name
11-
cimported_types = ['float32',
11+
complex_types = ['complex64',
12+
'complex128']
13+
}}
14+
15+
{{for name in complex_types}}
16+
cdef kh{{name}}_t to_kh{{name}}_t({{name}}_t val) nogil:
17+
cdef kh{{name}}_t res
18+
res.real = val.real
19+
res.imag = val.imag
20+
return res
21+
22+
cdef {{name}}_t to_{{name}}(kh{{name}}_t val) nogil:
23+
cdef {{name}}_t res
24+
res.real = val.real
25+
res.imag = val.imag
26+
return res
27+
28+
cdef bint is_nan_kh{{name}}_t(kh{{name}}_t val) nogil:
29+
return val.real != val.real or val.imag != val.imag
30+
{{endfor}}
31+
32+
33+
{{py:
34+
35+
# name
36+
cimported_types = ['complex64',
37+
'complex128',
38+
'float32',
1239
'float64',
1340
'int8',
1441
'int16',
@@ -48,7 +75,9 @@ from pandas._libs.missing cimport C_NA
4875
# but is included for completeness (rather ObjectVector is used
4976
# for uniques in hashtables)
5077

51-
dtypes = [('Float64', 'float64', 'float64_t'),
78+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
79+
('Complex64', 'complex64', 'khcomplex64_t'),
80+
('Float64', 'float64', 'float64_t'),
5281
('Float32', 'float32', 'float32_t'),
5382
('Int64', 'int64', 'int64_t'),
5483
('Int32', 'int32', 'int32_t'),
@@ -94,6 +123,8 @@ ctypedef fused vector_data:
94123
UInt8VectorData
95124
Float64VectorData
96125
Float32VectorData
126+
Complex128VectorData
127+
Complex64VectorData
97128
StringVectorData
98129

99130
cdef inline bint needs_resize(vector_data *data) nogil:
@@ -106,7 +137,9 @@ cdef inline bint needs_resize(vector_data *data) nogil:
106137
{{py:
107138

108139
# name, dtype, c_type
109-
dtypes = [('Float64', 'float64', 'float64_t'),
140+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
141+
('Complex64', 'complex64', 'khcomplex64_t'),
142+
('Float64', 'float64', 'float64_t'),
110143
('UInt64', 'uint64', 'uint64_t'),
111144
('Int64', 'int64', 'int64_t'),
112145
('Float32', 'float32', 'float32_t'),
@@ -303,22 +336,24 @@ cdef class HashTable:
303336

304337
{{py:
305338

306-
# name, dtype, c_type, float_group
307-
dtypes = [('Float64', 'float64', 'float64_t', True),
308-
('UInt64', 'uint64', 'uint64_t', False),
309-
('Int64', 'int64', 'int64_t', False),
310-
('Float32', 'float32', 'float32_t', True),
311-
('UInt32', 'uint32', 'uint32_t', False),
312-
('Int32', 'int32', 'int32_t', False),
313-
('UInt16', 'uint16', 'uint16_t', False),
314-
('Int16', 'int16', 'int16_t', False),
315-
('UInt8', 'uint8', 'uint8_t', False),
316-
('Int8', 'int8', 'int8_t', False)]
339+
# name, dtype, c_type, float_group, complex_group
340+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t', True, True),
341+
('Float64', 'float64', 'float64_t', True, False),
342+
('UInt64', 'uint64', 'uint64_t', False, False),
343+
('Int64', 'int64', 'int64_t', False, False),
344+
('Complex64', 'complex64', 'khcomplex64_t', True, True),
345+
('Float32', 'float32', 'float32_t', True, False),
346+
('UInt32', 'uint32', 'uint32_t', False, False),
347+
('Int32', 'int32', 'int32_t', False, False),
348+
('UInt16', 'uint16', 'uint16_t', False, False),
349+
('Int16', 'int16', 'int16_t', False, False),
350+
('UInt8', 'uint8', 'uint8_t', False, False),
351+
('Int8', 'int8', 'int8_t', False, False)]
317352

318353
}}
319354

320355

321-
{{for name, dtype, c_type, float_group in dtypes}}
356+
{{for name, dtype, c_type, float_group, complex_group in dtypes}}
322357

323358
cdef class {{name}}HashTable(HashTable):
324359

@@ -339,7 +374,13 @@ cdef class {{name}}HashTable(HashTable):
339374
def __contains__(self, object key):
340375
cdef:
341376
khiter_t k
342-
k = kh_get_{{dtype}}(self.table, key)
377+
{{c_type}} ckey
378+
{{if complex_group}}
379+
ckey = to_{{c_type}}(key)
380+
{{else}}
381+
ckey = key
382+
{{endif}}
383+
k = kh_get_{{dtype}}(self.table, ckey)
343384
return k != self.table.n_buckets
344385

345386
def sizeof(self, deep=False):
@@ -353,7 +394,13 @@ cdef class {{name}}HashTable(HashTable):
353394
cpdef get_item(self, {{dtype}}_t val):
354395
cdef:
355396
khiter_t k
356-
k = kh_get_{{dtype}}(self.table, val)
397+
{{c_type}} cval
398+
{{if complex_group}}
399+
cval = to_{{c_type}}(val)
400+
{{else}}
401+
cval = val
402+
{{endif}}
403+
k = kh_get_{{dtype}}(self.table, cval)
357404
if k != self.table.n_buckets:
358405
return self.table.vals[k]
359406
else:
@@ -363,8 +410,13 @@ cdef class {{name}}HashTable(HashTable):
363410
cdef:
364411
khiter_t k
365412
int ret = 0
366-
367-
k = kh_put_{{dtype}}(self.table, key, &ret)
413+
{{c_type}} ckey
414+
{{if complex_group}}
415+
ckey = to_{{c_type}}(key)
416+
{{else}}
417+
ckey = key
418+
{{endif}}
419+
k = kh_put_{{dtype}}(self.table, ckey, &ret)
368420
if kh_exist_{{dtype}}(self.table, k):
369421
self.table.vals[k] = val
370422
else:
@@ -486,9 +538,17 @@ cdef class {{name}}HashTable(HashTable):
486538
# We use None, to make it optional, which requires `object` type
487539
# for the parameter. To please the compiler, we use na_value2,
488540
# which is only used if it's *specified*.
489-
na_value2 = <{{dtype}}_t>na_value
541+
{{if complex_group}}
542+
na_value2 = to_{{c_type}}(na_value)
543+
{{else}}
544+
na_value2 = na_value
545+
{{endif}}
490546
else:
547+
{{if complex_group}}
548+
na_value2 = to_{{c_type}}(0)
549+
{{else}}
491550
na_value2 = 0
551+
{{endif}}
492552

493553
with nogil:
494554
for i in range(n):
@@ -499,10 +559,14 @@ cdef class {{name}}HashTable(HashTable):
499559
labels[i] = na_sentinel
500560
continue
501561
elif ignore_na and (
502-
{{if not name.lower().startswith(("uint", "int"))}}
503-
val != val or
504-
{{endif}}
562+
{{if complex_group}}
563+
not is_nan_{{c_type}}(val) or
564+
(use_na_value and are_equal_{{c_type}}(val,na_value2))
565+
{{elif float_group}}
566+
val != val or (use_na_value and val == na_value2)
567+
{{else}}
505568
(use_na_value and val == na_value2)
569+
{{endif}}
506570
):
507571
# if missing values do not count as unique values (i.e. if
508572
# ignore_na is True), skip the hashtable entry for them,
@@ -625,7 +689,12 @@ cdef class {{name}}HashTable(HashTable):
625689
val = values[i]
626690

627691
# specific for groupby
628-
{{if dtype != 'uint64'}}
692+
{{if dtype == 'complex64' or dtype== 'complex128'}}
693+
# TODO: what should be done here?
694+
if val.real < 0:
695+
labels[i] = -1
696+
continue
697+
{{elif dtype != 'uint64'}}
629698
if val < 0:
630699
labels[i] = -1
631700
continue

pandas/_libs/hashtable_func_helper.pxi.in

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,24 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
66

77
{{py:
88

9-
# dtype, ttype, c_type
10-
dtypes = [('float64', 'float64', 'float64_t'),
11-
('float32', 'float32', 'float32_t'),
12-
('uint64', 'uint64', 'uint64_t'),
13-
('uint32', 'uint32', 'uint32_t'),
14-
('uint16', 'uint16', 'uint16_t'),
15-
('uint8', 'uint8', 'uint8_t'),
16-
('object', 'pymap', 'object'),
17-
('int64', 'int64', 'int64_t'),
18-
('int32', 'int32', 'int32_t'),
19-
('int16', 'int16', 'int16_t'),
20-
('int8', 'int8', 'int8_t')]
9+
# dtype, ttype, c_type, complex_group
10+
dtypes = [('complex128', 'complex128', 'khcomplex128_t', True),
11+
('complex64', 'complex64', 'khcomplex64_t', True),
12+
('float64', 'float64', 'float64_t', False),
13+
('float32', 'float32', 'float32_t', False),
14+
('uint64', 'uint64', 'uint64_t', False),
15+
('uint32', 'uint32', 'uint32_t', False),
16+
('uint16', 'uint16', 'uint16_t', False),
17+
('uint8', 'uint8', 'uint8_t', False),
18+
('object', 'pymap', 'object', False),
19+
('int64', 'int64', 'int64_t', False),
20+
('int32', 'int32', 'int32_t', False),
21+
('int16', 'int16', 'int16_t', False),
22+
('int8', 'int8', 'int8_t', False)]
2123

2224
}}
2325

24-
{{for dtype, ttype, c_type in dtypes}}
26+
{{for dtype, ttype, c_type, complex_group in dtypes}}
2527

2628

2729
@cython.wraparound(False)
@@ -63,6 +65,8 @@ cdef build_count_table_{{dtype}}({{c_type}}[:] values,
6365

6466
{{if dtype == 'float64' or dtype == 'float32'}}
6567
if val == val or not dropna:
68+
{{elif complex_group}}
69+
if not is_nan_{{c_type}}(val) or not dropna:
6670
{{else}}
6771
if True:
6872
{{endif}}
@@ -114,7 +118,11 @@ cpdef value_count_{{dtype}}({{c_type}}[:] values, bint dropna):
114118
with nogil:
115119
for k in range(table.n_buckets):
116120
if kh_exist_{{ttype}}(table, k):
121+
{{if complex_group}}
122+
result_keys[i] = to_{{dtype}}(table.keys[k])
123+
{{else}}
117124
result_keys[i] = table.keys[k]
125+
{{endif}}
118126
result_counts[i] = table.vals[k]
119127
i += 1
120128
{{endif}}
@@ -279,7 +287,9 @@ def ismember_{{dtype}}(const {{c_type}}[:] arr, const {{c_type}}[:] values):
279287
{{py:
280288

281289
# dtype, ctype, table_type, npy_dtype
282-
dtypes = [('float64', 'float64_t', 'float64', 'float64'),
290+
dtypes = [('complex128', 'khcomplex128_t', 'complex128', 'complex128'),
291+
('complex64', 'khcomplex64_t', 'complex64', 'complex64'),
292+
('float64', 'float64_t', 'float64', 'float64'),
283293
('float32', 'float32_t', 'float32', 'float32'),
284294
('int64', 'int64_t', 'int64', 'int64'),
285295
('int32', 'int32_t', 'int32', 'int32'),

pandas/_libs/khash.pxd

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from cpython.object cimport PyObject
22
from numpy cimport (
3+
complex64_t,
4+
complex128_t,
35
float32_t,
46
float64_t,
57
int8_t,
@@ -19,6 +21,18 @@ cdef extern from "khash_python.h":
1921
ctypedef uint32_t khint_t
2022
ctypedef khint_t khiter_t
2123

24+
ctypedef struct khcomplex128_t:
25+
double real
26+
double imag
27+
28+
bint are_equal_khcomplex128_t "kh_complex_hash_equal" (khcomplex128_t a, khcomplex128_t b) nogil
29+
30+
ctypedef struct khcomplex64_t:
31+
float real
32+
float imag
33+
34+
bint are_equal_khcomplex64_t "kh_complex_hash_equal" (khcomplex64_t a, khcomplex64_t b) nogil
35+
2236
ctypedef struct kh_pymap_t:
2337
khint_t n_buckets, size, n_occupied, upper_bound
2438
uint32_t *flags

pandas/_libs/khash_for_primitive_helper.pxi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ primitive_types = [('int64', 'int64_t'),
1717
('uint16', 'uint16_t'),
1818
('int8', 'int8_t'),
1919
('uint8', 'uint8_t'),
20+
('complex64', 'khcomplex64_t'),
21+
('complex128', 'khcomplex128_t'),
2022
]
2123
}}
2224

0 commit comments

Comments
 (0)