@@ -355,14 +355,14 @@ cdef class {{name}}HashTable(HashTable):
355
355
356
356
return np.asarray(locs)
357
357
358
- def factorize(self, {{dtype}}_t values):
358
+ def factorize(self, {{dtype}}_t[:] values):
359
359
uniques = {{name}}Vector()
360
- labels = self.get_labels(values, uniques, 0, 0 )
360
+ labels = self.get_labels(values, uniques, 0)
361
361
return uniques.to_array(), labels
362
362
363
363
@cython.boundscheck(False)
364
364
def get_labels(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
365
- Py_ssize_t count_prior, Py_ssize_t na_sentinel,
365
+ Py_ssize_t count_prior=0 , Py_ssize_t na_sentinel=-1 ,
366
366
object na_value=None):
367
367
cdef:
368
368
Py_ssize_t i, n = len(values)
@@ -399,9 +399,11 @@ cdef class {{name}}HashTable(HashTable):
399
399
k = kh_get_{{dtype}}(self.table, val)
400
400
401
401
if k != self.table.n_buckets:
402
+ # k falls into a previous bucket
402
403
idx = self.table.vals[k]
403
404
labels[i] = idx
404
405
else:
406
+ # k hasn't been seen yet
405
407
k = kh_put_{{dtype}}(self.table, val, &ret)
406
408
self.table.vals[k] = count
407
409
@@ -464,27 +466,42 @@ cdef class {{name}}HashTable(HashTable):
464
466
return np.asarray(labels), arr_uniques
465
467
466
468
@cython.boundscheck(False)
467
- def unique(self, const {{dtype}}_t[:] values):
469
+ def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False ):
468
470
cdef:
469
- Py_ssize_t i, n = len(values)
470
- int ret = 0
471
- {{dtype}}_t val
472
- khiter_t k
473
- {{name}}Vector uniques = {{name}}Vector()
474
- {{name}}VectorData *ud
471
+ Py_ssize_t i, idx, count = 0, n = len(values)
472
+ int64_t[:] labels
473
+ int ret = 0
474
+ {{dtype}}_t val
475
+ khiter_t k
476
+ {{name}}Vector uniques = {{name}}Vector()
477
+ {{name}}VectorData *ud
475
478
476
479
ud = uniques.data
480
+ if return_inverse:
481
+ labels = np.empty(n, dtype=np.int64)
477
482
478
483
with nogil:
479
484
for i in range(n):
480
485
val = values[i]
481
486
k = kh_get_{{dtype}}(self.table, val)
482
- if k == self.table.n_buckets:
483
- kh_put_{{dtype}}(self.table, val, &ret)
487
+ if return_inverse and k != self.table.n_buckets:
488
+ # k falls into a previous bucket
489
+ idx = self.table.vals[k]
490
+ labels[i] = idx
491
+ elif k == self.table.n_buckets:
492
+ # k hasn't been seen yet
493
+ k = kh_put_{{dtype}}(self.table, val, &ret)
484
494
if needs_resize(ud):
485
495
with gil:
486
496
uniques.resize()
487
497
append_data_{{dtype}}(ud, val)
498
+ if return_inverse:
499
+ self.table.vals[k] = count
500
+ labels[i] = count
501
+ count += 1
502
+
503
+ if return_inverse:
504
+ return uniques.to_array(), np.asarray(labels)
488
505
return uniques.to_array()
489
506
490
507
{{endfor}}
@@ -567,45 +584,57 @@ cdef class StringHashTable(HashTable):
567
584
return labels
568
585
569
586
@cython.boundscheck(False)
570
- def unique(self, ndarray[object] values):
587
+ def unique(self, ndarray[object] values, bint return_inverse=False ):
571
588
cdef:
572
- Py_ssize_t i, count, n = len(values)
589
+ Py_ssize_t i, idx, count = 0, n = len(values)
590
+ int64_t[:] labels
573
591
int64_t[:] uindexer
574
592
int ret = 0
575
593
object val
576
- ObjectVector uniques
594
+ ObjectVector uniques = ObjectVector()
577
595
khiter_t k
578
596
const char *v
579
597
const char **vecs
580
598
581
- vecs = <const char **> malloc(n * sizeof(char *))
599
+ if return_inverse:
600
+ labels = np.zeros(n, dtype=np.int64)
582
601
uindexer = np.empty(n, dtype=np.int64)
602
+
603
+ # assign pointers
604
+ vecs = <const char **> malloc(n * sizeof(char *))
583
605
for i in range(n):
584
606
val = values[i]
585
607
v = util.get_c_string(val)
586
608
vecs[i] = v
587
609
588
- count = 0
610
+
611
+ # compute
589
612
with nogil:
590
613
for i in range(n):
591
614
v = vecs[i]
592
615
k = kh_get_str(self.table, v)
593
- if k == self.table.n_buckets:
594
- kh_put_str(self.table, v, &ret)
616
+ if return_inverse and k != self.table.n_buckets:
617
+ # k falls into a previous bucket
618
+ idx = self.table.vals[k]
619
+ labels[i] = <int64_t>idx
620
+ elif k == self.table.n_buckets:
621
+ # k hasn't been seen yet
622
+ k = kh_put_str(self.table, v, &ret)
595
623
uindexer[count] = i
624
+ if return_inverse:
625
+ self.table.vals[k] = count
626
+ labels[i] = <int64_t>count
596
627
count += 1
628
+
597
629
free(vecs)
598
630
599
631
# uniques
600
- uniques = ObjectVector()
601
632
for i in range(count):
602
633
uniques.append(values[uindexer[i]])
603
- return uniques.to_array()
604
634
605
- def factorize(self, ndarray[object] values):
606
- uniques = ObjectVector()
607
- labels = self.get_labels(values, uniques, 0, 0)
608
- return uniques.to_array(), labels
635
+ if return_inverse:
636
+ return uniques.to_array(), np.asarray(labels)
637
+ return uniques.to_array()
609
638
610
639
@cython.boundscheck(False)
611
640
def lookup(self, ndarray[object] values):
@@ -670,7 +699,7 @@ cdef class StringHashTable(HashTable):
670
699
671
700
@cython.boundscheck(False)
672
701
def get_labels(self, ndarray[object] values, ObjectVector uniques,
673
- Py_ssize_t count_prior, int64_t na_sentinel,
702
+ Py_ssize_t count_prior=0 , int64_t na_sentinel=-1 ,
674
703
object na_value=None):
675
704
cdef:
676
705
Py_ssize_t i, n = len(values)
@@ -814,26 +843,43 @@ cdef class PyObjectHashTable(HashTable):
814
843
815
844
return np.asarray(locs)
816
845
817
- def unique(self, ndarray[object] values):
846
+ @cython.boundscheck(False)
847
+ def unique(self, ndarray[object] values, bint return_inverse=False):
818
848
cdef:
819
- Py_ssize_t i, n = len(values)
849
+ Py_ssize_t i, idx, count = 0, n = len(values)
850
+ int64_t[:] labels
820
851
int ret = 0
821
852
object val
822
853
khiter_t k
823
854
ObjectVector uniques = ObjectVector()
824
855
856
+ if return_inverse:
857
+ labels = np.empty(n, dtype=np.int64)
858
+
825
859
for i in range(n):
826
860
val = values[i]
827
861
hash(val)
828
862
k = kh_get_pymap(self.table, <PyObject*>val)
829
- if k == self.table.n_buckets:
830
- kh_put_pymap(self.table, <PyObject*>val, &ret)
863
+ if return_inverse and k != self.table.n_buckets:
864
+ # k falls into a previous bucket
865
+ idx = self.table.vals[k]
866
+ labels[i] = <int64_t>idx
867
+ elif k == self.table.n_buckets:
868
+ # k hasn't been seen yet
869
+ k = kh_put_pymap(self.table, <PyObject*>val, &ret)
831
870
uniques.append(val)
871
+ if return_inverse:
872
+ self.table.vals[k] = count
873
+ labels[i] = <int64_t>count
874
+ count += 1
832
875
876
+ if return_inverse:
877
+ return uniques.to_array(), np.asarray(labels)
833
878
return uniques.to_array()
834
879
880
+ @cython.boundscheck(False)
835
881
def get_labels(self, ndarray[object] values, ObjectVector uniques,
836
- Py_ssize_t count_prior, int64_t na_sentinel,
882
+ Py_ssize_t count_prior=0 , int64_t na_sentinel=-1 ,
837
883
object na_value=None):
838
884
cdef:
839
885
Py_ssize_t i, n = len(values)
@@ -858,9 +904,11 @@ cdef class PyObjectHashTable(HashTable):
858
904
859
905
k = kh_get_pymap(self.table, <PyObject*>val)
860
906
if k != self.table.n_buckets:
907
+ # k falls into a previous bucket
861
908
idx = self.table.vals[k]
862
909
labels[i] = idx
863
910
else:
911
+ # k hasn't been seen yet
864
912
k = kh_put_pymap(self.table, <PyObject*>val, &ret)
865
913
self.table.vals[k] = count
866
914
uniques.append(val)
0 commit comments