Skip to content

Commit ec2b663

Browse files
committed
Standardize cast_str behavior in all datetimelike fill_value validators (pandas-dev#36746)
* Standardize cast_str behavior in all datetimelike fill_value validators * CLN: remove cast_str kwarg
1 parent fa8e066 commit ec2b663

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

pandas/core/arrays/datetimelike.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,7 @@ def _validate_shift_value(self, fill_value):
752752

753753
return self._unbox(fill_value)
754754

755-
def _validate_scalar(
756-
self, value, msg: Optional[str] = None, cast_str: bool = False
757-
):
755+
def _validate_scalar(self, value, msg: Optional[str] = None):
758756
"""
759757
Validate that the input value can be cast to our scalar_type.
760758
@@ -765,14 +763,12 @@ def _validate_scalar(
765763
Message to raise in TypeError on invalid input.
766764
If not provided, `value` is cast to a str and used
767765
as the message.
768-
cast_str : bool, default False
769-
Whether to try to parse string input to scalar_type.
770766
771767
Returns
772768
-------
773769
self._scalar_type or NaT
774770
"""
775-
if cast_str and isinstance(value, str):
771+
if isinstance(value, str):
776772
# NB: Careful about tzawareness
777773
try:
778774
value = self._scalar_from_string(value)
@@ -794,9 +790,7 @@ def _validate_scalar(
794790

795791
return value
796792

797-
def _validate_listlike(
798-
self, value, opname: str, cast_str: bool = False, allow_object: bool = False
799-
):
793+
def _validate_listlike(self, value, opname: str, allow_object: bool = False):
800794
if isinstance(value, type(self)):
801795
return value
802796

@@ -805,7 +799,7 @@ def _validate_listlike(
805799
value = array(value)
806800
value = extract_array(value, extract_numpy=True)
807801

808-
if cast_str and is_dtype_equal(value.dtype, "string"):
802+
if is_dtype_equal(value.dtype, "string"):
809803
# We got a StringArray
810804
try:
811805
# TODO: Could use from_sequence_of_strings if implemented
@@ -835,9 +829,9 @@ def _validate_listlike(
835829
def _validate_searchsorted_value(self, value):
836830
msg = "searchsorted requires compatible dtype or scalar"
837831
if not is_list_like(value):
838-
value = self._validate_scalar(value, msg, cast_str=True)
832+
value = self._validate_scalar(value, msg)
839833
else:
840-
value = self._validate_listlike(value, "searchsorted", cast_str=True)
834+
value = self._validate_listlike(value, "searchsorted")
841835

842836
rv = self._unbox(value)
843837
return self._rebox_native(rv)
@@ -848,15 +842,15 @@ def _validate_setitem_value(self, value):
848842
f"or array of those. Got '{type(value).__name__}' instead."
849843
)
850844
if is_list_like(value):
851-
value = self._validate_listlike(value, "setitem", cast_str=True)
845+
value = self._validate_listlike(value, "setitem")
852846
else:
853-
value = self._validate_scalar(value, msg, cast_str=True)
847+
value = self._validate_scalar(value, msg)
854848

855849
return self._unbox(value, setitem=True)
856850

857851
def _validate_insert_value(self, value):
858852
msg = f"cannot insert {type(self).__name__} with incompatible label"
859-
value = self._validate_scalar(value, msg, cast_str=False)
853+
value = self._validate_scalar(value, msg)
860854

861855
self._check_compatible_with(value, setitem=True)
862856
# TODO: if we dont have compat, should we raise or astype(object)?

pandas/core/indexes/datetimelike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def _wrap_joined_index(self, joined: np.ndarray, other):
648648
def _convert_arr_indexer(self, keyarr):
649649
try:
650650
return self._data._validate_listlike(
651-
keyarr, "convert_arr_indexer", cast_str=True, allow_object=True
651+
keyarr, "convert_arr_indexer", allow_object=True
652652
)
653653
except (ValueError, TypeError):
654654
return com.asarray_tuplesafe(keyarr)

pandas/core/indexes/timedeltas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_loc(self, key, method=None, tolerance=None):
217217
raise InvalidIndexError(key)
218218

219219
try:
220-
key = self._data._validate_scalar(key, cast_str=True)
220+
key = self._data._validate_scalar(key)
221221
except TypeError as err:
222222
raise KeyError(key) from err
223223

pandas/tests/arrays/test_datetimelike.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ def test_take_fill(self):
160160
result = arr.take([-1, 1], allow_fill=True, fill_value=pd.NaT)
161161
assert result[0] is pd.NaT
162162

163+
def test_take_fill_str(self, arr1d):
164+
# Cast str fill_value matching other fill_value-taking methods
165+
result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1]))
166+
expected = arr1d[[-1, 1]]
167+
tm.assert_equal(result, expected)
168+
169+
msg = r"'fill_value' should be a <.*>\. Got 'foo'"
170+
with pytest.raises(ValueError, match=msg):
171+
arr1d.take([-1, 1], allow_fill=True, fill_value="foo")
172+
163173
def test_concat_same_type(self):
164174
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
165175

pandas/tests/indexes/datetimelike.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,24 @@ def test_not_equals_numeric(self):
115115
assert not index.equals(pd.Index(index.asi8))
116116
assert not index.equals(pd.Index(index.asi8.astype("u8")))
117117
assert not index.equals(pd.Index(index.asi8).astype("f8"))
118+
119+
def test_where_cast_str(self):
120+
index = self.create_index()
121+
122+
mask = np.ones(len(index), dtype=bool)
123+
mask[-1] = False
124+
125+
result = index.where(mask, str(index[0]))
126+
expected = index.where(mask, index[0])
127+
tm.assert_index_equal(result, expected)
128+
129+
result = index.where(mask, [str(index[0])])
130+
tm.assert_index_equal(result, expected)
131+
132+
msg = "Where requires matching dtype, not foo"
133+
with pytest.raises(TypeError, match=msg):
134+
index.where(mask, "foo")
135+
136+
msg = r"Where requires matching dtype, not \['foo'\]"
137+
with pytest.raises(TypeError, match=msg):
138+
index.where(mask, ["foo"])

0 commit comments

Comments
 (0)