Skip to content

Commit dbf9ffe

Browse files
authored
REF: share MaskedArray methods (#45951)
1 parent bf2f6e5 commit dbf9ffe

File tree

3 files changed

+58
-63
lines changed

3 files changed

+58
-63
lines changed

pandas/core/arrays/boolean.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,3 @@ def _logical_method(self, other, op):
381381
# error: Argument 2 to "BooleanArray" has incompatible type "Optional[Any]";
382382
# expected "ndarray"
383383
return BooleanArray(result, mask) # type: ignore[arg-type]
384-
385-
def __abs__(self):
386-
return self.copy()

pandas/core/arrays/masked.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@
4343
is_bool_dtype,
4444
is_datetime64_dtype,
4545
is_dtype_equal,
46-
is_float,
4746
is_float_dtype,
4847
is_integer_dtype,
4948
is_list_like,
50-
is_numeric_dtype,
5149
is_object_dtype,
5250
is_scalar,
5351
is_string_dtype,
@@ -327,9 +325,53 @@ def ravel(self: BaseMaskedArrayT, *args, **kwargs) -> BaseMaskedArrayT:
327325
def T(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
328326
return type(self)(self._data.T, self._mask.T)
329327

328+
def round(self, decimals: int = 0, *args, **kwargs):
329+
"""
330+
Round each value in the array a to the given number of decimals.
331+
332+
Parameters
333+
----------
334+
decimals : int, default 0
335+
Number of decimal places to round to. If decimals is negative,
336+
it specifies the number of positions to the left of the decimal point.
337+
*args, **kwargs
338+
Additional arguments and keywords have no effect but might be
339+
accepted for compatibility with NumPy.
340+
341+
Returns
342+
-------
343+
NumericArray
344+
Rounded values of the NumericArray.
345+
346+
See Also
347+
--------
348+
numpy.around : Round values of an np.array.
349+
DataFrame.round : Round values of a DataFrame.
350+
Series.round : Round values of a Series.
351+
"""
352+
nv.validate_round(args, kwargs)
353+
values = np.round(self._data, decimals=decimals, **kwargs)
354+
355+
# Usually we'll get same type as self, but ndarray[bool] casts to float
356+
return self._maybe_mask_result(values, self._mask.copy())
357+
358+
# ------------------------------------------------------------------
359+
# Unary Methods
360+
330361
def __invert__(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
331362
return type(self)(~self._data, self._mask.copy())
332363

364+
def __neg__(self):
365+
return type(self)(-self._data, self._mask.copy())
366+
367+
def __pos__(self):
368+
return self.copy()
369+
370+
def __abs__(self):
371+
return type(self)(abs(self._data), self._mask.copy())
372+
373+
# ------------------------------------------------------------------
374+
333375
def to_numpy(
334376
self,
335377
dtype: npt.DTypeLike | None = None,
@@ -671,7 +713,7 @@ def _arith_method(self, other, op):
671713
# x ** 0 is 1.
672714
mask = np.where((self._data == 0) & ~self._mask, False, mask)
673715

674-
return self._maybe_mask_result(result, mask, other, op_name)
716+
return self._maybe_mask_result(result, mask)
675717

676718
def _cmp_method(self, other, op) -> BooleanArray:
677719
from pandas.core.arrays import BooleanArray
@@ -713,36 +755,27 @@ def _cmp_method(self, other, op) -> BooleanArray:
713755
mask = self._propagate_mask(mask, other)
714756
return BooleanArray(result, mask, copy=False)
715757

716-
def _maybe_mask_result(self, result, mask, other, op_name: str):
758+
def _maybe_mask_result(self, result, mask):
717759
"""
718760
Parameters
719761
----------
720-
result : array-like
762+
result : array-like or tuple[array-like]
721763
mask : array-like bool
722-
other : scalar or array-like
723-
op_name : str
724764
"""
725-
if op_name == "divmod":
726-
# divmod returns a tuple
765+
if isinstance(result, tuple):
766+
# i.e. divmod
727767
div, mod = result
728768
return (
729-
self._maybe_mask_result(div, mask, other, "floordiv"),
730-
self._maybe_mask_result(mod, mask, other, "mod"),
769+
self._maybe_mask_result(div, mask),
770+
self._maybe_mask_result(mod, mask),
731771
)
732772

733-
# if we have a float operand we are by-definition
734-
# a float result
735-
# or our op is a divide
736-
if (
737-
(is_float_dtype(other) or is_float(other))
738-
or (op_name in ["rtruediv", "truediv"])
739-
or (is_float_dtype(self.dtype) and is_numeric_dtype(result.dtype))
740-
):
773+
if is_float_dtype(result.dtype):
741774
from pandas.core.arrays import FloatingArray
742775

743776
return FloatingArray(result, mask, copy=False)
744777

745-
elif is_bool_dtype(result):
778+
elif is_bool_dtype(result.dtype):
746779
from pandas.core.arrays import BooleanArray
747780

748781
return BooleanArray(result, mask, copy=False)
@@ -757,7 +790,7 @@ def _maybe_mask_result(self, result, mask, other, op_name: str):
757790
result[mask] = result.dtype.type("NaT")
758791
return result
759792

760-
elif is_integer_dtype(result):
793+
elif is_integer_dtype(result.dtype):
761794
from pandas.core.arrays import IntegerArray
762795

763796
return IntegerArray(result, mask, copy=False)
@@ -980,6 +1013,9 @@ def _quantile(
9801013
out = np.asarray(res, dtype=np.float64) # type: ignore[assignment]
9811014
return out
9821015

1016+
# ------------------------------------------------------------------
1017+
# Reductions
1018+
9831019
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
9841020
if name in {"any", "all", "min", "max", "sum", "prod"}:
9851021
return getattr(self, name)(skipna=skipna, **kwargs)
@@ -1015,7 +1051,7 @@ def _wrap_reduction_result(self, name: str, result, skipna, **kwargs):
10151051
else:
10161052
mask = self._mask.any(axis=axis)
10171053

1018-
return self._maybe_mask_result(result, mask, other=None, op_name=name)
1054+
return self._maybe_mask_result(result, mask)
10191055
return result
10201056

10211057
def sum(self, *, skipna=True, min_count=0, axis: int | None = 0, **kwargs):

pandas/core/arrays/numeric.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Dtype,
1717
DtypeObj,
1818
)
19-
from pandas.compat.numpy import function as nv
2019
from pandas.errors import AbstractMethodError
2120

2221
from pandas.core.dtypes.common import (
@@ -211,40 +210,3 @@ def _from_sequence_of_strings(
211210
return cls._from_sequence(scalars, dtype=dtype, copy=copy)
212211

213212
_HANDLED_TYPES = (np.ndarray, numbers.Number)
214-
215-
def __neg__(self):
216-
return type(self)(-self._data, self._mask.copy())
217-
218-
def __pos__(self):
219-
return self.copy()
220-
221-
def __abs__(self):
222-
return type(self)(abs(self._data), self._mask.copy())
223-
224-
def round(self: T, decimals: int = 0, *args, **kwargs) -> T:
225-
"""
226-
Round each value in the array a to the given number of decimals.
227-
228-
Parameters
229-
----------
230-
decimals : int, default 0
231-
Number of decimal places to round to. If decimals is negative,
232-
it specifies the number of positions to the left of the decimal point.
233-
*args, **kwargs
234-
Additional arguments and keywords have no effect but might be
235-
accepted for compatibility with NumPy.
236-
237-
Returns
238-
-------
239-
NumericArray
240-
Rounded values of the NumericArray.
241-
242-
See Also
243-
--------
244-
numpy.around : Round values of an np.array.
245-
DataFrame.round : Round values of a DataFrame.
246-
Series.round : Round values of a Series.
247-
"""
248-
nv.validate_round(args, kwargs)
249-
values = np.round(self._data, decimals=decimals, **kwargs)
250-
return type(self)(values, self._mask.copy())

0 commit comments

Comments
 (0)