diff --git a/pandas/compat/numpy/function.py b/pandas/compat/numpy/function.py index f6e80aba0c34f..a36e25a9df410 100644 --- a/pandas/compat/numpy/function.py +++ b/pandas/compat/numpy/function.py @@ -69,25 +69,27 @@ def __call__( max_fname_arg_count=None, method: str | None = None, ) -> None: - if args or kwargs: - fname = self.fname if fname is None else fname - max_fname_arg_count = ( - self.max_fname_arg_count - if max_fname_arg_count is None - else max_fname_arg_count + if not args and not kwargs: + return None + + fname = self.fname if fname is None else fname + max_fname_arg_count = ( + self.max_fname_arg_count + if max_fname_arg_count is None + else max_fname_arg_count + ) + method = self.method if method is None else method + + if method == "args": + validate_args(fname, args, max_fname_arg_count, self.defaults) + elif method == "kwargs": + validate_kwargs(fname, kwargs, self.defaults) + elif method == "both": + validate_args_and_kwargs( + fname, args, kwargs, max_fname_arg_count, self.defaults ) - method = self.method if method is None else method - - if method == "args": - validate_args(fname, args, max_fname_arg_count, self.defaults) - elif method == "kwargs": - validate_kwargs(fname, kwargs, self.defaults) - elif method == "both": - validate_args_and_kwargs( - fname, args, kwargs, max_fname_arg_count, self.defaults - ) - else: - raise ValueError(f"invalid validation method '{method}'") + else: + raise ValueError(f"invalid validation method '{method}'") ARGMINMAX_DEFAULTS = {"out": None} @@ -247,7 +249,7 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool: LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False} validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs") -MINMAX_DEFAULTS = {"axis": None, "out": None, "keepdims": False} +MINMAX_DEFAULTS = {"axis": None, "dtype": None, "out": None, "keepdims": False} validate_min = CompatValidator( MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1 ) @@ -285,10 +287,9 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool: SUM_DEFAULTS["keepdims"] = False SUM_DEFAULTS["initial"] = None -PROD_DEFAULTS = STAT_FUNC_DEFAULTS.copy() -PROD_DEFAULTS["axis"] = None -PROD_DEFAULTS["keepdims"] = False -PROD_DEFAULTS["initial"] = None +PROD_DEFAULTS = SUM_DEFAULTS.copy() + +MEAN_DEFAULTS = SUM_DEFAULTS.copy() MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy() MEDIAN_DEFAULTS["overwrite_input"] = False @@ -304,7 +305,7 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool: PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1 ) validate_mean = CompatValidator( - STAT_FUNC_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1 + MEAN_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1 ) validate_median = CompatValidator( MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1 @@ -395,3 +396,21 @@ def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None: return if axis >= ndim or (axis < 0 and ndim + axis < 0): raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})") + + +_validation_funcs = { + "median": validate_median, + "mean": validate_mean, + "min": validate_min, + "max": validate_max, + "sum": validate_sum, + "prod": validate_prod, +} + + +def validate_func(fname, args, kwargs) -> None: + if fname not in _validation_funcs: + return validate_stat_func(args, kwargs, fname=fname) + + validation_func = _validation_funcs[fname] + return validation_func(args, kwargs) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 3da4f96444215..f55800c7f44e1 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -11345,10 +11345,8 @@ def _stat_function( numeric_only: bool_t = False, **kwargs, ): - if name == "median": - nv.validate_median((), kwargs) - else: - nv.validate_stat_func((), kwargs, fname=name) + assert name in ["median", "mean", "min", "max", "kurt", "skew"], name + nv.validate_func(name, (), kwargs) validate_bool_kwarg(skipna, "skipna", none_allowed=False) @@ -11445,12 +11443,8 @@ def _min_count_stat_function( min_count: int = 0, **kwargs, ): - if name == "sum": - nv.validate_sum((), kwargs) - elif name == "prod": - nv.validate_prod((), kwargs) - else: - nv.validate_stat_func((), kwargs, fname=name) + assert name in ["sum", "prod"], name + nv.validate_func(name, (), kwargs) validate_bool_kwarg(skipna, "skipna", none_allowed=False)