diff --git a/doc/source/user_guide/missing_data.rst b/doc/source/user_guide/missing_data.rst index 1cc485a229123..1bfe196cb2f89 100644 --- a/doc/source/user_guide/missing_data.rst +++ b/doc/source/user_guide/missing_data.rst @@ -822,6 +822,18 @@ For example, ``pd.NA`` propagates in arithmetic operations, similarly to pd.NA + 1 "a" * pd.NA +There are a few special cases when the result is known, even when one of the +operands is ``NA``. + + +================ ====== +Operation Result +================ ====== +``pd.NA ** 0`` 0 +``1 ** pd.NA`` 1 +``-1 ** pd.NA`` -1 +================ ====== + In equality and comparison operations, ``pd.NA`` also propagates. This deviates from the behaviour of ``np.nan``, where comparisons with ``np.nan`` always return ``False``. diff --git a/pandas/_libs/missing.pyx b/pandas/_libs/missing.pyx index 30832a8e4daab..63aa5501c5250 100644 --- a/pandas/_libs/missing.pyx +++ b/pandas/_libs/missing.pyx @@ -365,8 +365,6 @@ class NAType(C_NAType): __rmod__ = _create_binary_propagating_op("__rmod__") __divmod__ = _create_binary_propagating_op("__divmod__", divmod=True) __rdivmod__ = _create_binary_propagating_op("__rdivmod__", divmod=True) - __pow__ = _create_binary_propagating_op("__pow__") - __rpow__ = _create_binary_propagating_op("__rpow__") # __lshift__ and __rshift__ are not implemented __eq__ = _create_binary_propagating_op("__eq__") @@ -383,6 +381,30 @@ class NAType(C_NAType): __abs__ = _create_unary_propagating_op("__abs__") __invert__ = _create_unary_propagating_op("__invert__") + # pow has special + def __pow__(self, other): + if other is C_NA: + return NA + elif isinstance(other, (numbers.Number, np.bool_)): + if other == 0: + # returning positive is correct for +/- 0. + return type(other)(1) + else: + return NA + + return NotImplemented + + def __rpow__(self, other): + if other is C_NA: + return NA + elif isinstance(other, (numbers.Number, np.bool_)): + if other == 1 or other == -1: + return other + else: + return NA + + return NotImplemented + # Logical ops using Kleene logic def __and__(self, other): diff --git a/pandas/tests/scalar/test_na_scalar.py b/pandas/tests/scalar/test_na_scalar.py index 586433698a587..40db617c64717 100644 --- a/pandas/tests/scalar/test_na_scalar.py +++ b/pandas/tests/scalar/test_na_scalar.py @@ -38,11 +38,14 @@ def test_arithmetic_ops(all_arithmetic_functions): op = all_arithmetic_functions for other in [NA, 1, 1.0, "a", np.int64(1), np.nan]: - if op.__name__ == "rmod" and isinstance(other, str): + if op.__name__ in ("pow", "rpow", "rmod") and isinstance(other, str): continue if op.__name__ in ("divmod", "rdivmod"): assert op(NA, other) is (NA, NA) else: + if op.__name__ == "rpow": + # avoid special case + other += 1 assert op(NA, other) is NA @@ -69,6 +72,49 @@ def test_comparison_ops(): assert (other <= NA) is NA +@pytest.mark.parametrize( + "value", + [ + 0, + 0.0, + -0, + -0.0, + False, + np.bool_(False), + np.int_(0), + np.float_(0), + np.int_(-0), + np.float_(-0), + ], +) +def test_pow_special(value): + result = pd.NA ** value + assert isinstance(result, type(value)) + assert result == 1 + + +@pytest.mark.parametrize( + "value", + [ + 1, + 1.0, + -1, + -1.0, + True, + np.bool_(True), + np.int_(1), + np.float_(1), + np.int_(-1), + np.float_(-1), + ], +) +def test_rpow_special(value): + result = value ** pd.NA + assert result == value + if not isinstance(value, (np.float_, np.bool_, np.int_)): + assert isinstance(result, type(value)) + + def test_unary_ops(): assert +NA is NA assert -NA is NA