diff --git a/pandas/tests/test_nanops.py b/pandas/tests/test_nanops.py index f7e652eb78e2d..852e1ce489893 100644 --- a/pandas/tests/test_nanops.py +++ b/pandas/tests/test_nanops.py @@ -19,6 +19,14 @@ has_c16 = hasattr(np, "complex128") +@pytest.fixture(params=[True, False]) +def skipna(request): + """ + Fixture to pass skipna to nanops functions. + """ + return request.param + + class TestnanopsDataFrame: def setup_method(self, method): np.random.seed(11235) @@ -89,28 +97,14 @@ def teardown_method(self, method): def check_results(self, targ, res, axis, check_dtype=True): res = getattr(res, "asm8", res) - res = getattr(res, "values", res) - - # timedeltas are a beast here - def _coerce_tds(targ, res): - if hasattr(targ, "dtype") and targ.dtype == "m8[ns]": - if len(targ) == 1: - targ = targ[0].item() - res = res.item() - else: - targ = targ.view("i8") - return targ, res - try: - if ( - axis != 0 - and hasattr(targ, "shape") - and targ.ndim - and targ.shape != res.shape - ): - res = np.split(res, [targ.shape[0]], axis=0)[0] - except (ValueError, IndexError): - targ, res = _coerce_tds(targ, res) + if ( + axis != 0 + and hasattr(targ, "shape") + and targ.ndim + and targ.shape != res.shape + ): + res = np.split(res, [targ.shape[0]], axis=0)[0] try: tm.assert_almost_equal(targ, res, check_dtype=check_dtype) @@ -118,9 +112,7 @@ def _coerce_tds(targ, res): # handle timedelta dtypes if hasattr(targ, "dtype") and targ.dtype == "m8[ns]": - targ, res = _coerce_tds(targ, res) - tm.assert_almost_equal(targ, res, check_dtype=check_dtype) - return + raise # There are sometimes rounding errors with # complex and object dtypes. @@ -149,29 +141,29 @@ def check_fun_data( targfunc, testarval, targarval, + skipna, check_dtype=True, empty_targfunc=None, **kwargs, ): for axis in list(range(targarval.ndim)) + [None]: - for skipna in [False, True]: - targartempval = targarval if skipna else testarval - if skipna and empty_targfunc and isna(targartempval).all(): - targ = empty_targfunc(targartempval, axis=axis, **kwargs) - else: - targ = targfunc(targartempval, axis=axis, **kwargs) + targartempval = targarval if skipna else testarval + if skipna and empty_targfunc and isna(targartempval).all(): + targ = empty_targfunc(targartempval, axis=axis, **kwargs) + else: + targ = targfunc(targartempval, axis=axis, **kwargs) - res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs) + res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna: + res = testfunc(testarval, axis=axis, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if axis is None: + res = testfunc(testarval, skipna=skipna, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna and axis is None: + res = testfunc(testarval, **kwargs) self.check_results(targ, res, axis, check_dtype=check_dtype) - if skipna: - res = testfunc(testarval, axis=axis, **kwargs) - self.check_results(targ, res, axis, check_dtype=check_dtype) - if axis is None: - res = testfunc(testarval, skipna=skipna, **kwargs) - self.check_results(targ, res, axis, check_dtype=check_dtype) - if skipna and axis is None: - res = testfunc(testarval, **kwargs) - self.check_results(targ, res, axis, check_dtype=check_dtype) if testarval.ndim <= 1: return @@ -184,12 +176,15 @@ def check_fun_data( targfunc, testarval2, targarval2, + skipna=skipna, check_dtype=check_dtype, empty_targfunc=empty_targfunc, **kwargs, ) - def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs): + def check_fun( + self, testfunc, targfunc, testar, skipna, empty_targfunc=None, **kwargs + ): targar = testar if testar.endswith("_nan") and hasattr(self, testar[:-4]): @@ -202,6 +197,7 @@ def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs): targfunc, testarval, targarval, + skipna=skipna, empty_targfunc=empty_targfunc, **kwargs, ) @@ -210,6 +206,7 @@ def check_funs( self, testfunc, targfunc, + skipna, allow_complex=True, allow_all_nan=True, allow_date=True, @@ -217,10 +214,10 @@ def check_funs( allow_obj=True, **kwargs, ): - self.check_fun(testfunc, targfunc, "arr_float", **kwargs) - self.check_fun(testfunc, targfunc, "arr_float_nan", **kwargs) - self.check_fun(testfunc, targfunc, "arr_int", **kwargs) - self.check_fun(testfunc, targfunc, "arr_bool", **kwargs) + self.check_fun(testfunc, targfunc, "arr_float", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_float_nan", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_int", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_bool", skipna, **kwargs) objs = [ self.arr_float.astype("O"), self.arr_int.astype("O"), @@ -228,18 +225,18 @@ def check_funs( ] if allow_all_nan: - self.check_fun(testfunc, targfunc, "arr_nan", **kwargs) + self.check_fun(testfunc, targfunc, "arr_nan", skipna, **kwargs) if allow_complex: - self.check_fun(testfunc, targfunc, "arr_complex", **kwargs) - self.check_fun(testfunc, targfunc, "arr_complex_nan", **kwargs) + self.check_fun(testfunc, targfunc, "arr_complex", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_complex_nan", skipna, **kwargs) if allow_all_nan: - self.check_fun(testfunc, targfunc, "arr_nan_nanj", **kwargs) + self.check_fun(testfunc, targfunc, "arr_nan_nanj", skipna, **kwargs) objs += [self.arr_complex.astype("O")] if allow_date: targfunc(self.arr_date) - self.check_fun(testfunc, targfunc, "arr_date", **kwargs) + self.check_fun(testfunc, targfunc, "arr_date", skipna, **kwargs) objs += [self.arr_date.astype("O")] if allow_tdelta: @@ -248,7 +245,7 @@ def check_funs( except TypeError: pass else: - self.check_fun(testfunc, targfunc, "arr_tdelta", **kwargs) + self.check_fun(testfunc, targfunc, "arr_tdelta", skipna, **kwargs) objs += [self.arr_tdelta.astype("O")] if allow_obj: @@ -260,7 +257,7 @@ def check_funs( targfunc = partial( self._badobj_wrap, func=targfunc, allow_complex=allow_complex ) - self.check_fun(testfunc, targfunc, "arr_obj", **kwargs) + self.check_fun(testfunc, targfunc, "arr_obj", skipna, **kwargs) def _badobj_wrap(self, value, func, allow_complex=True, **kwargs): if value.dtype.kind == "O": @@ -273,28 +270,22 @@ def _badobj_wrap(self, value, func, allow_complex=True, **kwargs): @pytest.mark.parametrize( "nan_op,np_op", [(nanops.nanany, np.any), (nanops.nanall, np.all)] ) - def test_nan_funcs(self, nan_op, np_op): - # TODO: allow tdelta, doesn't break tests - self.check_funs( - nan_op, np_op, allow_all_nan=False, allow_date=False, allow_tdelta=False - ) + def test_nan_funcs(self, nan_op, np_op, skipna): + self.check_funs(nan_op, np_op, skipna, allow_all_nan=False, allow_date=False) - def test_nansum(self): + def test_nansum(self, skipna): self.check_funs( nanops.nansum, np.sum, + skipna, allow_date=False, check_dtype=False, empty_targfunc=np.nansum, ) - def test_nanmean(self): + def test_nanmean(self, skipna): self.check_funs( - nanops.nanmean, - np.mean, - allow_complex=False, # TODO: allow this, doesn't break test - allow_obj=False, - allow_date=False, + nanops.nanmean, np.mean, skipna, allow_obj=False, allow_date=False, ) def test_nanmean_overflow(self): @@ -336,22 +327,24 @@ def test_returned_dtype(self, dtype): else: assert result.dtype == dtype - def test_nanmedian(self): + def test_nanmedian(self, skipna): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", RuntimeWarning) self.check_funs( nanops.nanmedian, np.median, + skipna, allow_complex=False, allow_date=False, allow_obj="convert", ) @pytest.mark.parametrize("ddof", range(3)) - def test_nanvar(self, ddof): + def test_nanvar(self, ddof, skipna): self.check_funs( nanops.nanvar, np.var, + skipna, allow_complex=False, allow_date=False, allow_obj="convert", @@ -359,10 +352,11 @@ def test_nanvar(self, ddof): ) @pytest.mark.parametrize("ddof", range(3)) - def test_nanstd(self, ddof): + def test_nanstd(self, ddof, skipna): self.check_funs( nanops.nanstd, np.std, + skipna, allow_complex=False, allow_date=False, allow_obj="convert", @@ -371,13 +365,14 @@ def test_nanstd(self, ddof): @td.skip_if_no_scipy @pytest.mark.parametrize("ddof", range(3)) - def test_nansem(self, ddof): + def test_nansem(self, ddof, skipna): from scipy.stats import sem with np.errstate(invalid="ignore"): self.check_funs( nanops.nansem, sem, + skipna, allow_complex=False, allow_date=False, allow_tdelta=False, @@ -388,10 +383,10 @@ def test_nansem(self, ddof): @pytest.mark.parametrize( "nan_op,np_op", [(nanops.nanmin, np.min), (nanops.nanmax, np.max)] ) - def test_nanops_with_warnings(self, nan_op, np_op): + def test_nanops_with_warnings(self, nan_op, np_op, skipna): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", RuntimeWarning) - self.check_funs(nan_op, np_op, allow_obj=False) + self.check_funs(nan_op, np_op, skipna, allow_obj=False) def _argminmax_wrap(self, value, axis=None, func=None): res = func(value, axis) @@ -408,17 +403,17 @@ def _argminmax_wrap(self, value, axis=None, func=None): res = -1 return res - def test_nanargmax(self): + def test_nanargmax(self, skipna): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", RuntimeWarning) func = partial(self._argminmax_wrap, func=np.argmax) - self.check_funs(nanops.nanargmax, func, allow_obj=False) + self.check_funs(nanops.nanargmax, func, skipna, allow_obj=False) - def test_nanargmin(self): + def test_nanargmin(self, skipna): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", RuntimeWarning) func = partial(self._argminmax_wrap, func=np.argmin) - self.check_funs(nanops.nanargmin, func, allow_obj=False) + self.check_funs(nanops.nanargmin, func, skipna, allow_obj=False) def _skew_kurt_wrap(self, values, axis=None, func=None): if not isinstance(values.dtype.type, np.floating): @@ -433,7 +428,7 @@ def _skew_kurt_wrap(self, values, axis=None, func=None): return result @td.skip_if_no_scipy - def test_nanskew(self): + def test_nanskew(self, skipna): from scipy.stats import skew func = partial(self._skew_kurt_wrap, func=skew) @@ -441,13 +436,14 @@ def test_nanskew(self): self.check_funs( nanops.nanskew, func, + skipna, allow_complex=False, allow_date=False, allow_tdelta=False, ) @td.skip_if_no_scipy - def test_nankurt(self): + def test_nankurt(self, skipna): from scipy.stats import kurtosis func1 = partial(kurtosis, fisher=True) @@ -456,15 +452,17 @@ def test_nankurt(self): self.check_funs( nanops.nankurt, func, + skipna, allow_complex=False, allow_date=False, allow_tdelta=False, ) - def test_nanprod(self): + def test_nanprod(self, skipna): self.check_funs( nanops.nanprod, np.prod, + skipna, allow_date=False, allow_tdelta=False, empty_targfunc=np.nanprod,