diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b81b9ba070a44..b2c463f9f4a22 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -301,3 +301,60 @@ repos: language: python files: ^doc/source/whatsnew/v exclude: ^doc/source/whatsnew/v(0|1|2\.0\.0) + - id: enforce-match-arg-in-assert-produces-warning + name: Enforce the usage of match arg + entry: python scripts/enforce_match_arg_in_assert_produces_warning.py + language: python + files: ^pandas/tests + exclude: | + (?x) + ^( + pandas/tests/computation/test_eval.py| + pandas/tests/frame/test_query_eval.py| + pandas/tests/frame/methods/test_drop.py| + pandas/tests/plotting/test_datetimelike.py| + pandas/tests/plotting/frame/test_frame_color.py| + pandas/tests/plotting/test_hist_method.py| + pandas/tests/util/test_deprecate_kwarg.py| + pandas/tests/util/test_deprecate.py| + pandas/tests/util/test_assert_produces_warning.py| + pandas/tests/util/test_deprecate_nonkeyword_arguments.py| + pandas/tests/api/test_types.py| + pandas/tests/api/test_api.py| + pandas/tests/strings/test_find_replace.py| + pandas/tests/arithmetic/test_datetime64.py| + pandas/tests/arithmetic/test_timedelta64.py| + pandas/tests/arithmetic/test_period.py| + pandas/tests/tseries/offsets/test_offsets.py| + pandas/tests/indexes/datetimes/methods/test_shift.py| + pandas/tests/extension/test_period.py| + pandas/tests/io/pytables/test_retain_attributes.py| + pandas/tests/scalar/timestamp/methods/test_to_pydatetime.py| + pandas/tests/io/test_sql.py| + pandas/tests/plotting/frame/test_hist_box_by.py| + pandas/tests/plotting/test_boxplot_method.py| + pandas/tests/frame/methods/test_reindex_like.py| + pandas/tests/series/accessors/test_cat_accessor.py| + pandas/tests/indexing/multiindex/test_multiindex.py| + pandas/tests/indexes/multi/test_drop.py| + pandas/tests/io/parser/test_unsupported.py| + pandas/tests/apply/test_str.py| + pandas/tests/indexes/multi/test_indexing.py| + pandas/tests/io/pytables/test_store.py| + pandas/tests/io/pytables/test_put.py| + pandas/tests/io/pytables/test_round_trip.py| + pandas/tests/resample/test_datetime_index.py| + pandas/tests/io/parser/test_c_parser_only.py| + pandas/tests/io/test_stata.py| + pandas/tests/plotting/test_misc.py| + pandas/tests/series/methods/test_equals.py| + pandas/tests/frame/test_block_internals.py| + pandas/tests/indexes/multi/test_sorting.py| + pandas/tests/series/methods/test_reindex_like.py| + pandas/tests/extension/test_sparse.py| + pandas/tests/indexes/test_common.py| + pandas/tests/indexing/multiindex/test_loc.py| + pandas/tests/frame/indexing/test_insert.py| + pandas/tests/groupby/test_groupby.py| + )$ + types: [python] diff --git a/scripts/enforce_match_arg_in_assert_produces_warning.py b/scripts/enforce_match_arg_in_assert_produces_warning.py new file mode 100755 index 0000000000000..d1a2e9cc2e6dc --- /dev/null +++ b/scripts/enforce_match_arg_in_assert_produces_warning.py @@ -0,0 +1,88 @@ +""" +Enforce that all usages of tm.assert_produces_warning use +the "match" argument. This will help ensure that users always see +the correct warning message. + +tm.assert_produces_warning(None), tm.assert_produces_warning() +and tm.assert_produces_warning(False) are excluded as no warning is +produced. + +This is meant to be run as a pre-commit hook - to run it manually, you can do: + + pre-commit run enforce-match-arg-in-assert-produces-warning --all-files +""" +from __future__ import annotations + +import argparse +import ast +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + +ERROR_MESSAGE = ( + "{path}:{lineno}:{col_offset}: " + '"match" argument missing in tm.assert_produces_warning' + "\n" +) + + +class MatchArgForWarningsChecker(ast.NodeVisitor): + def __init__(self) -> None: + self.error_set = [] + + def visit_Call(self, node) -> None: + if ( isinstance(node.func, ast.Attribute) and + node.func.attr == "assert_produces_warning"): + # only check for attribute function of class/module tm + if ( isinstance(node.func.value, ast.Name) and + node.func.value.id == "tm" ): + # ignore tm.assert_produces_warning(None),tm.assert_produces_warning() + # and tm.assert_produces_warning(False) + if ( len(node.args) == 0 or + (isinstance(node.args[0], ast.Constant) and + ( node.args[0].value is None or node.args[0].value is False))): + return + if not any(keyword.arg == "match" for keyword in node.keywords): + self.error_set.append((node.lineno, node.col_offset)) + + +# Returns true if a file fails the check +def check_for_match_arg(content: str, filename: str) -> bool: + tree = ast.parse(content) + visitor = MatchArgForWarningsChecker() + visitor.visit(tree) + + if len(visitor.error_set) == 0: + return False + + for error in visitor.error_set: + msg = ERROR_MESSAGE.format( + lineno=error[0], + col_offset=error[1], + path=filename, + ) + sys.stdout.write(msg) + + return True + + +def main(argv: Sequence[str] | None = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("paths", nargs="*") + + args = parser.parse_args(argv) + is_match_missing = False + + for filename in args.paths: + with open(filename, encoding="utf-8") as fd: + content = fd.read() + is_match_missing = check_for_match_arg(content, filename) | is_match_missing + + if is_match_missing: + sys.exit(1) + + +if __name__ == "__main__": + main()