From 8d0165fd134efa0107055d7e8aad1cd9b8b2cdff Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 18 Apr 2023 08:07:28 -0700 Subject: [PATCH 1/2] PERF: dtype checks --- pandas/_testing/asserters.py | 3 +-- pandas/core/algorithms.py | 7 +++---- pandas/core/arrays/categorical.py | 9 ++++----- pandas/core/arrays/datetimes.py | 7 +------ pandas/core/arrays/masked.py | 27 +++++++++++++------------- pandas/core/arrays/timedeltas.py | 9 ++++----- pandas/core/dtypes/cast.py | 13 +++++-------- pandas/core/dtypes/common.py | 13 ++++++------- pandas/core/indexes/base.py | 2 +- pandas/core/internals/array_manager.py | 3 +-- pandas/core/methods/describe.py | 8 ++++---- pandas/core/methods/to_dict.py | 9 ++++----- pandas/core/ops/array_ops.py | 23 +++++++++++----------- pandas/io/pytables.py | 2 +- pandas/plotting/_matplotlib/core.py | 2 +- 15 files changed, 61 insertions(+), 76 deletions(-) diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index c2ea82061dc63..c0d1f1eba9e09 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -16,7 +16,6 @@ from pandas.core.dtypes.common import ( is_bool, - is_extension_array_dtype, is_integer_dtype, is_number, is_numeric_dtype, @@ -316,7 +315,7 @@ def _get_ilevel_values(index, level): if not left.equals(right): mismatch = left._values != right._values - if is_extension_array_dtype(mismatch): + if not isinstance(mismatch, np.ndarray): mismatch = cast("ExtensionArray", mismatch).fillna(True) diff = np.sum(mismatch.astype(int)) * 100.0 / len(left) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 67b7dc0ac709d..4f771b3c80791 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -49,7 +49,6 @@ is_integer, is_integer_dtype, is_list_like, - is_numeric_dtype, is_object_dtype, is_scalar, is_signed_integer_dtype, @@ -471,7 +470,7 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> npt.NDArray[np.bool_]: if ( len(values) > 0 - and is_numeric_dtype(values.dtype) + and values.dtype.kind in "iufcb" and not is_signed_integer_dtype(comps) ): # GH#46485 Use object to avoid upcast to float64 later @@ -1403,7 +1402,7 @@ def diff(arr, n: int, axis: AxisInt = 0): ) is_timedelta = False - if needs_i8_conversion(arr.dtype): + if arr.dtype.kind in "mM": dtype = np.int64 arr = arr.view("i8") na = iNaT @@ -1413,7 +1412,7 @@ def diff(arr, n: int, axis: AxisInt = 0): # We have to cast in order to be able to hold np.nan dtype = np.object_ - elif is_integer_dtype(dtype): + elif dtype.kind in "iu": # We have to cast in order to be able to hold np.nan # int8, int16 are incompatible with float64, diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index a435cb2e4eb33..12bac08175a31 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -40,7 +40,6 @@ is_bool_dtype, is_dict_like, is_dtype_equal, - is_extension_array_dtype, is_hashable, is_integer_dtype, is_list_like, @@ -618,7 +617,7 @@ def _from_inferred_categories( if known_categories: # Convert to a specialized type with `dtype` if specified. - if is_any_real_numeric_dtype(dtype.categories): + if is_any_real_numeric_dtype(dtype.categories.dtype): cats = to_numeric(inferred_categories, errors="coerce") elif lib.is_np_dtype(dtype.categories.dtype, "M"): cats = to_datetime(inferred_categories, errors="coerce") @@ -701,7 +700,7 @@ def from_codes( ) raise ValueError(msg) - if is_extension_array_dtype(codes) and is_integer_dtype(codes): + if isinstance(codes, ExtensionArray) and is_integer_dtype(codes.dtype): # Avoid the implicit conversion of Int to object if isna(codes).any(): raise ValueError("codes cannot contain NA values") @@ -1598,7 +1597,7 @@ def _internal_get_values(self): # if we are a datetime and period index, return Index to keep metadata if needs_i8_conversion(self.categories.dtype): return self.categories.take(self._codes, fill_value=NaT) - elif is_integer_dtype(self.categories) and -1 in self._codes: + elif is_integer_dtype(self.categories.dtype) and -1 in self._codes: return self.categories.astype("object").take(self._codes, fill_value=np.nan) return np.array(self) @@ -1809,7 +1808,7 @@ def _values_for_rank(self) -> np.ndarray: if mask.any(): values = values.astype("float64") values[mask] = np.nan - elif is_any_real_numeric_dtype(self.categories): + elif is_any_real_numeric_dtype(self.categories.dtype): values = np.array(self) else: # reorder the categories (so rank can use the float codes) diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 5c651bc3674da..a6ef01c3a956f 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -53,7 +53,6 @@ is_datetime64_any_dtype, is_dtype_equal, is_float_dtype, - is_object_dtype, is_sparse, is_string_dtype, pandas_dtype, @@ -2038,11 +2037,7 @@ def _sequence_to_dt64ns( if out_unit is not None: out_dtype = np.dtype(f"M8[{out_unit}]") - if ( - is_object_dtype(data_dtype) - or is_string_dtype(data_dtype) - or is_sparse(data_dtype) - ): + if data_dtype == object or is_string_dtype(data_dtype) or is_sparse(data_dtype): # TODO: We do not have tests specific to string-dtypes, # also complex or categorical or other extension copy = False diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 9861cb61282c3..8ca5c362079b8 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -41,12 +41,9 @@ from pandas.core.dtypes.base import ExtensionDtype from pandas.core.dtypes.common import ( is_bool, - is_bool_dtype, is_dtype_equal, - is_float_dtype, is_integer_dtype, is_list_like, - is_object_dtype, is_scalar, is_string_dtype, pandas_dtype, @@ -408,9 +405,11 @@ def to_numpy( na_value = libmissing.NA if dtype is None: dtype = object + else: + dtype = np.dtype(dtype) if self._hasna: if ( - not is_object_dtype(dtype) + dtype != object and not is_string_dtype(dtype) and na_value is libmissing.NA ): @@ -545,7 +544,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): else: inputs2.append(x) - def reconstruct(x): + def reconstruct(x: np.ndarray): # we don't worry about scalar `x` here, since we # raise for reduce up above. from pandas.core.arrays import ( @@ -554,13 +553,13 @@ def reconstruct(x): IntegerArray, ) - if is_bool_dtype(x.dtype): + if x.dtype.kind == "b": m = mask.copy() return BooleanArray(x, m) - elif is_integer_dtype(x.dtype): + elif x.dtype.kind in "iu": m = mask.copy() return IntegerArray(x, m) - elif is_float_dtype(x.dtype): + elif x.dtype.kind == "f": m = mask.copy() if x.dtype == np.float16: # reached in e.g. np.sqrt on BooleanArray @@ -763,7 +762,9 @@ def _cmp_method(self, other, op) -> BooleanArray: mask = self._propagate_mask(mask, other) return BooleanArray(result, mask, copy=False) - def _maybe_mask_result(self, result, mask): + def _maybe_mask_result( + self, result: np.ndarray | tuple[np.ndarray, np.ndarray], mask: np.ndarray + ): """ Parameters ---------- @@ -778,12 +779,12 @@ def _maybe_mask_result(self, result, mask): self._maybe_mask_result(mod, mask), ) - if is_float_dtype(result.dtype): + if result.dtype.kind == "f": from pandas.core.arrays import FloatingArray return FloatingArray(result, mask, copy=False) - elif is_bool_dtype(result.dtype): + elif result.dtype.kind == "b": from pandas.core.arrays import BooleanArray return BooleanArray(result, mask, copy=False) @@ -800,7 +801,7 @@ def _maybe_mask_result(self, result, mask): result[mask] = result.dtype.type("NaT") return result - elif is_integer_dtype(result.dtype): + elif result.dtype.kind in "iu": from pandas.core.arrays import IntegerArray return IntegerArray(result, mask, copy=False) @@ -875,7 +876,7 @@ def isin(self, values) -> BooleanArray: # type: ignore[override] result = isin(self._data, values_arr) if self._hasna: - values_have_NA = is_object_dtype(values_arr.dtype) and any( + values_have_NA = values_arr.dtype == object and any( val is self.dtype.na_value for val in values_arr ) diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 5bc9a7c6b51ab..553a8ecc4dc4c 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -47,15 +47,14 @@ from pandas.core.dtypes.common import ( TD64NS_DTYPE, is_dtype_equal, - is_extension_array_dtype, is_float_dtype, is_integer_dtype, is_object_dtype, is_scalar, is_string_dtype, - is_timedelta64_dtype, pandas_dtype, ) +from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.missing import isna from pandas.core import ( @@ -137,7 +136,7 @@ class TimedeltaArray(dtl.TimelikeOps): _typ = "timedeltaarray" _internal_fill_value = np.timedelta64("NaT", "ns") _recognized_scalars = (timedelta, np.timedelta64, Tick) - _is_recognized_dtype = is_timedelta64_dtype + _is_recognized_dtype = lambda x: lib.is_np_dtype(x, "m") _infer_matches = ("timedelta", "timedelta64") @property @@ -912,7 +911,7 @@ def sequence_to_td64ns( inferred_freq = data.freq # Convert whatever we have into timedelta64[ns] dtype - if is_object_dtype(data.dtype) or is_string_dtype(data.dtype): + if data.dtype == object or is_string_dtype(data.dtype): # no need to make a copy, need to convert if string-dtyped data = _objects_to_td64ns(data, unit=unit, errors=errors) copy = False @@ -925,7 +924,7 @@ def sequence_to_td64ns( elif is_float_dtype(data.dtype): # cast the unit, multiply base/frac separately # to avoid precision issues from float -> int - if is_extension_array_dtype(data.dtype): + if isinstance(data.dtype, ExtensionDtype): mask = data._mask data = data._data else: diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 3e41fdf5a7634..3929775283f6a 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -53,8 +53,6 @@ is_extension_array_dtype, is_float, is_integer, - is_integer_dtype, - is_numeric_dtype, is_object_dtype, is_scalar, is_string_dtype, @@ -472,7 +470,7 @@ def maybe_cast_pointwise_result( else: result = maybe_cast_to_extension_array(cls, result) - elif (numeric_only and is_numeric_dtype(dtype)) or not numeric_only: + elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) return result @@ -1041,13 +1039,13 @@ def convert_dtypes( if convert_integer: target_int_dtype = pandas_dtype_func("Int64") - if is_integer_dtype(input_array.dtype): + if input_array.dtype.kind in "iu": from pandas.core.arrays.integer import INT_STR_TO_DTYPE inferred_dtype = INT_STR_TO_DTYPE.get( input_array.dtype.name, target_int_dtype ) - elif is_numeric_dtype(input_array.dtype): + elif input_array.dtype.kind in "fcb": # TODO: de-dup with maybe_cast_to_integer_array? arr = input_array[notna(input_array)] if (arr.astype(int) == arr).all(): @@ -1062,9 +1060,8 @@ def convert_dtypes( inferred_dtype = target_int_dtype if convert_floating: - if not is_integer_dtype(input_array.dtype) and is_numeric_dtype( - input_array.dtype - ): + if input_array.dtype.kind in "fcb": + # i.e. numeric but not integer from pandas.core.arrays.floating import FLOAT_STR_TO_DTYPE inferred_float_dtype: DtypeObj = FLOAT_STR_TO_DTYPE.get( diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index dcf4c25f14e2f..67fb5a81ecabe 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -122,7 +122,7 @@ def classes(*klasses) -> Callable: return lambda tipo: issubclass(tipo, klasses) -def classes_and_not_datetimelike(*klasses) -> Callable: +def _classes_and_not_datetimelike(*klasses) -> Callable: """ Evaluate if the tipo is a subclass of the klasses and not a datetimelike. @@ -654,7 +654,7 @@ def is_integer_dtype(arr_or_dtype) -> bool: False """ return _is_dtype_type( - arr_or_dtype, classes_and_not_datetimelike(np.integer) + arr_or_dtype, _classes_and_not_datetimelike(np.integer) ) or _is_dtype( arr_or_dtype, lambda typ: isinstance(typ, ExtensionDtype) and typ.kind in "iu" ) @@ -713,7 +713,7 @@ def is_signed_integer_dtype(arr_or_dtype) -> bool: False """ return _is_dtype_type( - arr_or_dtype, classes_and_not_datetimelike(np.signedinteger) + arr_or_dtype, _classes_and_not_datetimelike(np.signedinteger) ) or _is_dtype( arr_or_dtype, lambda typ: isinstance(typ, ExtensionDtype) and typ.kind == "i" ) @@ -763,7 +763,7 @@ def is_unsigned_integer_dtype(arr_or_dtype) -> bool: True """ return _is_dtype_type( - arr_or_dtype, classes_and_not_datetimelike(np.unsignedinteger) + arr_or_dtype, _classes_and_not_datetimelike(np.unsignedinteger) ) or _is_dtype( arr_or_dtype, lambda typ: isinstance(typ, ExtensionDtype) and typ.kind == "u" ) @@ -1087,7 +1087,7 @@ def is_numeric_dtype(arr_or_dtype) -> bool: False """ return _is_dtype_type( - arr_or_dtype, classes_and_not_datetimelike(np.number, np.bool_) + arr_or_dtype, _classes_and_not_datetimelike(np.number, np.bool_) ) or _is_dtype( arr_or_dtype, lambda typ: isinstance(typ, ExtensionDtype) and typ._is_numeric ) @@ -1490,7 +1490,7 @@ def infer_dtype_from_object(dtype) -> type: except TypeError: pass - if is_extension_array_dtype(dtype): + if isinstance(dtype, ExtensionDtype): return dtype.type elif isinstance(dtype, str): # TODO(jreback) @@ -1644,7 +1644,6 @@ def is_all_strings(value: ArrayLike) -> bool: __all__ = [ "classes", - "classes_and_not_datetimelike", "DT64NS_DTYPE", "ensure_float64", "ensure_python_int", diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c13ce8079b669..dfd8840aafa5a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -6138,7 +6138,7 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: elif is_numeric_dtype(self.dtype): return is_numeric_dtype(dtype) # TODO: this was written assuming we only get here with object-dtype, - # which is nom longer correct. Can we specialize for EA? + # which is no longer correct. Can we specialize for EA? return True @final diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 203fc9c7f78cb..44f49bac9eea6 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -29,7 +29,6 @@ ensure_platform_int, is_datetime64_ns_dtype, is_dtype_equal, - is_extension_array_dtype, is_integer, is_numeric_dtype, is_object_dtype, @@ -1125,7 +1124,7 @@ def as_array( dtype = dtype.subtype elif isinstance(dtype, PandasDtype): dtype = dtype.numpy_dtype - elif is_extension_array_dtype(dtype): + elif isinstance(dtype, ExtensionDtype): dtype = "object" elif is_dtype_equal(dtype, str): dtype = "object" diff --git a/pandas/core/methods/describe.py b/pandas/core/methods/describe.py index 2fcb0de6b5451..4c997f2f0847c 100644 --- a/pandas/core/methods/describe.py +++ b/pandas/core/methods/describe.py @@ -30,11 +30,10 @@ from pandas.core.dtypes.common import ( is_bool_dtype, - is_complex_dtype, is_datetime64_any_dtype, - is_extension_array_dtype, is_numeric_dtype, ) +from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.floating import Float64Dtype @@ -229,14 +228,15 @@ def describe_numeric_1d(series: Series, percentiles: Sequence[float]) -> Series: ) # GH#48340 - always return float on non-complex numeric data dtype: DtypeObj | None - if is_extension_array_dtype(series.dtype): + if isinstance(series.dtype, ExtensionDtype): if isinstance(series.dtype, ArrowDtype): import pyarrow as pa dtype = ArrowDtype(pa.float64()) else: dtype = Float64Dtype() - elif is_numeric_dtype(series.dtype) and not is_complex_dtype(series.dtype): + elif series.dtype.kind in "iufb": + # i.e. numeric but exclude complex dtype dtype = np.dtype("float") else: dtype = None diff --git a/pandas/core/methods/to_dict.py b/pandas/core/methods/to_dict.py index 5614b612660b9..e89f641e17296 100644 --- a/pandas/core/methods/to_dict.py +++ b/pandas/core/methods/to_dict.py @@ -6,13 +6,12 @@ ) import warnings +import numpy as np + from pandas.util._exceptions import find_stack_level from pandas.core.dtypes.cast import maybe_box_native -from pandas.core.dtypes.common import ( - is_extension_array_dtype, - is_object_dtype, -) +from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core import common as com @@ -99,7 +98,7 @@ def to_dict( box_native_indices = [ i for i, col_dtype in enumerate(df.dtypes.values) - if is_object_dtype(col_dtype) or is_extension_array_dtype(col_dtype) + if col_dtype == np.dtype(object) or isinstance(col_dtype, ExtensionDtype) ] are_all_object_dtype_cols = len(box_native_indices) == len(df.dtypes) diff --git a/pandas/core/ops/array_ops.py b/pandas/core/ops/array_ops.py index d8960d4faf0cc..c4f873ef079f5 100644 --- a/pandas/core/ops/array_ops.py +++ b/pandas/core/ops/array_ops.py @@ -32,7 +32,6 @@ from pandas.core.dtypes.common import ( ensure_object, is_bool_dtype, - is_integer_dtype, is_list_like, is_numeric_v_string_like, is_object_dtype, @@ -213,7 +212,9 @@ def _na_arithmetic_op(left: np.ndarray, right, op, is_cmp: bool = False): try: result = func(left, right) except TypeError: - if not is_cmp and (is_object_dtype(left.dtype) or is_object_dtype(right)): + if not is_cmp and ( + left.dtype == object or getattr(right, "dtype", None) == object + ): # For object dtype, fallback to a masked operation (only operating # on the non-missing values) # Don't do this for comparisons, as that will handle complex numbers @@ -316,7 +317,7 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike: if should_extension_dispatch(lvalues, rvalues) or ( (isinstance(rvalues, (Timedelta, BaseOffset, Timestamp)) or right is NaT) - and not is_object_dtype(lvalues.dtype) + and lvalues.dtype != object ): # Call the method on lvalues res_values = op(lvalues, rvalues) @@ -332,7 +333,7 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike: # GH#36377 going through the numexpr path would incorrectly raise return invalid_comparison(lvalues, rvalues, op) - elif is_object_dtype(lvalues.dtype) or isinstance(rvalues, str): + elif lvalues.dtype == object or isinstance(rvalues, str): res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues) else: @@ -355,7 +356,7 @@ def na_logical_op(x: np.ndarray, y, op): except TypeError: if isinstance(y, np.ndarray): # bool-bool dtype operations should be OK, should not get here - assert not (is_bool_dtype(x.dtype) and is_bool_dtype(y.dtype)) + assert not (x.dtype.kind == "b" and y.dtype.kind == "b") x = ensure_object(x) y = ensure_object(y) result = libops.vec_binop(x.ravel(), y.ravel(), op) @@ -408,7 +409,7 @@ def fill_bool(x, left=None): x = x.astype(object) x[mask] = False - if left is None or is_bool_dtype(left.dtype): + if left is None or left.dtype.kind == "b": x = x.astype(bool) return x @@ -435,7 +436,7 @@ def fill_bool(x, left=None): else: if isinstance(rvalues, np.ndarray): - is_other_int_dtype = is_integer_dtype(rvalues.dtype) + is_other_int_dtype = rvalues.dtype.kind in "iu" if not is_other_int_dtype: rvalues = fill_bool(rvalues, lvalues) @@ -447,7 +448,7 @@ def fill_bool(x, left=None): # For int vs int `^`, `|`, `&` are bitwise operators and return # integer dtypes. Otherwise these are boolean ops - if not (is_integer_dtype(left.dtype) and is_other_int_dtype): + if not (left.dtype.kind in "iu" and is_other_int_dtype): res_values = fill_bool(res_values) return res_values @@ -565,15 +566,13 @@ def maybe_prepare_scalar_for_op(obj, shape: Shape): } -def _bool_arith_check(op, a, b): +def _bool_arith_check(op, a: np.ndarray, b): """ In contrast to numpy, pandas raises an error for certain operations with booleans. """ if op in _BOOL_OP_NOT_ALLOWED: - if is_bool_dtype(a.dtype) and ( - is_bool_dtype(b) or isinstance(b, (bool, np.bool_)) - ): + if a.dtype.kind == "b" and (is_bool_dtype(b) or lib.is_bool(b)): op_name = op.__name__.strip("_").lstrip("r") raise NotImplementedError( f"operator '{op_name}' not implemented for bool dtypes" diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 7ba3fe98f59bc..2a522ef6b5171 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -2555,7 +2555,7 @@ class DataIndexableCol(DataCol): is_data_indexable = True def validate_names(self) -> None: - if not is_object_dtype(Index(self.values)): + if not is_object_dtype(Index(self.values).dtype): # TODO: should the message here be more specifically non-str? raise ValueError("cannot have non-object label DataIndexableCol") diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index f667de6a5a34c..d1e8f92ffc36b 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -844,7 +844,7 @@ def _get_xticks(self, convert_period: bool = False): if convert_period and isinstance(index, ABCPeriodIndex): self.data = self.data.reindex(index=index.sort_values()) x = self.data.index.to_timestamp()._mpl_repr() - elif is_any_real_numeric_dtype(index): + elif is_any_real_numeric_dtype(index.dtype): # Matplotlib supports numeric values or datetime objects as # xaxis values. Taking LBYL approach here, by the time # matplotlib raises exception when using non numeric/datetime From 6ed7b2dfc13069ae263ca7033866ed0411beedf5 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 18 Apr 2023 12:25:31 -0700 Subject: [PATCH 2/2] mypy fixup --- pandas/core/arrays/masked.py | 5 +++-- pandas/core/ops/array_ops.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 8ca5c362079b8..f1df86788ac44 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -795,10 +795,11 @@ def _maybe_mask_result( # e.g. test_numeric_arr_mul_tdscalar_numexpr_path from pandas.core.arrays import TimedeltaArray + result[mask] = result.dtype.type("NaT") + if not isinstance(result, TimedeltaArray): - result = TimedeltaArray._simple_new(result, dtype=result.dtype) + return TimedeltaArray._simple_new(result, dtype=result.dtype) - result[mask] = result.dtype.type("NaT") return result elif result.dtype.kind in "iu": diff --git a/pandas/core/ops/array_ops.py b/pandas/core/ops/array_ops.py index c4f873ef079f5..8b39089bfb1d5 100644 --- a/pandas/core/ops/array_ops.py +++ b/pandas/core/ops/array_ops.py @@ -269,7 +269,9 @@ def arithmetic_op(left: ArrayLike, right: Any, op): else: # TODO we should handle EAs consistently and move this check before the if/else # (https://github.com/pandas-dev/pandas/issues/41165) - _bool_arith_check(op, left, right) + # error: Argument 2 to "_bool_arith_check" has incompatible type + # "Union[ExtensionArray, ndarray[Any, Any]]"; expected "ndarray[Any, Any]" + _bool_arith_check(op, left, right) # type: ignore[arg-type] # error: Argument 1 to "_na_arithmetic_op" has incompatible type # "Union[ExtensionArray, ndarray[Any, Any]]"; expected "ndarray[Any, Any]"