Skip to content

Commit 1db1357

Browse files
authored
fix: dbdate and dbtime support set item with null values (#85)
feat: dbdate and dbtime support numpy.datetime64 values in array constructor
1 parent 38ac28d commit 1db1357

File tree

5 files changed

+91
-27
lines changed

5 files changed

+91
-27
lines changed

db_dtypes/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def _datetime(
106106
r"(?:\.(?P<fraction>\d*))?)?)?\s*$"
107107
).match,
108108
) -> Optional[numpy.datetime64]:
109+
if isinstance(scalar, numpy.datetime64):
110+
return scalar
111+
109112
# Convert pyarrow values to datetime.time.
110113
if isinstance(scalar, (pyarrow.Time32Scalar, pyarrow.Time64Scalar)):
111114
scalar = (
@@ -116,7 +119,7 @@ def _datetime(
116119
)
117120

118121
if pandas.isna(scalar):
119-
return None
122+
return numpy.datetime64("NaT")
120123
if isinstance(scalar, datetime.time):
121124
return pandas.Timestamp(
122125
year=1970,
@@ -238,12 +241,15 @@ def _datetime(
238241
scalar,
239242
match_fn=re.compile(r"\s*(?P<year>\d+)-(?P<month>\d+)-(?P<day>\d+)\s*$").match,
240243
) -> Optional[numpy.datetime64]:
244+
if isinstance(scalar, numpy.datetime64):
245+
return scalar
246+
241247
# Convert pyarrow values to datetime.date.
242248
if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)):
243249
scalar = scalar.as_py()
244250

245251
if pandas.isna(scalar):
246-
return None
252+
return numpy.datetime64("NaT")
247253
elif isinstance(scalar, datetime.date):
248254
return pandas.Timestamp(
249255
year=scalar.year, month=scalar.month, day=scalar.day

db_dtypes/core.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,6 @@ def _cmp_method(self, other, op):
100100
return NotImplemented
101101
return op(self._ndarray, other._ndarray)
102102

103-
def __setitem__(self, key, value):
104-
if is_list_like(value):
105-
_datetime = self._datetime
106-
value = [_datetime(v) for v in value]
107-
elif not pandas.isna(value):
108-
value = self._datetime(value)
109-
return super().__setitem__(key, value)
110-
111103
def _from_factorized(self, unique, original):
112104
return self.__class__(unique)
113105

@@ -121,6 +113,16 @@ def _validate_scalar(self, value):
121113
"""
122114
return self._datetime(value)
123115

116+
def _validate_setitem_value(self, value):
117+
"""
118+
Convert a value for use in setting a value in the backing numpy array.
119+
"""
120+
if is_list_like(value):
121+
_datetime = self._datetime
122+
return [_datetime(v) for v in value]
123+
124+
return self._datetime(value)
125+
124126
def any(
125127
self,
126128
*,

db_dtypes/pandas_backports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __getitem__(self, index):
126126
return self.__class__(value, self._dtype)
127127

128128
def __setitem__(self, index, value):
129-
self._ndarray[index] = value
129+
self._ndarray[index] = self._validate_setitem_value(value)
130130

131131
def __len__(self):
132132
return len(self._ndarray)

tests/unit/test_date.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,33 @@
2424
from db_dtypes import pandas_backports
2525

2626

27+
VALUE_PARSING_TEST_CASES = [
28+
# Min/Max values for pandas.Timestamp.
29+
("1677-09-22", datetime.date(1677, 9, 22)),
30+
("2262-04-11", datetime.date(2262, 4, 11)),
31+
# Typical "zero" values.
32+
("1900-01-01", datetime.date(1900, 1, 1)),
33+
("1970-01-01", datetime.date(1970, 1, 1)),
34+
# Assorted values.
35+
("1993-10-31", datetime.date(1993, 10, 31)),
36+
(datetime.date(1993, 10, 31), datetime.date(1993, 10, 31)),
37+
("2012-02-29", datetime.date(2012, 2, 29)),
38+
(numpy.datetime64("2012-02-29"), datetime.date(2012, 2, 29)),
39+
("2021-12-17", datetime.date(2021, 12, 17)),
40+
(pandas.Timestamp("2021-12-17"), datetime.date(2021, 12, 17)),
41+
("2038-01-19", datetime.date(2038, 1, 19)),
42+
]
43+
44+
NULL_VALUE_TEST_CASES = [
45+
None,
46+
pandas.NaT,
47+
float("nan"),
48+
]
49+
50+
if hasattr(pandas, "NA"):
51+
NULL_VALUE_TEST_CASES.append(pandas.NA)
52+
53+
2754
def test_box_func():
2855
input_array = db_dtypes.DateArray([])
2956
input_datetime = datetime.datetime(2022, 3, 16)
@@ -58,26 +85,49 @@ def test__cmp_method_with_scalar():
5885
assert got[0]
5986

6087

61-
@pytest.mark.parametrize(
62-
"value, expected",
63-
[
64-
# Min/Max values for pandas.Timestamp.
65-
("1677-09-22", datetime.date(1677, 9, 22)),
66-
("2262-04-11", datetime.date(2262, 4, 11)),
67-
# Typical "zero" values.
68-
("1900-01-01", datetime.date(1900, 1, 1)),
69-
("1970-01-01", datetime.date(1970, 1, 1)),
70-
# Assorted values.
71-
("1993-10-31", datetime.date(1993, 10, 31)),
72-
("2012-02-29", datetime.date(2012, 2, 29)),
73-
("2021-12-17", datetime.date(2021, 12, 17)),
74-
("2038-01-19", datetime.date(2038, 1, 19)),
75-
],
76-
)
88+
@pytest.mark.parametrize("value, expected", VALUE_PARSING_TEST_CASES)
7789
def test_date_parsing(value, expected):
7890
assert pandas.Series([value], dtype="dbdate")[0] == expected
7991

8092

93+
@pytest.mark.parametrize("value", NULL_VALUE_TEST_CASES)
94+
def test_date_parsing_null(value):
95+
assert pandas.Series([value], dtype="dbdate")[0] is pandas.NaT
96+
97+
98+
@pytest.mark.parametrize("value, expected", VALUE_PARSING_TEST_CASES)
99+
def test_date_set_item(value, expected):
100+
series = pandas.Series([None], dtype="dbdate")
101+
series[0] = value
102+
assert series[0] == expected
103+
104+
105+
@pytest.mark.parametrize("value", NULL_VALUE_TEST_CASES)
106+
def test_date_set_item_null(value):
107+
series = pandas.Series(["1970-01-01"], dtype="dbdate")
108+
series[0] = value
109+
assert series[0] is pandas.NaT
110+
111+
112+
def test_date_set_slice():
113+
series = pandas.Series([None, None, None], dtype="dbdate")
114+
series[:] = [
115+
datetime.date(2022, 3, 21),
116+
"2011-12-13",
117+
numpy.datetime64("1998-09-04"),
118+
]
119+
assert series[0] == datetime.date(2022, 3, 21)
120+
assert series[1] == datetime.date(2011, 12, 13)
121+
assert series[2] == datetime.date(1998, 9, 4)
122+
123+
124+
def test_date_set_slice_null():
125+
series = pandas.Series(["1970-01-01"] * len(NULL_VALUE_TEST_CASES), dtype="dbdate")
126+
series[:] = NULL_VALUE_TEST_CASES
127+
for row_index in range(len(NULL_VALUE_TEST_CASES)):
128+
assert series[row_index] is pandas.NaT
129+
130+
81131
@pytest.mark.parametrize(
82132
"value, error",
83133
[

tests/unit/test_time.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,14 @@ def test_box_func():
7373
# Fractional seconds can cause rounding problems if cast to float. See:
7474
# https://github.com/googleapis/python-db-dtypes-pandas/issues/18
7575
("0:0:59.876543", datetime.time(0, 0, 59, 876543)),
76+
(
77+
numpy.datetime64("1970-01-01 00:00:59.876543"),
78+
datetime.time(0, 0, 59, 876543),
79+
),
7680
("01:01:01.010101", datetime.time(1, 1, 1, 10101)),
81+
(pandas.Timestamp("1970-01-01 01:01:01.010101"), datetime.time(1, 1, 1, 10101)),
7782
("09:09:09.090909", datetime.time(9, 9, 9, 90909)),
83+
(datetime.time(9, 9, 9, 90909), datetime.time(9, 9, 9, 90909)),
7884
("11:11:11.111111", datetime.time(11, 11, 11, 111111)),
7985
("19:16:23.987654", datetime.time(19, 16, 23, 987654)),
8086
# Microsecond precision

0 commit comments

Comments
 (0)