Skip to content

BUG: DTA/TDA constructors with mismatched values/dtype resolutions #55658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,31 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
dtype = pandas_dtype(dtype)
arr = self._ndarray

if isinstance(dtype, (PeriodDtype, DatetimeTZDtype)):
if isinstance(dtype, PeriodDtype):
cls = dtype.construct_array_type()
return cls(arr.view("i8"), dtype=dtype)
elif isinstance(dtype, DatetimeTZDtype):
# error: Incompatible types in assignment (expression has type
# "type[DatetimeArray]", variable has type "type[PeriodArray]")
cls = dtype.construct_array_type() # type: ignore[assignment]
dt64_values = arr.view(f"M8[{dtype.unit}]")
return cls(dt64_values, dtype=dtype)
elif dtype == "M8[ns]":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do just the ns cases need special casing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also have a branch to extend this check to other supported dt64/td64 dtypes, just trying to keep this branch targeted

from pandas.core.arrays import DatetimeArray

return DatetimeArray(arr.view("i8"), dtype=dtype)
# error: Argument 1 to "view" of "ndarray" has incompatible type
# "ExtensionDtype | dtype[Any]"; expected "dtype[Any] | type[Any]
# | _SupportsDType[dtype[Any]]"
dt64_values = arr.view(dtype) # type: ignore[arg-type]
return DatetimeArray(dt64_values, dtype=dtype)
elif dtype == "m8[ns]":
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(arr.view("i8"), dtype=dtype)
# error: Argument 1 to "view" of "ndarray" has incompatible type
# "ExtensionDtype | dtype[Any]"; expected "dtype[Any] | type[Any]
# | _SupportsDType[dtype[Any]]"
td64_values = arr.view(dtype) # type: ignore[arg-type]
return TimedeltaArray(td64_values, dtype=dtype)

# error: Argument "dtype" to "view" of "_ArrayOrScalarCommon" has incompatible
# type "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None,
Expand Down
27 changes: 19 additions & 8 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,9 @@ class TimelikeOps(DatetimeLikeArrayMixin):
def __init__(
self, values, dtype=None, freq=lib.no_default, copy: bool = False
) -> None:
if dtype is not None:
dtype = pandas_dtype(dtype)

values = extract_array(values, extract_numpy=True)
if isinstance(values, IntegerArray):
values = values.to_numpy("int64", na_value=iNaT)
Expand All @@ -1936,13 +1939,11 @@ def __init__(
freq = to_offset(freq)
freq, _ = validate_inferred_freq(freq, values.freq, False)

if dtype is not None:
dtype = pandas_dtype(dtype)
if dtype != values.dtype:
# TODO: we only have tests for this for DTA, not TDA (2022-07-01)
raise TypeError(
f"dtype={dtype} does not match data dtype {values.dtype}"
)
if dtype is not None and dtype != values.dtype:
# TODO: we only have tests for this for DTA, not TDA (2022-07-01)
raise TypeError(
f"dtype={dtype} does not match data dtype {values.dtype}"
)

dtype = values.dtype
values = values._ndarray
Expand All @@ -1952,6 +1953,8 @@ def __init__(
dtype = values.dtype
else:
dtype = self._default_dtype
if isinstance(values, np.ndarray) and values.dtype == "i8":
values = values.view(dtype)

if not isinstance(values, np.ndarray):
raise ValueError(
Expand All @@ -1966,7 +1969,15 @@ def __init__(
# for compat with datetime/timedelta/period shared methods,
# we can sometimes get here with int64 values. These represent
# nanosecond UTC (or tz-naive) unix timestamps
values = values.view(self._default_dtype)
if dtype is None:
dtype = self._default_dtype
values = values.view(self._default_dtype)
elif lib.is_np_dtype(dtype, "mM"):
values = values.view(dtype)
elif isinstance(dtype, DatetimeTZDtype):
kind = self._default_dtype.kind
new_dtype = f"{kind}8[{dtype.unit}]"
values = values.view(new_dtype)

dtype = self._validate_dtype(values, dtype)

Expand Down
9 changes: 8 additions & 1 deletion pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,15 @@ def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
_validate_dt64_dtype(values.dtype)
dtype = _validate_dt64_dtype(dtype)
_validate_dt64_dtype(values.dtype)
if isinstance(dtype, np.dtype):
if values.dtype != dtype:
raise ValueError("Values resolution does not match dtype.")
else:
vunit = np.datetime_data(values.dtype)[0]
if vunit != dtype.unit:
raise ValueError("Values resolution does not match dtype.")
return dtype

# error: Signature of "_simple_new" incompatible with supertype "NDArrayBacked"
Expand Down
14 changes: 7 additions & 7 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ def dtype(self) -> np.dtype[np.timedelta64]: # type: ignore[override]
@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
_validate_td64_dtype(values.dtype)
dtype = _validate_td64_dtype(dtype)
_validate_td64_dtype(values.dtype)
if dtype != values.dtype:
raise ValueError("Values resolution does not match dtype.")
return dtype

# error: Signature of "_simple_new" incompatible with supertype "NDArrayBacked"
Expand Down Expand Up @@ -1202,11 +1204,9 @@ def _validate_td64_dtype(dtype) -> DtypeObj:
)
raise ValueError(msg)

if (
not isinstance(dtype, np.dtype)
or dtype.kind != "m"
or not is_supported_unit(get_unit_from_dtype(dtype))
):
raise ValueError(f"dtype {dtype} cannot be converted to timedelta64[ns]")
if not lib.is_np_dtype(dtype, "m"):
raise ValueError(f"dtype '{dtype}' is invalid, should be np.timedelta64 dtype")
elif not is_supported_unit(get_unit_from_dtype(dtype)):
raise ValueError("Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'")

return dtype
3 changes: 2 additions & 1 deletion pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2307,7 +2307,8 @@ def make_na_array(dtype: DtypeObj, shape: Shape, fill_value) -> ArrayLike:
# NB: exclude e.g. pyarrow[dt64tz] dtypes
ts = Timestamp(fill_value).as_unit(dtype.unit)
i8values = np.full(shape, ts._value)
return DatetimeArray(i8values, dtype=dtype)
dt64values = i8values.view(f"M8[{dtype.unit}]")
return DatetimeArray(dt64values, dtype=dtype)

elif is_1d_only_ea_dtype(dtype):
dtype = cast(ExtensionDtype, dtype)
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/tools/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,9 @@ def _convert_listlike_datetimes(
if tz_parsed is not None:
# We can take a shortcut since the datetime64 numpy array
# is in UTC
dta = DatetimeArray(result, dtype=tz_to_dtype(tz_parsed))
dtype = cast(DatetimeTZDtype, tz_to_dtype(tz_parsed))
dt64_values = result.view(f"M8[{dtype.unit}]")
dta = DatetimeArray(dt64_values, dtype=dtype)
return DatetimeIndex._simple_new(dta, name=name)

return _box_as_indexlike(result, utc=utc, name=name)
Expand Down
18 changes: 18 additions & 0 deletions pandas/tests/arrays/datetimes/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ def test_incorrect_dtype_raises(self):
with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="category")

with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="m8[s]")

with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="M8[D]")

def test_mismatched_values_dtype_units(self):
arr = np.array([1, 2, 3], dtype="M8[s]")
dtype = np.dtype("M8[ns]")
msg = "Values resolution does not match dtype."

with pytest.raises(ValueError, match=msg):
DatetimeArray(arr, dtype=dtype)

dtype2 = DatetimeTZDtype(tz="UTC", unit="ns")
with pytest.raises(ValueError, match=msg):
DatetimeArray(arr, dtype=dtype2)

def test_freq_infer_raises(self):
with pytest.raises(ValueError, match="Frequency inference"):
DatetimeArray(np.array([1, 2, 3], dtype="i8"), freq="infer")
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_repeat_preserves_tz(self):
repeated = arr.repeat([1, 1])

# preserves tz and values, but not freq
expected = DatetimeArray(arr.asi8, freq=None, dtype=arr.dtype)
expected = DatetimeArray._from_sequence(arr.asi8, dtype=arr.dtype)
tm.assert_equal(repeated, expected)

def test_value_counts_preserves_tz(self):
Expand Down
37 changes: 28 additions & 9 deletions pandas/tests/arrays/timedeltas/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,40 @@ def test_non_array_raises(self):
TimedeltaArray([1, 2, 3])

def test_other_type_raises(self):
with pytest.raises(ValueError, match="dtype bool cannot be converted"):
msg = "dtype 'bool' is invalid, should be np.timedelta64 dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="bool"))

def test_incorrect_dtype_raises(self):
# TODO: why TypeError for 'category' but ValueError for i8?
with pytest.raises(
ValueError, match=r"category cannot be converted to timedelta64\[ns\]"
):
msg = "dtype 'category' is invalid, should be np.timedelta64 dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype="category")

with pytest.raises(
ValueError, match=r"dtype int64 cannot be converted to timedelta64\[ns\]"
):
msg = "dtype 'int64' is invalid, should be np.timedelta64 dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("int64"))

msg = r"dtype 'datetime64\[ns\]' is invalid, should be np.timedelta64 dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("M8[ns]"))

msg = (
r"dtype 'datetime64\[us, UTC\]' is invalid, should be np.timedelta64 dtype"
)
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype="M8[us, UTC]")

msg = "Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("m8[Y]"))

def test_mismatched_values_dtype_units(self):
arr = np.array([1, 2, 3], dtype="m8[s]")
dtype = np.dtype("m8[ns]")
msg = r"Values resolution does not match dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray(arr, dtype=dtype)

def test_copy(self):
data = np.array([1, 2, 3], dtype="m8[ns]")
arr = TimedeltaArray(data, copy=False)
Expand All @@ -58,6 +77,6 @@ def test_copy(self):
assert arr._ndarray.base is not data

def test_from_sequence_dtype(self):
msg = "dtype .*object.* cannot be converted to timedelta64"
msg = "dtype 'object' is invalid, should be np.timedelta64 dtype"
with pytest.raises(ValueError, match=msg):
TimedeltaArray._from_sequence([], dtype=object)
4 changes: 3 additions & 1 deletion pandas/tests/base/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_array_multiindex_raises():
),
# Timedelta
(
TimedeltaArray(np.array([0, 3600000000000], dtype="i8"), freq="h"),
TimedeltaArray(
np.array([0, 3600000000000], dtype="i8").view("m8[ns]"), freq="h"
),
np.array([0, 3600000000000], dtype="m8[ns]"),
),
# GH#26406 tz is preserved in Categorical[dt64tz]
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/timedeltas/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_constructor_no_precision_raises(self):
pd.Index(["2000"], dtype="timedelta64")

def test_constructor_wrong_precision_raises(self):
msg = r"dtype timedelta64\[D\] cannot be converted to timedelta64\[ns\]"
msg = "Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'"
with pytest.raises(ValueError, match=msg):
TimedeltaIndex(["2000"], dtype="timedelta64[D]")

Expand Down