Skip to content

Commit 99c7e27

Browse files
committed
types, cleanup, let cython handle dtype dispatch
1 parent 43c085b commit 99c7e27

File tree

2 files changed

+44
-61
lines changed

2 files changed

+44
-61
lines changed

pandas/_libs/groupby_helper.pxi.in

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -735,11 +735,6 @@ def group_min(ndarray[groupby_t, ndim=2] out,
735735
out[i, j] = minx[i, j]
736736

737737

738-
group_min_float64 = group_min["float64_t"]
739-
group_min_float32 = group_min["float32_t"]
740-
group_min_int64 = group_min["int64_t"]
741-
742-
743738
@cython.boundscheck(False)
744739
@cython.wraparound(False)
745740
def group_cummin(groupby_t[:, :] out,
@@ -788,11 +783,6 @@ def group_cummin(groupby_t[:, :] out,
788783
out[i, j] = mval
789784

790785

791-
group_cummin_float64 = group_cummin["float64_t"]
792-
group_cummin_float32 = group_cummin["float32_t"]
793-
group_cummin_int64 = group_cummin["int64_t"]
794-
795-
796786
@cython.boundscheck(False)
797787
@cython.wraparound(False)
798788
def group_cummax(groupby_t[:, :] out,
@@ -838,8 +828,3 @@ def group_cummax(groupby_t[:, :] out,
838828
if val > mval:
839829
accum[lab, j] = mval = val
840830
out[i, j] = mval
841-
842-
843-
group_cummax_float64 = group_cummax["float64_t"]
844-
group_cummax_float32 = group_cummax["float32_t"]
845-
group_cummax_int64 = group_cummax["int64_t"]

pandas/_libs/lib.pyx

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ def fast_unique_multiple_list_gen(object gen, bint sort=True):
263263

264264
@cython.wraparound(False)
265265
@cython.boundscheck(False)
266-
def dicts_to_array(list dicts, list columns):
266+
def dicts_to_array(dicts: list, columns: list):
267267
cdef:
268268
Py_ssize_t i, j, k, n
269-
ndarray[object, ndim=2] result
269+
object[:, :] result
270270
dict row
271271
object col, onan = np.nan
272272

@@ -284,7 +284,7 @@ def dicts_to_array(list dicts, list columns):
284284
else:
285285
result[i, j] = onan
286286

287-
return result
287+
return result.base # `.base` to access underlying np.ndarray
288288

289289

290290
def fast_zip(list ndarrays):
@@ -343,17 +343,17 @@ def get_reverse_indexer(int64_t[:] indexer, Py_ssize_t length):
343343

344344
cdef:
345345
Py_ssize_t i, n = len(indexer)
346-
ndarray[int64_t] rev_indexer
346+
int64_t[:] rev_indexer
347347
int64_t idx
348348

349349
rev_indexer = np.empty(length, dtype=np.int64)
350-
rev_indexer.fill(-1)
350+
rev_indexer[:] = -1
351351
for i in range(n):
352352
idx = indexer[i]
353353
if idx != -1:
354354
rev_indexer[idx] = i
355355

356-
return rev_indexer
356+
return rev_indexer.base # `.base` to access underlying np.ndarray
357357

358358

359359
@cython.wraparound(False)
@@ -460,7 +460,7 @@ def maybe_booleans_to_slice(ndarray[uint8_t] mask):
460460

461461
@cython.wraparound(False)
462462
@cython.boundscheck(False)
463-
def array_equivalent_object(left: object[:], right: object[:]) -> bint:
463+
def array_equivalent_object(left: object[:], right: object[:]) -> bool:
464464
""" perform an element by element comparion on 1-d object arrays
465465
taking into account nan positions """
466466
cdef:
@@ -484,7 +484,7 @@ def array_equivalent_object(left: object[:], right: object[:]) -> bint:
484484
def astype_intsafe(object[:] arr, new_dtype):
485485
cdef:
486486
Py_ssize_t i, n = len(arr)
487-
object v
487+
object val
488488
bint is_datelike
489489
ndarray result
490490

@@ -493,11 +493,11 @@ def astype_intsafe(object[:] arr, new_dtype):
493493

494494
result = np.empty(n, dtype=new_dtype)
495495
for i in range(n):
496-
v = arr[i]
497-
if is_datelike and checknull(v):
496+
val = arr[i]
497+
if is_datelike and checknull(val):
498498
result[i] = NPY_NAT
499499
else:
500-
result[i] = v
500+
result[i] = val
501501

502502
return result
503503

@@ -524,7 +524,7 @@ def astype_unicode(arr: ndarray, skipna: bool=False) -> ndarray[object]:
524524
cdef:
525525
object arr_i
526526
Py_ssize_t i, n = arr.size
527-
ndarray[object] result = np.empty(n, dtype=object)
527+
object[:] result = np.empty(n, dtype=object)
528528

529529
for i in range(n):
530530
arr_i = arr[i]
@@ -534,7 +534,7 @@ def astype_unicode(arr: ndarray, skipna: bool=False) -> ndarray[object]:
534534

535535
result[i] = arr_i
536536

537-
return result
537+
return result.base # `.base` to access underlying np.ndarray
538538

539539

540540
@cython.wraparound(False)
@@ -559,7 +559,7 @@ def astype_str(arr: ndarray, skipna: bool = False) -> ndarray[object]:
559559
cdef:
560560
object arr_i
561561
Py_ssize_t i, n = arr.size
562-
ndarray[object] result = np.empty(n, dtype=object)
562+
object[:] result = np.empty(n, dtype=object)
563563

564564
for i in range(n):
565565
arr_i = arr[i]
@@ -569,24 +569,24 @@ def astype_str(arr: ndarray, skipna: bool = False) -> ndarray[object]:
569569

570570
result[i] = arr_i
571571

572-
return result
572+
return result.base # `.base` to access underlying np.ndarray
573573

574574

575575
@cython.wraparound(False)
576576
@cython.boundscheck(False)
577-
def clean_index_list(list obj):
577+
def clean_index_list(obj: list):
578578
"""
579579
Utility used in pandas.core.index.ensure_index
580580
"""
581581
cdef:
582582
Py_ssize_t i, n = len(obj)
583-
object v
583+
object val
584584
bint all_arrays = 1
585585

586586
for i in range(n):
587-
v = obj[i]
588-
if not (isinstance(v, list) or
589-
util.is_array(v) or hasattr(v, '_data')):
587+
val = obj[i]
588+
if not (isinstance(val, list) or
589+
util.is_array(val) or hasattr(val, '_data')):
590590
all_arrays = 0
591591
break
592592

@@ -595,11 +595,9 @@ def clean_index_list(list obj):
595595

596596
# don't force numpy coerce with nan's
597597
inferred = infer_dtype(obj)
598-
if inferred in ['string', 'bytes', 'unicode',
599-
'mixed', 'mixed-integer']:
598+
if inferred in ['string', 'bytes', 'unicode', 'mixed', 'mixed-integer']:
600599
return np.asarray(obj, dtype=object), 0
601600
elif inferred in ['integer']:
602-
603601
# TODO: we infer an integer but it *could* be a uint64
604602
try:
605603
return np.asarray(obj, dtype='int64'), 0
@@ -680,13 +678,13 @@ def generate_bins_dt64(ndarray[int64_t] values, int64_t[:] binner,
680678

681679
@cython.boundscheck(False)
682680
@cython.wraparound(False)
683-
def row_bool_subset(ndarray[float64_t, ndim=2] values,
681+
def row_bool_subset(float64_t[:, :] values,
684682
ndarray[uint8_t, cast=True] mask):
685683
cdef:
686684
Py_ssize_t i, j, n, k, pos = 0
687-
ndarray[float64_t, ndim=2] out
685+
float64_t[:, :] out
688686

689-
n, k = (<object> values).shape
687+
n, k = (<object>values).shape
690688
assert (n == len(mask))
691689

692690
out = np.empty((mask.sum(), k), dtype=np.float64)
@@ -697,7 +695,7 @@ def row_bool_subset(ndarray[float64_t, ndim=2] values,
697695
out[pos, j] = values[i, j]
698696
pos += 1
699697

700-
return out
698+
return out.base # `.base` to access underlying np.ndarray
701699

702700

703701
@cython.boundscheck(False)
@@ -706,7 +704,7 @@ def row_bool_subset_object(object[:, :] values,
706704
ndarray[uint8_t, cast=True] mask):
707705
cdef:
708706
Py_ssize_t i, j, n, k, pos = 0
709-
ndarray[object, ndim=2] out
707+
object[:, :] out
710708

711709
n, k = (<object>values).shape
712710
assert (n == len(mask))
@@ -719,7 +717,7 @@ def row_bool_subset_object(object[:, :] values,
719717
out[pos, j] = values[i, j]
720718
pos += 1
721719

722-
return out
720+
return out.base # `.base` to access underlying np.ndarray
723721

724722

725723
@cython.boundscheck(False)
@@ -846,19 +844,19 @@ def indices_fast(object index, int64_t[:] labels, list keys,
846844

847845
# core.common import for fast inference checks
848846

849-
def is_float(obj: object) -> bint:
847+
def is_float(obj: object) -> bool:
850848
return util.is_float_object(obj)
851849

852850

853-
def is_integer(obj: object) -> bint:
851+
def is_integer(obj: object) -> bool:
854852
return util.is_integer_object(obj)
855853

856854

857-
def is_bool(obj: object) -> bint:
855+
def is_bool(obj: object) -> bool:
858856
return util.is_bool_object(obj)
859857

860858

861-
def is_complex(obj: object) -> bint:
859+
def is_complex(obj: object) -> bool:
862860
return util.is_complex_object(obj)
863861

864862

@@ -870,7 +868,7 @@ cpdef bint is_interval(object obj):
870868
return getattr(obj, '_typ', '_typ') == 'interval'
871869

872870

873-
def is_period(val: object) -> bint:
871+
def is_period(val: object) -> bool:
874872
""" Return a boolean if this is a Period object """
875873
return util.is_period_object(val)
876874

@@ -1352,7 +1350,7 @@ def infer_datetimelike_array(arr: object) -> object:
13521350
seen_datetime = 1
13531351
elif PyDate_Check(v):
13541352
seen_date = 1
1355-
elif is_timedelta(v) or util.is_timedelta64_object(v):
1353+
elif is_timedelta(v):
13561354
# timedelta, or timedelta64
13571355
seen_timedelta = 1
13581356
else:
@@ -1633,7 +1631,7 @@ cpdef bint is_datetime64_array(ndarray values):
16331631

16341632
@cython.wraparound(False)
16351633
@cython.boundscheck(False)
1636-
def is_datetime_with_singletz_array(values: ndarray) -> bint:
1634+
def is_datetime_with_singletz_array(values: ndarray) -> bool:
16371635
"""
16381636
Check values have the same tzinfo attribute.
16391637
Doesn't check values are datetime-like types.
@@ -2138,7 +2136,7 @@ def map_infer_mask(ndarray arr, object f, uint8_t[:] mask, bint convert=1):
21382136
"""
21392137
cdef:
21402138
Py_ssize_t i, n
2141-
ndarray[object] result
2139+
object[:] result
21422140
object val
21432141

21442142
n = len(arr)
@@ -2163,7 +2161,7 @@ def map_infer_mask(ndarray arr, object f, uint8_t[:] mask, bint convert=1):
21632161
convert_datetime=0,
21642162
convert_timedelta=0)
21652163

2166-
return result
2164+
return result.base # `.base` to access underlying np.ndarray
21672165

21682166

21692167
@cython.wraparound(False)
@@ -2208,7 +2206,7 @@ def map_infer(ndarray arr, object f, bint convert=1):
22082206
return result.base # `.base` to access underlying np.ndarray
22092207

22102208

2211-
def to_object_array(list rows, int min_width=0):
2209+
def to_object_array(rows: list, min_width: int = 0):
22122210
"""
22132211
Convert a list of lists into an object array.
22142212
@@ -2228,7 +2226,7 @@ def to_object_array(list rows, int min_width=0):
22282226
"""
22292227
cdef:
22302228
Py_ssize_t i, j, n, k, tmp
2231-
ndarray[object, ndim=2] result
2229+
object[:, :] result
22322230
list row
22332231

22342232
n = len(rows)
@@ -2247,13 +2245,13 @@ def to_object_array(list rows, int min_width=0):
22472245
for j in range(len(row)):
22482246
result[i, j] = row[j]
22492247

2250-
return result
2248+
return result.base # `.base` to access underlying np.ndarray
22512249

22522250

22532251
def tuples_to_object_array(ndarray[object] tuples):
22542252
cdef:
22552253
Py_ssize_t i, j, n, k, tmp
2256-
ndarray[object, ndim=2] result
2254+
object[:, :] result
22572255
tuple tup
22582256

22592257
n = len(tuples)
@@ -2264,13 +2262,13 @@ def tuples_to_object_array(ndarray[object] tuples):
22642262
for j in range(k):
22652263
result[i, j] = tup[j]
22662264

2267-
return result
2265+
return result.base # `.base` to access underlying np.ndarray
22682266

22692267

2270-
def to_object_array_tuples(list rows):
2268+
def to_object_array_tuples(rows: list):
22712269
cdef:
22722270
Py_ssize_t i, j, n, k, tmp
2273-
ndarray[object, ndim=2] result
2271+
object[:, :] result
22742272
tuple row
22752273

22762274
n = len(rows)
@@ -2295,7 +2293,7 @@ def to_object_array_tuples(list rows):
22952293
for j in range(len(row)):
22962294
result[i, j] = row[j]
22972295

2298-
return result
2296+
return result.base # `.base` to access underlying np.ndarray
22992297

23002298

23012299
@cython.wraparound(False)

0 commit comments

Comments
 (0)