Skip to content

Commit 31d0dc5

Browse files
committed
Add return_inverse to hashtable.unique
1 parent 640162f commit 31d0dc5

File tree

1 file changed

+79
-31
lines changed

1 file changed

+79
-31
lines changed

pandas/_libs/hashtable_class_helper.pxi.in

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,14 @@ cdef class {{name}}HashTable(HashTable):
355355

356356
return np.asarray(locs)
357357

358-
def factorize(self, {{dtype}}_t values):
358+
def factorize(self, {{dtype}}_t[:] values):
359359
uniques = {{name}}Vector()
360-
labels = self.get_labels(values, uniques, 0, 0)
360+
labels = self.get_labels(values, uniques, 0)
361361
return uniques.to_array(), labels
362362

363363
@cython.boundscheck(False)
364364
def get_labels(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
365-
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
365+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
366366
object na_value=None):
367367
cdef:
368368
Py_ssize_t i, n = len(values)
@@ -399,9 +399,11 @@ cdef class {{name}}HashTable(HashTable):
399399
k = kh_get_{{dtype}}(self.table, val)
400400

401401
if k != self.table.n_buckets:
402+
# k falls into a previous bucket
402403
idx = self.table.vals[k]
403404
labels[i] = idx
404405
else:
406+
# k hasn't been seen yet
405407
k = kh_put_{{dtype}}(self.table, val, &ret)
406408
self.table.vals[k] = count
407409

@@ -464,27 +466,42 @@ cdef class {{name}}HashTable(HashTable):
464466
return np.asarray(labels), arr_uniques
465467

466468
@cython.boundscheck(False)
467-
def unique(self, const {{dtype}}_t[:] values):
469+
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False):
468470
cdef:
469-
Py_ssize_t i, n = len(values)
470-
int ret = 0
471-
{{dtype}}_t val
472-
khiter_t k
473-
{{name}}Vector uniques = {{name}}Vector()
474-
{{name}}VectorData *ud
471+
Py_ssize_t i, idx, count = 0, n = len(values)
472+
int64_t[:] labels
473+
int ret = 0
474+
{{dtype}}_t val
475+
khiter_t k
476+
{{name}}Vector uniques = {{name}}Vector()
477+
{{name}}VectorData *ud
475478

476479
ud = uniques.data
480+
if return_inverse:
481+
labels = np.empty(n, dtype=np.int64)
477482

478483
with nogil:
479484
for i in range(n):
480485
val = values[i]
481486
k = kh_get_{{dtype}}(self.table, val)
482-
if k == self.table.n_buckets:
483-
kh_put_{{dtype}}(self.table, val, &ret)
487+
if return_inverse and k != self.table.n_buckets:
488+
# k falls into a previous bucket
489+
idx = self.table.vals[k]
490+
labels[i] = idx
491+
elif k == self.table.n_buckets:
492+
# k hasn't been seen yet
493+
k = kh_put_{{dtype}}(self.table, val, &ret)
484494
if needs_resize(ud):
485495
with gil:
486496
uniques.resize()
487497
append_data_{{dtype}}(ud, val)
498+
if return_inverse:
499+
self.table.vals[k] = count
500+
labels[i] = count
501+
count += 1
502+
503+
if return_inverse:
504+
return uniques.to_array(), np.asarray(labels)
488505
return uniques.to_array()
489506

490507
{{endfor}}
@@ -567,45 +584,57 @@ cdef class StringHashTable(HashTable):
567584
return labels
568585

569586
@cython.boundscheck(False)
570-
def unique(self, ndarray[object] values):
587+
def unique(self, ndarray[object] values, bint return_inverse=False):
571588
cdef:
572-
Py_ssize_t i, count, n = len(values)
589+
Py_ssize_t i, idx, count = 0, n = len(values)
590+
int64_t[:] labels
573591
int64_t[:] uindexer
574592
int ret = 0
575593
object val
576-
ObjectVector uniques
594+
ObjectVector uniques = ObjectVector()
577595
khiter_t k
578596
const char *v
579597
const char **vecs
580598

581-
vecs = <const char **> malloc(n * sizeof(char *))
599+
if return_inverse:
600+
labels = np.zeros(n, dtype=np.int64)
582601
uindexer = np.empty(n, dtype=np.int64)
602+
603+
# assign pointers
604+
vecs = <const char **> malloc(n * sizeof(char *))
583605
for i in range(n):
584606
val = values[i]
585607
v = util.get_c_string(val)
586608
vecs[i] = v
587609

588-
count = 0
610+
611+
# compute
589612
with nogil:
590613
for i in range(n):
591614
v = vecs[i]
592615
k = kh_get_str(self.table, v)
593-
if k == self.table.n_buckets:
594-
kh_put_str(self.table, v, &ret)
616+
if return_inverse and k != self.table.n_buckets:
617+
# k falls into a previous bucket
618+
idx = self.table.vals[k]
619+
labels[i] = <int64_t>idx
620+
elif k == self.table.n_buckets:
621+
# k hasn't been seen yet
622+
k = kh_put_str(self.table, v, &ret)
595623
uindexer[count] = i
624+
if return_inverse:
625+
self.table.vals[k] = count
626+
labels[i] = <int64_t>count
596627
count += 1
628+
597629
free(vecs)
598630

599631
# uniques
600-
uniques = ObjectVector()
601632
for i in range(count):
602633
uniques.append(values[uindexer[i]])
603-
return uniques.to_array()
604634

605-
def factorize(self, ndarray[object] values):
606-
uniques = ObjectVector()
607-
labels = self.get_labels(values, uniques, 0, 0)
608-
return uniques.to_array(), labels
635+
if return_inverse:
636+
return uniques.to_array(), np.asarray(labels)
637+
return uniques.to_array()
609638

610639
@cython.boundscheck(False)
611640
def lookup(self, ndarray[object] values):
@@ -670,7 +699,7 @@ cdef class StringHashTable(HashTable):
670699

671700
@cython.boundscheck(False)
672701
def get_labels(self, ndarray[object] values, ObjectVector uniques,
673-
Py_ssize_t count_prior, int64_t na_sentinel,
702+
Py_ssize_t count_prior=0, int64_t na_sentinel=-1,
674703
object na_value=None):
675704
cdef:
676705
Py_ssize_t i, n = len(values)
@@ -814,26 +843,43 @@ cdef class PyObjectHashTable(HashTable):
814843

815844
return np.asarray(locs)
816845

817-
def unique(self, ndarray[object] values):
846+
@cython.boundscheck(False)
847+
def unique(self, ndarray[object] values, bint return_inverse=False):
818848
cdef:
819-
Py_ssize_t i, n = len(values)
849+
Py_ssize_t i, idx, count = 0, n = len(values)
850+
int64_t[:] labels
820851
int ret = 0
821852
object val
822853
khiter_t k
823854
ObjectVector uniques = ObjectVector()
824855

856+
if return_inverse:
857+
labels = np.empty(n, dtype=np.int64)
858+
825859
for i in range(n):
826860
val = values[i]
827861
hash(val)
828862
k = kh_get_pymap(self.table, <PyObject*>val)
829-
if k == self.table.n_buckets:
830-
kh_put_pymap(self.table, <PyObject*>val, &ret)
863+
if return_inverse and k != self.table.n_buckets:
864+
# k falls into a previous bucket
865+
idx = self.table.vals[k]
866+
labels[i] = <int64_t>idx
867+
elif k == self.table.n_buckets:
868+
# k hasn't been seen yet
869+
k = kh_put_pymap(self.table, <PyObject*>val, &ret)
831870
uniques.append(val)
871+
if return_inverse:
872+
self.table.vals[k] = count
873+
labels[i] = <int64_t>count
874+
count += 1
832875

876+
if return_inverse:
877+
return uniques.to_array(), np.asarray(labels)
833878
return uniques.to_array()
834879

880+
@cython.boundscheck(False)
835881
def get_labels(self, ndarray[object] values, ObjectVector uniques,
836-
Py_ssize_t count_prior, int64_t na_sentinel,
882+
Py_ssize_t count_prior=0, int64_t na_sentinel=-1,
837883
object na_value=None):
838884
cdef:
839885
Py_ssize_t i, n = len(values)
@@ -858,9 +904,11 @@ cdef class PyObjectHashTable(HashTable):
858904

859905
k = kh_get_pymap(self.table, <PyObject*>val)
860906
if k != self.table.n_buckets:
907+
# k falls into a previous bucket
861908
idx = self.table.vals[k]
862909
labels[i] = idx
863910
else:
911+
# k hasn't been seen yet
864912
k = kh_put_pymap(self.table, <PyObject*>val, &ret)
865913
self.table.vals[k] = count
866914
uniques.append(val)

0 commit comments

Comments
 (0)