@@ -69,25 +69,27 @@ def __call__(
69
69
max_fname_arg_count = None ,
70
70
method : str | None = None ,
71
71
) -> None :
72
- if args or kwargs :
73
- fname = self .fname if fname is None else fname
74
- max_fname_arg_count = (
75
- self .max_fname_arg_count
76
- if max_fname_arg_count is None
77
- else max_fname_arg_count
72
+ if not args and not kwargs :
73
+ return None
74
+
75
+ fname = self .fname if fname is None else fname
76
+ max_fname_arg_count = (
77
+ self .max_fname_arg_count
78
+ if max_fname_arg_count is None
79
+ else max_fname_arg_count
80
+ )
81
+ method = self .method if method is None else method
82
+
83
+ if method == "args" :
84
+ validate_args (fname , args , max_fname_arg_count , self .defaults )
85
+ elif method == "kwargs" :
86
+ validate_kwargs (fname , kwargs , self .defaults )
87
+ elif method == "both" :
88
+ validate_args_and_kwargs (
89
+ fname , args , kwargs , max_fname_arg_count , self .defaults
78
90
)
79
- method = self .method if method is None else method
80
-
81
- if method == "args" :
82
- validate_args (fname , args , max_fname_arg_count , self .defaults )
83
- elif method == "kwargs" :
84
- validate_kwargs (fname , kwargs , self .defaults )
85
- elif method == "both" :
86
- validate_args_and_kwargs (
87
- fname , args , kwargs , max_fname_arg_count , self .defaults
88
- )
89
- else :
90
- raise ValueError (f"invalid validation method '{ method } '" )
91
+ else :
92
+ raise ValueError (f"invalid validation method '{ method } '" )
91
93
92
94
93
95
ARGMINMAX_DEFAULTS = {"out" : None }
@@ -247,7 +249,7 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
247
249
LOGICAL_FUNC_DEFAULTS = {"out" : None , "keepdims" : False }
248
250
validate_logical_func = CompatValidator (LOGICAL_FUNC_DEFAULTS , method = "kwargs" )
249
251
250
- MINMAX_DEFAULTS = {"axis" : None , "out" : None , "keepdims" : False }
252
+ MINMAX_DEFAULTS = {"axis" : None , "dtype" : None , " out" : None , "keepdims" : False }
251
253
validate_min = CompatValidator (
252
254
MINMAX_DEFAULTS , fname = "min" , method = "both" , max_fname_arg_count = 1
253
255
)
@@ -285,10 +287,9 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
285
287
SUM_DEFAULTS ["keepdims" ] = False
286
288
SUM_DEFAULTS ["initial" ] = None
287
289
288
- PROD_DEFAULTS = STAT_FUNC_DEFAULTS .copy ()
289
- PROD_DEFAULTS ["axis" ] = None
290
- PROD_DEFAULTS ["keepdims" ] = False
291
- PROD_DEFAULTS ["initial" ] = None
290
+ PROD_DEFAULTS = SUM_DEFAULTS .copy ()
291
+
292
+ MEAN_DEFAULTS = SUM_DEFAULTS .copy ()
292
293
293
294
MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS .copy ()
294
295
MEDIAN_DEFAULTS ["overwrite_input" ] = False
@@ -304,7 +305,7 @@ def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
304
305
PROD_DEFAULTS , fname = "prod" , method = "both" , max_fname_arg_count = 1
305
306
)
306
307
validate_mean = CompatValidator (
307
- STAT_FUNC_DEFAULTS , fname = "mean" , method = "both" , max_fname_arg_count = 1
308
+ MEAN_DEFAULTS , fname = "mean" , method = "both" , max_fname_arg_count = 1
308
309
)
309
310
validate_median = CompatValidator (
310
311
MEDIAN_DEFAULTS , fname = "median" , method = "both" , max_fname_arg_count = 1
@@ -395,3 +396,21 @@ def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
395
396
return
396
397
if axis >= ndim or (axis < 0 and ndim + axis < 0 ):
397
398
raise ValueError (f"`axis` must be fewer than the number of dimensions ({ ndim } )" )
399
+
400
+
401
+ _validation_funcs = {
402
+ "median" : validate_median ,
403
+ "mean" : validate_mean ,
404
+ "min" : validate_min ,
405
+ "max" : validate_max ,
406
+ "sum" : validate_sum ,
407
+ "prod" : validate_prod ,
408
+ }
409
+
410
+
411
+ def validate_func (fname , args , kwargs ) -> None :
412
+ if fname not in _validation_funcs :
413
+ return validate_stat_func (args , kwargs , fname = fname )
414
+
415
+ validation_func = _validation_funcs [fname ]
416
+ return validation_func (args , kwargs )
0 commit comments