diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 8a58666..9495f0c 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -16,8 +16,10 @@ """ import datetime +import re import numpy +import packaging.version import pandas import pandas.compat.numpy.function import pandas.core.algorithms @@ -35,6 +37,8 @@ date_dtype_name = "date" time_dtype_name = "time" +pandas_release = packaging.version.parse(pandas.__version__).release + @pandas.core.dtypes.dtypes.register_extension_dtype class TimeDtype(core.BaseDatetimeDtype): @@ -61,15 +65,33 @@ class TimeArray(core.BaseDatetimeArray): _npepoch = numpy.datetime64(_epoch) @classmethod - def _datetime(cls, scalar): + def _datetime( + cls, + scalar, + match_fn=re.compile( + r"\s*(?P\d+)(?::(?P\d+)(?::(?P\d+(?:[.]\d+)?)?)?)?\s*$" + ).match, + ): if isinstance(scalar, datetime.time): return datetime.datetime.combine(cls._epoch, scalar) elif isinstance(scalar, str): # iso string - h, m, s = map(float, scalar.split(":")) - s, us = divmod(s, 1) + match = match_fn(scalar) + if not match: + raise ValueError(f"Bad time string: {repr(scalar)}") + + hour = match.group("hour") + minute = match.group("minute") + second = match.group("second") + second, microsecond = divmod(float(second if second else 0), 1) return datetime.datetime( - 1970, 1, 1, int(h), int(m), int(s), int(us * 1000000) + 1970, + 1, + 1, + int(hour), + int(minute if minute else 0), + int(second), + int(microsecond * 1_000_000), ) else: raise TypeError("Invalid value type", scalar) @@ -96,6 +118,11 @@ def astype(self, dtype, copy=True): else: return super().astype(dtype, copy=copy) + if pandas_release < (1,): + + def to_numpy(self, dtype="object"): + return self.astype(dtype) + def __arrow_array__(self, type=None): return pyarrow.array( self.to_numpy(), type=type if type is not None else pyarrow.time64("ns"), @@ -125,12 +152,20 @@ class DateArray(core.BaseDatetimeArray): dtype = DateDtype() @staticmethod - def _datetime(scalar): + def _datetime( + scalar, + match_fn=re.compile(r"\s*(?P\d+)-(?P\d+)-(?P\d+)\s*$").match, + ): if isinstance(scalar, datetime.date): return datetime.datetime(scalar.year, scalar.month, scalar.day) elif isinstance(scalar, str): - # iso string - return datetime.datetime(*map(int, scalar.split("-"))) + match = match_fn(scalar) + if not match: + raise ValueError(f"Bad date string: {repr(scalar)}") + year = int(match.group("year")) + month = int(match.group("month")) + day = int(match.group("day")) + return datetime.datetime(year, month, day) else: raise TypeError("Invalid value type", scalar) diff --git a/db_dtypes/core.py b/db_dtypes/core.py index 3b05ad6..fbc784e 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -84,6 +84,12 @@ def astype(self, dtype, copy=True): return super().astype(dtype, copy=copy) def _cmp_method(self, other, op): + oshape = getattr(other, "shape", None) + if oshape != self.shape and oshape != (1,) and self.shape != (1,): + raise TypeError( + "Can't compare arrays with different shapes", self.shape, oshape + ) + if type(other) != type(self): return NotImplemented return op(self._ndarray, other._ndarray) diff --git a/db_dtypes/pandas_backports.py b/db_dtypes/pandas_backports.py index bfeb148..003224f 100644 --- a/db_dtypes/pandas_backports.py +++ b/db_dtypes/pandas_backports.py @@ -31,8 +31,17 @@ def import_default(module_name, force=False, default=None): + """ + Provide an implementation for a class or function when it can't be imported + + or when force is True. + + This is used to replicate Pandas APIs that are missing or insufficient + (thus the force option) in early pandas versions. + """ + if default is None: - return lambda func: import_default(module_name, force, func) + return lambda func_or_class: import_default(module_name, force, func_or_class) if force: return default diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index fd89d90..a7388cd 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,6 +5,7 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", packaging==17.0 +# Make sure we test with pandas 0.24.2. The Python version isn't that relevant. pandas==0.24.2 pyarrow==3.0.0 numpy==1.16.6 diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 684864f..0b3b309 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -1 +1,2 @@ +# Make sure we test with pandas 1.1.0. The Python version isn't that relevant. pandas==1.1.0 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index 3fd8886..2e7f354 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -1 +1,2 @@ +# Make sure we test with pandas 1.2.0. The Python version isn't that relevant. pandas==1.2.0 diff --git a/tests/unit/test_dtypes.py b/tests/unit/test_dtypes.py index 4991639..eca3a31 100644 --- a/tests/unit/test_dtypes.py +++ b/tests/unit/test_dtypes.py @@ -15,6 +15,7 @@ import datetime import packaging.version +import pyarrow.lib import pytest pd = pytest.importorskip("pandas") @@ -171,18 +172,14 @@ def test_timearray_comparisons( # Bad shape for bad_shape in ([], [1, 2, 3]): - if op == "==": - assert not comparisons[op](left, np.array(bad_shape)) - assert complements[op](left, np.array(bad_shape)) - else: - with pytest.raises( - ValueError, match="operands could not be broadcast together", - ): - comparisons[op](left, np.array(bad_shape)) - with pytest.raises( - ValueError, match="operands could not be broadcast together", - ): - complements[op](left, np.array(bad_shape)) + with pytest.raises( + TypeError, match="Can't compare arrays with different shapes" + ): + comparisons[op](left, np.array(bad_shape)) + with pytest.raises( + TypeError, match="Can't compare arrays with different shapes" + ): + complements[op](left, np.array(bad_shape)) # Bad items for bad_items in ( @@ -478,8 +475,10 @@ def test_asdatetime(dtype, same): ) def test_astimedelta(dtype): t = "01:02:03.123456" - expect = pd.to_timedelta([t]).array.astype( - "timedelta64[ns]" if dtype == "timedelta" else dtype + expect = ( + pd.to_timedelta([t]) + .to_numpy() + .astype("timedelta64[ns]" if dtype == "timedelta" else dtype) ) a = _cls("time")([t, None]) @@ -543,7 +542,10 @@ def test_min_max_median(dtype): assert empty.min(skipna=False) is None assert empty.max(skipna=False) is None if pandas_release >= (1, 2): - assert empty.median() is None + with pytest.warns(RuntimeWarning, match="empty slice"): + # It's weird that we get the warning here, and not + # below. :/ + assert empty.median() is None assert empty.median(skipna=False) is None a = _make_one(dtype) @@ -620,3 +622,61 @@ def test_date_sub(): do = pd.Series([pd.DateOffset(days=i) for i in range(4)]) expect = dates.astype("object") - do assert np.array_equal(dates - do, expect) + + +@pytest.mark.parametrize( + "value, expected", [("1", datetime.time(1)), ("1:2", datetime.time(1, 2))], +) +def test_short_time_parsing(value, expected): + assert _cls("time")([value])[0] == expected + + +@pytest.mark.parametrize( + "value, error", + [ + ("thursday", "Bad time string: 'thursday'"), + ("1:2:3thursday", "Bad time string: '1:2:3thursday'"), + ("1:2:3:4", "Bad time string: '1:2:3:4'"), + ("1:2:3.f", "Bad time string: '1:2:3.f'"), + ("1:d:3", "Bad time string: '1:d:3'"), + ("1:2.3", "Bad time string: '1:2.3'"), + ("", "Bad time string: ''"), + ("1:2:99", "second must be in 0[.][.]59"), + ("1:99", "minute must be in 0[.][.]59"), + ("99", "hour must be in 0[.][.]23"), + ], +) +def test_bad_time_parsing(value, error): + with pytest.raises(ValueError, match=error): + _cls("time")([value]) + + +@pytest.mark.parametrize( + "value, error", + [ + ("thursday", "Bad date string: 'thursday'"), + ("1-2-thursday", "Bad date string: '1-2-thursday'"), + ("1-2-3-4", "Bad date string: '1-2-3-4'"), + ("1-2-3.f", "Bad date string: '1-2-3.f'"), + ("1-d-3", "Bad date string: '1-d-3'"), + ("1-3", "Bad date string: '1-3'"), + ("1", "Bad date string: '1'"), + ("", "Bad date string: ''"), + ("2021-2-99", "day is out of range for month"), + ("2021-99-1", "month must be in 1[.][.]12"), + ("10000-1-1", "year 10000 is out of range"), + ], +) +def test_bad_date_parsing(value, error): + with pytest.raises(ValueError, match=error): + _cls("date")([value]) + + +@for_date_and_time +def test_date___arrow__array__(dtype): + a = _make_one(dtype) + ar = a.__arrow_array__() + assert isinstance( + ar, pyarrow.Date32Array if dtype == "date" else pyarrow.Time64Array, + ) + assert [v.as_py() for v in ar] == list(a)