From b700f581134d6f10162b624114accae9dbcd19c4 Mon Sep 17 00:00:00 2001 From: Jeffrey Tratner Date: Mon, 29 Jul 2013 23:10:19 -0400 Subject: [PATCH] TST: util/testing improvements assertRaises and assertRaisesRegexp are now with-statement-compatible, refactored assert_panel_equal and removed check_index_freq from assert_series_equal. --- pandas/io/tests/test_pytables.py | 5 +- pandas/util/testing.py | 199 ++++++++++++++++++++----------- 2 files changed, 135 insertions(+), 69 deletions(-) diff --git a/pandas/io/tests/test_pytables.py b/pandas/io/tests/test_pytables.py index d6eeb38076a42..c2564a6e12145 100644 --- a/pandas/io/tests/test_pytables.py +++ b/pandas/io/tests/test_pytables.py @@ -1763,7 +1763,10 @@ def test_index_types(self): values = np.random.randn(2) - func = lambda l, r: tm.assert_series_equal(l, r, True, True, True) + func = lambda l, r: tm.assert_series_equal(l, r, + check_dtype=True, + check_index_type=True, + check_series_type=True) with tm.assert_produces_warning(expected_warning=PerformanceWarning): ser = Series(values, [0, 'y']) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 8cb9138f4d2f6..82fdf45265e78 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -1,8 +1,8 @@ from __future__ import division - # pylint: disable-msg=W0402 import random +import re import string import sys import tempfile @@ -11,7 +11,7 @@ import os from datetime import datetime -from functools import wraps +from functools import wraps, partial from contextlib import contextmanager from distutils.version import LooseVersion @@ -130,8 +130,20 @@ def assert_isinstance(obj, class_type_or_tuple): "Expected object to be of type %r, found %r instead" % ( type(obj), class_type_or_tuple)) -def assert_equal(actual, expected, msg=""): - assert expected == actual, "%s: %r != %r" % (msg, actual, expected) +def assert_equal(a, b, msg=""): + """asserts that a equals b, like nose's assert_equal, but allows custom message to start. + Passes a and b to format string as well. So you can use '{0}' and '{1}' to display a and b. + + Examples + -------- + >>> assert_equal(2, 2, "apples") + >>> assert_equal(5.2, 1.2, "{0} was really a dead parrot") + Traceback (most recent call last): + ... + AssertionError: 5.2 was really a dead parrot: 5.2 != 1.2 + """ + assert a == b, "%s: %r != %r" % (msg.format(a,b), a, b) + def assert_index_equal(left, right): if not left.equals(right): @@ -139,18 +151,17 @@ def assert_index_equal(left, right): left, right, right.dtype)) + + def assert_attr_equal(attr, left, right): - left_attr = getattr(left, attr, None) - right_attr = getattr(right, attr, None) + """checks attributes are equal. Both objects must have attribute.""" + left_attr = getattr(left, attr) + right_attr = getattr(right, attr) assert_equal(left_attr,right_attr,"attr is not equal [{0}]" .format(attr)) def isiterable(obj): return hasattr(obj, '__iter__') -def assert_isinstance(obj, class_type_or_tuple): - """asserts that obj is an instance of class_type_or_tuple""" - assert isinstance(obj, class_type_or_tuple), ( - "Expected object to be of type %r, found %r instead" % (type(obj), class_type_or_tuple)) def assert_almost_equal(a, b, check_less_precise=False): @@ -221,7 +232,6 @@ def assert_dict_equal(a, b, compare_keys=True): def assert_series_equal(left, right, check_dtype=True, check_index_type=False, - check_index_freq=False, check_series_type=False, check_less_precise=False): if check_series_type: @@ -238,8 +248,6 @@ def assert_series_equal(left, right, check_dtype=True, assert_isinstance(left.index, type(right.index)) assert_attr_equal('dtype', left.index, right.index) assert_attr_equal('inferred_type', left.index, right.index) - if check_index_freq: - assert_attr_equal('freqstr', left.index, right.index) def assert_frame_equal(left, right, check_dtype=True, @@ -261,7 +269,7 @@ def assert_frame_equal(left, right, check_dtype=True, assert_index_equal(left.index, right.index) for i, col in enumerate(left.columns): - assert(col in right) + assert col in right lcol = left.icol(i) rcol = right.icol(i) assert_series_equal(lcol, rcol, @@ -282,44 +290,36 @@ def assert_frame_equal(left, right, check_dtype=True, assert_attr_equal('names', left.columns, right.columns) -def assert_panel_equal(left, right, - check_panel_type=False, - check_less_precise=False): +def assert_panelnd_equal(left, right, + check_panel_type=False, + check_less_precise=False, + assert_func=assert_frame_equal): if check_panel_type: assert_isinstance(left, type(right)) for axis in ['items', 'major_axis', 'minor_axis']: - assert_index_equal( - getattr(left, axis, None), getattr(right, axis, None)) + left_ind = getattr(left, axis) + right_ind = getattr(right, axis) + assert_index_equal(left_ind, right_ind) for col, series in compat.iteritems(left): - assert(col in right) - # TODO strangely check_names fails in py3 ? - assert_frame_equal( - series, right[col], check_less_precise=check_less_precise, check_names=False) + assert col in right, "non-matching column '%s'" % col + assert_func(series, right[col], check_less_precise=check_less_precise) for col in right: - assert(col in left) - - -def assert_panel4d_equal(left, right, - check_less_precise=False): - for axis in ['labels', 'items', 'major_axis', 'minor_axis']: - assert_index_equal( - getattr(left, axis, None), getattr(right, axis, None)) - - for col, series in compat.iteritems(left): - assert(col in right) - assert_panel_equal( - series, right[col], check_less_precise=check_less_precise) + assert col in left - for col in right: - assert(col in left) +# TODO: strangely check_names fails in py3 ? +_panel_frame_equal = partial(assert_frame_equal, check_names=False) +assert_panel_equal = partial(assert_panelnd_equal, + assert_func=_panel_frame_equal) +assert_panel4d_equal = partial(assert_panelnd_equal, + assert_func=assert_panel_equal) def assert_contains_all(iterable, dic): for k in iterable: - assert(k in dic) + assert k in dic, "Did not contain item: '%r'" % k def getCols(k): @@ -986,7 +986,45 @@ def stdin_encoding(encoding=None): sys.stdin = _stdin -def assertRaisesRegexp(exception, regexp, callable, *args, **kwargs): +def assertRaises(_exception, _callable=None, *args, **kwargs): + """assertRaises that is usable as context manager or in a with statement + + Exceptions that don't match the given Exception type fall through:: + + >>> with assertRaises(ValueError): + ... raise TypeError("banana") + ... + Traceback (most recent call last): + ... + TypeError: banana + + If it raises the given Exception type, the test passes + >>> with assertRaises(KeyError): + ... dct = dict() + ... dct["apple"] + + If the expected error doesn't occur, it raises an error. + >>> with assertRaises(KeyError): + ... dct = {'apple':True} + ... dct["apple"] + Traceback (most recent call last): + ... + AssertionError: KeyError not raised. + + In addition to using it as a contextmanager, you can also use it as a + function, just like the normal assertRaises + + >>> assertRaises(TypeError, ",".join, [1, 3, 5]); + """ + manager = _AssertRaisesContextmanager(exception=_exception) + # don't return anything if usedin function form + if _callable is not None: + with manager: + _callable(*args, **kwargs) + else: + return manager + +def assertRaisesRegexp(_exception, _regexp, _callable=None, *args, **kwargs): """ Port of assertRaisesRegexp from unittest in Python 2.7 - used in with statement. Explanation from standard library: @@ -997,46 +1035,71 @@ def assertRaisesRegexp(exception, regexp, callable, *args, **kwargs): You can pass either a regular expression or a compiled regular expression object. >>> assertRaisesRegexp(ValueError, 'invalid literal for.*XYZ', - ... int, 'XYZ') + ... int, 'XYZ'); >>> import re - >>> assertRaisesRegexp(ValueError, re.compile('literal'), int, 'XYZ') + >>> assertRaisesRegexp(ValueError, re.compile('literal'), int, 'XYZ'); If an exception of a different type is raised, it bubbles up. - >>> assertRaisesRegexp(TypeError, 'literal', int, 'XYZ') + >>> assertRaisesRegexp(TypeError, 'literal', int, 'XYZ'); Traceback (most recent call last): ... ValueError: invalid literal for int() with base 10: 'XYZ' - >>> dct = {} - >>> assertRaisesRegexp(KeyError, 'pear', dct.__getitem__, 'apple') + >>> dct = dict() + >>> assertRaisesRegexp(KeyError, 'pear', dct.__getitem__, 'apple'); Traceback (most recent call last): ... AssertionError: "pear" does not match "'apple'" - >>> assertRaisesRegexp(KeyError, 'apple', dct.__getitem__, 'apple') - >>> assertRaisesRegexp(Exception, 'operand type.*int.*dict', lambda : 2 + {}) - """ - - import re - try: - callable(*args, **kwargs) - except Exception as e: - if not issubclass(e.__class__, exception): - # mimics behavior of unittest - raise - # don't recompile - if hasattr(regexp, "search"): - expected_regexp = regexp - else: - expected_regexp = re.compile(regexp) - if not expected_regexp.search(str(e)): - raise AssertionError('"%s" does not match "%s"' % - (expected_regexp.pattern, str(e))) + You can also use this in a with statement. + >>> with assertRaisesRegexp(TypeError, 'unsupported operand type\(s\)'): + ... 1 + {} + >>> with assertRaisesRegexp(TypeError, 'banana'): + ... 'apple'[0] = 'b' + Traceback (most recent call last): + ... + AssertionError: "banana" does not match "'str' object does not support \ +item assignment" + """ + manager = _AssertRaisesContextmanager(exception=_exception, regexp=_regexp) + if _callable is not None: + with manager: + _callable(*args, **kwargs) else: - # Apparently some exceptions don't have a __name__ attribute? Just - # aping unittest library here - name = getattr(exception, "__name__", str(exception)) - raise AssertionError("{0} not raised".format(name)) + return manager + + +class _AssertRaisesContextmanager(object): + """handles the behind the scenes work for assertRaises and assertRaisesRegexp""" + def __init__(self, exception, regexp=None, *args, **kwargs): + self.exception = exception + if regexp is not None and not hasattr(regexp, "search"): + regexp = re.compile(regexp) + self.regexp = regexp + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + expected = self.exception + if not exc_type: + name = getattr(expected, "__name__", str(expected)) + raise AssertionError("{0} not raised.".format(name)) + if issubclass(exc_type, expected): + return self.handle_success(exc_type, exc_value, traceback) + return self.handle_failure(exc_type, exc_value, traceback) + + def handle_failure(*args, **kwargs): + # Failed, so allow Exception to bubble up + return False + + def handle_success(self, exc_type, exc_value, traceback): + if self.regexp is not None: + val = str(exc_value) + if not self.regexp.search(val): + raise AssertionError('"%s" does not match "%s"' % + (self.regexp.pattern, str(val))) + return True @contextmanager