diff --git a/pandas/core/expressions.py b/pandas/core/expressions.py index 4199c6f7f890c..de93394872e12 100644 --- a/pandas/core/expressions.py +++ b/pandas/core/expressions.py @@ -15,9 +15,11 @@ _USE_NUMEXPR = _NUMEXPR_INSTALLED _evaluate = None +_where = None # the set of dtypes that we will allow pass to numexpr -_ALLOWED_DTYPES = set(['int64','int32','float64','float32','bool']) +_ALLOWED_DTYPES = dict(evaluate = set(['int64','int32','float64','float32','bool']), + where = set(['int64','float64','bool'])) # the minimum prod shape that we will use numexpr _MIN_ELEMENTS = 10000 @@ -26,17 +28,16 @@ def set_use_numexpr(v = True): # set/unset to use numexpr global _USE_NUMEXPR if _NUMEXPR_INSTALLED: - #print "setting use_numexpr : was->%s, now->%s" % (_USE_NUMEXPR,v) _USE_NUMEXPR = v # choose what we are going to do - global _evaluate + global _evaluate, _where if not _USE_NUMEXPR: _evaluate = _evaluate_standard + _where = _where_standard else: _evaluate = _evaluate_numexpr - - #print "evaluate -> %s" % _evaluate + _where = _where_numexpr def set_numexpr_threads(n = None): # if we are using numexpr, set the threads to n @@ -54,7 +55,7 @@ def _evaluate_standard(op, op_str, a, b, raise_on_error=True): """ standard evaluation """ return op(a,b) -def _can_use_numexpr(op, op_str, a, b): +def _can_use_numexpr(op, op_str, a, b, dtype_check): """ return a boolean if we WILL be using numexpr """ if op_str is not None: @@ -73,7 +74,7 @@ def _can_use_numexpr(op, op_str, a, b): dtypes |= set([o.dtype.name]) # allowed are a superset - if not len(dtypes) or _ALLOWED_DTYPES >= dtypes: + if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes: return True return False @@ -81,7 +82,7 @@ def _can_use_numexpr(op, op_str, a, b): def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False): result = None - if _can_use_numexpr(op, op_str, a, b): + if _can_use_numexpr(op, op_str, a, b, 'evaluate'): try: a_value, b_value = a, b if hasattr(a_value,'values'): @@ -104,6 +105,40 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False): return result +def _where_standard(cond, a, b, raise_on_error=True): + return np.where(cond, a, b) + +def _where_numexpr(cond, a, b, raise_on_error = False): + result = None + + if _can_use_numexpr(None, 'where', a, b, 'where'): + + try: + cond_value, a_value, b_value = cond, a, b + if hasattr(cond_value,'values'): + cond_value = cond_value.values + if hasattr(a_value,'values'): + a_value = a_value.values + if hasattr(b_value,'values'): + b_value = b_value.values + result = ne.evaluate('where(cond_value,a_value,b_value)', + local_dict={ 'cond_value' : cond_value, + 'a_value' : a_value, + 'b_value' : b_value }, + casting='safe') + except (ValueError), detail: + if 'unknown type object' in str(detail): + pass + except (Exception), detail: + if raise_on_error: + raise TypeError(str(detail)) + + if result is None: + result = _where_standard(cond,a,b,raise_on_error) + + return result + + # turn myself on set_use_numexpr(True) @@ -126,4 +161,20 @@ def evaluate(op, op_str, a, b, raise_on_error=False, use_numexpr=True): return _evaluate(op, op_str, a, b, raise_on_error=raise_on_error) return _evaluate_standard(op, op_str, a, b, raise_on_error=raise_on_error) - +def where(cond, a, b, raise_on_error=False, use_numexpr=True): + """ evaluate the where condition cond on a and b + + Parameters + ---------- + + cond : a boolean array + a : return if cond is True + b : return if cond is False + raise_on_error : pass the error to the higher level if indicated (default is False), + otherwise evaluate the op with and return the results + use_numexpr : whether to try to use numexpr (default True) + """ + + if use_numexpr: + return _where(cond, a, b, raise_on_error=raise_on_error) + return _where_standard(cond, a, b, raise_on_error=raise_on_error) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index afb698221c48b..c3dc38d5d7187 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3720,7 +3720,8 @@ def combine_first(self, other): ------- combined : DataFrame """ - combiner = lambda x, y: np.where(isnull(x), y, x) + def combiner(x, y): + return expressions.where(isnull(x), y, x, raise_on_error=True) return self.combine(other, combiner, overwrite=False) def update(self, other, join='left', overwrite=True, filter_func=None, @@ -3771,7 +3772,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None, else: mask = notnull(this) - self[col] = np.where(mask, this, that) + self[col] = expressions.where(mask, this, that, raise_on_error=True) #---------------------------------------------------------------------- # Misc methods diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 385695ec6cc50..f7c560481cc5f 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -10,6 +10,7 @@ import pandas.core.common as com import pandas.lib as lib import pandas.tslib as tslib +import pandas.core.expressions as expressions from pandas.tslib import Timestamp from pandas.util import py3compat @@ -506,7 +507,7 @@ def func(c,v,o): return v try: - return np.where(c,v,o) + return expressions.where(c, v, o, raise_on_error=True) except (Exception), detail: if raise_on_error: raise TypeError('Could not operate [%s] with block values [%s]' diff --git a/pandas/tests/test_expressions.py b/pandas/tests/test_expressions.py index a0321d2dbe55f..a496785b0aed3 100644 --- a/pandas/tests/test_expressions.py +++ b/pandas/tests/test_expressions.py @@ -46,19 +46,19 @@ def setUp(self): def test_invalid(self): # no op - result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame) + result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame, 'evaluate') self.assert_(result == False) # mixed - result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame) + result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame, 'evaluate') self.assert_(result == False) # min elements - result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2) + result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2, 'evaluate') self.assert_(result == False) # ok, we only check on first part of expression - result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2) + result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2, 'evaluate') self.assert_(result == True) def test_binary_ops(self): @@ -70,14 +70,14 @@ def testit(): for op, op_str in [('add','+'),('sub','-'),('mul','*'),('div','/'),('pow','**')]: op = getattr(operator,op) - result = expr._can_use_numexpr(op, op_str, f, f) + result = expr._can_use_numexpr(op, op_str, f, f, 'evaluate') self.assert_(result == (not f._is_mixed_type)) result = expr.evaluate(op, op_str, f, f, use_numexpr=True) expected = expr.evaluate(op, op_str, f, f, use_numexpr=False) assert_array_equal(result,expected.values) - result = expr._can_use_numexpr(op, op_str, f2, f2) + result = expr._can_use_numexpr(op, op_str, f2, f2, 'evaluate') self.assert_(result == False) @@ -105,14 +105,14 @@ def testit(): op = getattr(operator,op) - result = expr._can_use_numexpr(op, op_str, f11, f12) + result = expr._can_use_numexpr(op, op_str, f11, f12, 'evaluate') self.assert_(result == (not f11._is_mixed_type)) result = expr.evaluate(op, op_str, f11, f12, use_numexpr=True) expected = expr.evaluate(op, op_str, f11, f12, use_numexpr=False) assert_array_equal(result,expected.values) - result = expr._can_use_numexpr(op, op_str, f21, f22) + result = expr._can_use_numexpr(op, op_str, f21, f22, 'evaluate') self.assert_(result == False) expr.set_use_numexpr(False) @@ -123,6 +123,28 @@ def testit(): expr.set_numexpr_threads() testit() + def test_where(self): + + def testit(): + for f in [ self.frame, self.frame2, self.mixed, self.mixed2 ]: + + + for cond in [ True, False ]: + + c = np.empty(f.shape,dtype=np.bool_) + c.fill(cond) + result = expr.where(c, f.values, f.values+1) + expected = np.where(c, f.values, f.values+1) + assert_array_equal(result,expected) + + expr.set_use_numexpr(False) + testit() + expr.set_use_numexpr(True) + expr.set_numexpr_threads(1) + testit() + expr.set_numexpr_threads() + testit() + if __name__ == '__main__': # unittest.main() import nose