@@ -8,7 +8,34 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
8
8
{{py:
9
9
10
10
# name
11
- cimported_types = ['float32',
11
+ complex_types = ['complex64',
12
+ 'complex128']
13
+ }}
14
+
15
+ {{for name in complex_types}}
16
+ cdef kh{{name}}_t to_kh{{name}}_t({{name}}_t val) nogil:
17
+ cdef kh{{name}}_t res
18
+ res.real = val.real
19
+ res.imag = val.imag
20
+ return res
21
+
22
+ cdef {{name}}_t to_{{name}}(kh{{name}}_t val) nogil:
23
+ cdef {{name}}_t res
24
+ res.real = val.real
25
+ res.imag = val.imag
26
+ return res
27
+
28
+ cdef bint is_nan_kh{{name}}_t(kh{{name}}_t val) nogil:
29
+ return val.real != val.real or val.imag != val.imag
30
+ {{endfor}}
31
+
32
+
33
+ {{py:
34
+
35
+ # name
36
+ cimported_types = ['complex64',
37
+ 'complex128',
38
+ 'float32',
12
39
'float64',
13
40
'int8',
14
41
'int16',
@@ -48,7 +75,9 @@ from pandas._libs.missing cimport C_NA
48
75
# but is included for completeness (rather ObjectVector is used
49
76
# for uniques in hashtables)
50
77
51
- dtypes = [('Float64', 'float64', 'float64_t'),
78
+ dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
79
+ ('Complex64', 'complex64', 'khcomplex64_t'),
80
+ ('Float64', 'float64', 'float64_t'),
52
81
('Float32', 'float32', 'float32_t'),
53
82
('Int64', 'int64', 'int64_t'),
54
83
('Int32', 'int32', 'int32_t'),
@@ -94,6 +123,8 @@ ctypedef fused vector_data:
94
123
UInt8VectorData
95
124
Float64VectorData
96
125
Float32VectorData
126
+ Complex128VectorData
127
+ Complex64VectorData
97
128
StringVectorData
98
129
99
130
cdef inline bint needs_resize(vector_data *data) nogil:
@@ -106,7 +137,9 @@ cdef inline bint needs_resize(vector_data *data) nogil:
106
137
{{py:
107
138
108
139
# name, dtype, c_type
109
- dtypes = [('Float64', 'float64', 'float64_t'),
140
+ dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
141
+ ('Complex64', 'complex64', 'khcomplex64_t'),
142
+ ('Float64', 'float64', 'float64_t'),
110
143
('UInt64', 'uint64', 'uint64_t'),
111
144
('Int64', 'int64', 'int64_t'),
112
145
('Float32', 'float32', 'float32_t'),
@@ -303,22 +336,24 @@ cdef class HashTable:
303
336
304
337
{{py:
305
338
306
- # name, dtype, c_type, float_group
307
- dtypes = [('Float64', 'float64', 'float64_t', True),
308
- ('UInt64', 'uint64', 'uint64_t', False),
309
- ('Int64', 'int64', 'int64_t', False),
310
- ('Float32', 'float32', 'float32_t', True),
311
- ('UInt32', 'uint32', 'uint32_t', False),
312
- ('Int32', 'int32', 'int32_t', False),
313
- ('UInt16', 'uint16', 'uint16_t', False),
314
- ('Int16', 'int16', 'int16_t', False),
315
- ('UInt8', 'uint8', 'uint8_t', False),
316
- ('Int8', 'int8', 'int8_t', False)]
339
+ # name, dtype, c_type, float_group, complex_group
340
+ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', True, True),
341
+ ('Float64', 'float64', 'float64_t', True, False),
342
+ ('UInt64', 'uint64', 'uint64_t', False, False),
343
+ ('Int64', 'int64', 'int64_t', False, False),
344
+ ('Complex64', 'complex64', 'khcomplex64_t', True, True),
345
+ ('Float32', 'float32', 'float32_t', True, False),
346
+ ('UInt32', 'uint32', 'uint32_t', False, False),
347
+ ('Int32', 'int32', 'int32_t', False, False),
348
+ ('UInt16', 'uint16', 'uint16_t', False, False),
349
+ ('Int16', 'int16', 'int16_t', False, False),
350
+ ('UInt8', 'uint8', 'uint8_t', False, False),
351
+ ('Int8', 'int8', 'int8_t', False, False)]
317
352
318
353
}}
319
354
320
355
321
- {{for name, dtype, c_type, float_group in dtypes}}
356
+ {{for name, dtype, c_type, float_group, complex_group in dtypes}}
322
357
323
358
cdef class {{name}}HashTable(HashTable):
324
359
@@ -339,7 +374,13 @@ cdef class {{name}}HashTable(HashTable):
339
374
def __contains__(self, object key):
340
375
cdef:
341
376
khiter_t k
342
- k = kh_get_{{dtype}}(self.table, key)
377
+ {{c_type}} ckey
378
+ {{if complex_group}}
379
+ ckey = to_{{c_type}}(key)
380
+ {{else}}
381
+ ckey = key
382
+ {{endif}}
383
+ k = kh_get_{{dtype}}(self.table, ckey)
343
384
return k != self.table.n_buckets
344
385
345
386
def sizeof(self, deep=False):
@@ -353,7 +394,13 @@ cdef class {{name}}HashTable(HashTable):
353
394
cpdef get_item(self, {{dtype}}_t val):
354
395
cdef:
355
396
khiter_t k
356
- k = kh_get_{{dtype}}(self.table, val)
397
+ {{c_type}} cval
398
+ {{if complex_group}}
399
+ cval = to_{{c_type}}(val)
400
+ {{else}}
401
+ cval = val
402
+ {{endif}}
403
+ k = kh_get_{{dtype}}(self.table, cval)
357
404
if k != self.table.n_buckets:
358
405
return self.table.vals[k]
359
406
else:
@@ -363,8 +410,13 @@ cdef class {{name}}HashTable(HashTable):
363
410
cdef:
364
411
khiter_t k
365
412
int ret = 0
366
-
367
- k = kh_put_{{dtype}}(self.table, key, &ret)
413
+ {{c_type}} ckey
414
+ {{if complex_group}}
415
+ ckey = to_{{c_type}}(key)
416
+ {{else}}
417
+ ckey = key
418
+ {{endif}}
419
+ k = kh_put_{{dtype}}(self.table, ckey, &ret)
368
420
if kh_exist_{{dtype}}(self.table, k):
369
421
self.table.vals[k] = val
370
422
else:
@@ -486,9 +538,17 @@ cdef class {{name}}HashTable(HashTable):
486
538
# We use None, to make it optional, which requires `object` type
487
539
# for the parameter. To please the compiler, we use na_value2,
488
540
# which is only used if it's *specified*.
489
- na_value2 = <{{dtype}}_t>na_value
541
+ {{if complex_group}}
542
+ na_value2 = to_{{c_type}}(na_value)
543
+ {{else}}
544
+ na_value2 = na_value
545
+ {{endif}}
490
546
else:
547
+ {{if complex_group}}
548
+ na_value2 = to_{{c_type}}(0)
549
+ {{else}}
491
550
na_value2 = 0
551
+ {{endif}}
492
552
493
553
with nogil:
494
554
for i in range(n):
@@ -499,10 +559,14 @@ cdef class {{name}}HashTable(HashTable):
499
559
labels[i] = na_sentinel
500
560
continue
501
561
elif ignore_na and (
502
- {{if not name.lower().startswith(("uint", "int"))}}
503
- val != val or
504
- {{endif}}
562
+ {{if complex_group}}
563
+ not is_nan_{{c_type}}(val) or
564
+ (use_na_value and are_equal_{{c_type}}(val,na_value2))
565
+ {{elif float_group}}
566
+ val != val or (use_na_value and val == na_value2)
567
+ {{else}}
505
568
(use_na_value and val == na_value2)
569
+ {{endif}}
506
570
):
507
571
# if missing values do not count as unique values (i.e. if
508
572
# ignore_na is True), skip the hashtable entry for them,
@@ -625,7 +689,12 @@ cdef class {{name}}HashTable(HashTable):
625
689
val = values[i]
626
690
627
691
# specific for groupby
628
- {{if dtype != 'uint64'}}
692
+ {{if dtype == 'complex64' or dtype== 'complex128'}}
693
+ # TODO: what should be done here?
694
+ if val.real < 0:
695
+ labels[i] = -1
696
+ continue
697
+ {{elif dtype != 'uint64'}}
629
698
if val < 0:
630
699
labels[i] = -1
631
700
continue
0 commit comments