From 03f7fe63638faed0fdb1869ab6b9586ae65788fa Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 10 Feb 2021 13:47:40 -0800 Subject: [PATCH] CLN: remove ndarray cases from maybe_promote --- pandas/core/dtypes/cast.py | 29 +++++++----------------- pandas/tests/dtypes/cast/test_promote.py | 28 ----------------------- 2 files changed, 8 insertions(+), 49 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index ed36beb80986e..58f2ab0de489a 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -484,7 +484,7 @@ def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray) -> np.ndarray: return result -def maybe_promote(dtype, fill_value=np.nan): +def maybe_promote(dtype: DtypeObj, fill_value=np.nan): """ Find the minimal dtype that can hold both the given dtype and fill_value. @@ -505,28 +505,15 @@ def maybe_promote(dtype, fill_value=np.nan): ValueError If fill_value is a non-scalar and dtype is not object. """ - if not is_scalar(fill_value) and not is_object_dtype(dtype): + if not is_scalar(fill_value): # with object dtype there is nothing to promote, and the user can # pass pretty much any weird fill_value they like - raise ValueError("fill_value must be a scalar") - - # if we passed an array here, determine the fill value by dtype - if isinstance(fill_value, np.ndarray): - if issubclass(fill_value.dtype.type, (np.datetime64, np.timedelta64)): - fill_value = fill_value.dtype.type("NaT", "ns") - else: - - # we need to change to object type as our - # fill_value is of object type - if fill_value.dtype == np.object_: - dtype = np.dtype(np.object_) - fill_value = np.nan - - if dtype == np.object_ or dtype.kind in ["U", "S"]: - # We treat string-like dtypes as object, and _always_ fill - # with np.nan - fill_value = np.nan - dtype = np.dtype(np.object_) + if not is_object_dtype(dtype): + # with object dtype there is nothing to promote, and the user can + # pass pretty much any weird fill_value they like + raise ValueError("fill_value must be a scalar") + dtype = np.dtype(object) + return dtype, fill_value # returns tuple of (dtype, fill_value) if issubclass(dtype.type, np.datetime64): diff --git a/pandas/tests/dtypes/cast/test_promote.py b/pandas/tests/dtypes/cast/test_promote.py index 89b45890458c5..16caf935652cb 100644 --- a/pandas/tests/dtypes/cast/test_promote.py +++ b/pandas/tests/dtypes/cast/test_promote.py @@ -605,31 +605,3 @@ def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, nulls_fi exp_val_for_scalar = np.nan _check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar) - - -@pytest.mark.parametrize("dim", [0, 2, 3]) -def test_maybe_promote_dimensions(any_numpy_dtype_reduced, dim): - dtype = np.dtype(any_numpy_dtype_reduced) - - # create 0-dim array of given dtype; casts "1" to correct dtype - fill_array = np.array(1, dtype=dtype) - - # expand to desired dimension: - for _ in range(dim): - fill_array = np.expand_dims(fill_array, 0) - - if dtype != object: - # test against 1-dimensional case - with pytest.raises(ValueError, match="fill_value must be a scalar"): - maybe_promote(dtype, np.array([1], dtype=dtype)) - - with pytest.raises(ValueError, match="fill_value must be a scalar"): - maybe_promote(dtype, fill_array) - - else: - expected_dtype, expected_missing_value = maybe_promote( - dtype, np.array([1], dtype=dtype) - ) - result_dtype, result_missing_value = maybe_promote(dtype, fill_array) - assert result_dtype == expected_dtype - _assert_match(result_missing_value, expected_missing_value)