Skip to content

PERF: use bisect_right_i8 in vectorized #46341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 62 additions & 73 deletions pandas/_libs/tslibs/vectorized.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ from cpython.datetime cimport (

import numpy as np

cimport numpy as cnp
from numpy cimport (
int64_t,
intp_t,
ndarray,
)

cnp.import_array()

from .conversion cimport normalize_i8_stamp

from .dtypes import Resolution
Expand All @@ -35,52 +38,13 @@ from .timezones cimport (
is_tzlocal,
is_utc,
)
from .tzconversion cimport tz_convert_utc_to_tzlocal
from .tzconversion cimport (
bisect_right_i8,
tz_convert_utc_to_tzlocal,
)

# -------------------------------------------------------------------------

cdef inline object create_datetime_from_ts(
int64_t value,
npy_datetimestruct dts,
tzinfo tz,
object freq,
bint fold,
):
"""
Convenience routine to construct a datetime.datetime from its parts.
"""
return datetime(
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
tz, fold=fold,
)


cdef inline object create_date_from_ts(
int64_t value,
npy_datetimestruct dts,
tzinfo tz,
object freq,
bint fold
):
"""
Convenience routine to construct a datetime.date from its parts.
"""
# GH#25057 add fold argument to match other func_create signatures
return date(dts.year, dts.month, dts.day)


cdef inline object create_time_from_ts(
int64_t value,
npy_datetimestruct dts,
tzinfo tz,
object freq,
bint fold
):
"""
Convenience routine to construct a datetime.time from its parts.
"""
return time(dts.hour, dts.min, dts.sec, dts.us, tz, fold=fold)


@cython.wraparound(False)
@cython.boundscheck(False)
Expand Down Expand Up @@ -119,29 +83,29 @@ def ints_to_pydatetime(
ndarray[object] of type specified by box
"""
cdef:
Py_ssize_t i, n = len(stamps)
Py_ssize_t i, ntrans =- 1, n = len(stamps)
ndarray[int64_t] trans
int64_t[::1] deltas
intp_t[:] pos
int64_t* tdata = NULL
intp_t pos
npy_datetimestruct dts
object dt, new_tz
str typ
int64_t value, local_val, delta = NPY_NAT # dummy for delta
ndarray[object] result = np.empty(n, dtype=object)
object (*func_create)(int64_t, npy_datetimestruct, tzinfo, object, bint)
bint use_utc = False, use_tzlocal = False, use_fixed = False
bint use_pytz = False
bint use_date = False, use_time = False, use_ts = False, use_pydt = False

if box == "date":
assert (tz is None), "tz should be None when converting to date"

func_create = create_date_from_ts
use_date = True
elif box == "timestamp":
func_create = create_timestamp_from_ts
use_ts = True
elif box == "time":
func_create = create_time_from_ts
use_time = True
elif box == "datetime":
func_create = create_datetime_from_ts
use_pydt = True
else:
raise ValueError(
"box must be one of 'datetime', 'date', 'time' or 'timestamp'"
Expand All @@ -153,12 +117,13 @@ def ints_to_pydatetime(
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
pos = trans.searchsorted(stamps, side="right") - 1
tdata = <int64_t*>cnp.PyArray_DATA(trans)
use_pytz = typ == "pytz"

for i in range(n):
Expand All @@ -176,14 +141,26 @@ def ints_to_pydatetime(
elif use_fixed:
local_val = value + delta
else:
local_val = value + deltas[pos[i]]
pos = bisect_right_i8(tdata, value, ntrans) - 1
local_val = value + deltas[pos]

if use_pytz:
# find right representation of dst etc in pytz timezone
new_tz = tz._tzinfos[tz._transition_info[pos[i]]]
if use_pytz:
# find right representation of dst etc in pytz timezone
new_tz = tz._tzinfos[tz._transition_info[pos]]

dt64_to_dtstruct(local_val, &dts)
result[i] = func_create(value, dts, new_tz, freq, fold)

if use_ts:
result[i] = create_timestamp_from_ts(value, dts, new_tz, freq, fold)
elif use_pydt:
result[i] = datetime(
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
new_tz, fold=fold,
)
elif use_date:
result[i] = date(dts.year, dts.month, dts.day)
else:
result[i] = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold=fold)

return result

Expand Down Expand Up @@ -219,12 +196,13 @@ cdef inline int _reso_stamp(npy_datetimestruct *dts):

def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
cdef:
Py_ssize_t i, n = len(stamps)
Py_ssize_t i, ntrans=-1, n = len(stamps)
npy_datetimestruct dts
int reso = RESO_DAY, curr_reso
ndarray[int64_t] trans
int64_t[::1] deltas
intp_t[:] pos
int64_t* tdata = NULL
intp_t pos
int64_t local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False

Expand All @@ -234,12 +212,13 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
pos = trans.searchsorted(stamps, side="right") - 1
tdata = <int64_t*>cnp.PyArray_DATA(trans)

for i in range(n):
if stamps[i] == NPY_NAT:
Expand All @@ -252,7 +231,8 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
elif use_fixed:
local_val = stamps[i] + delta
else:
local_val = stamps[i] + deltas[pos[i]]
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
local_val = stamps[i] + deltas[pos]

dt64_to_dtstruct(local_val, &dts)
curr_reso = _reso_stamp(&dts)
Expand Down Expand Up @@ -282,12 +262,13 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
result : int64 ndarray of converted of normalized nanosecond timestamps
"""
cdef:
Py_ssize_t i, n = len(stamps)
Py_ssize_t i, ntrans =- 1, n = len(stamps)
int64_t[:] result = np.empty(n, dtype=np.int64)
ndarray[int64_t] trans
int64_t[::1] deltas
int64_t* tdata = NULL
str typ
Py_ssize_t[:] pos
Py_ssize_t pos
int64_t local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False

Expand All @@ -297,12 +278,13 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
pos = trans.searchsorted(stamps, side="right") - 1
tdata = <int64_t*>cnp.PyArray_DATA(trans)

for i in range(n):
if stamps[i] == NPY_NAT:
Expand All @@ -316,7 +298,8 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
elif use_fixed:
local_val = stamps[i] + delta
else:
local_val = stamps[i] + deltas[pos[i]]
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
local_val = stamps[i] + deltas[pos]

result[i] = normalize_i8_stamp(local_val)

Expand All @@ -341,10 +324,11 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
is_normalized : bool True if all stamps are normalized
"""
cdef:
Py_ssize_t i, n = len(stamps)
Py_ssize_t i, ntrans =- 1, n = len(stamps)
ndarray[int64_t] trans
int64_t[::1] deltas
intp_t[:] pos
int64_t* tdata = NULL
intp_t pos
int64_t local_val, delta = NPY_NAT
str typ
int64_t day_nanos = 24 * 3600 * 1_000_000_000
Expand All @@ -356,12 +340,13 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
pos = trans.searchsorted(stamps, side="right") - 1
tdata = <int64_t*>cnp.PyArray_DATA(trans)

for i in range(n):
if use_utc:
Expand All @@ -371,7 +356,8 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
elif use_fixed:
local_val = stamps[i] + delta
else:
local_val = stamps[i] + deltas[pos[i]]
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
local_val = stamps[i] + deltas[pos]

if local_val % day_nanos != 0:
return False
Expand All @@ -386,11 +372,12 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
@cython.boundscheck(False)
def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
cdef:
Py_ssize_t i, n = len(stamps)
Py_ssize_t i, ntrans =- 1, n = len(stamps)
int64_t[:] result = np.empty(n, dtype=np.int64)
ndarray[int64_t] trans
int64_t[::1] deltas
Py_ssize_t[:] pos
int64_t* tdata = NULL
intp_t pos
npy_datetimestruct dts
int64_t local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
Expand All @@ -401,12 +388,13 @@ def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
pos = trans.searchsorted(stamps, side="right") - 1
tdata = <int64_t*>cnp.PyArray_DATA(trans)

for i in range(n):
if stamps[i] == NPY_NAT:
Expand All @@ -420,7 +408,8 @@ def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
elif use_fixed:
local_val = stamps[i] + delta
else:
local_val = stamps[i] + deltas[pos[i]]
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
local_val = stamps[i] + deltas[pos]

dt64_to_dtstruct(local_val, &dts)
result[i] = get_period_ordinal(&dts, freq)
Expand Down