Skip to content

ENH: added numexpr support for where operations #3154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions pandas/core/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

_USE_NUMEXPR = _NUMEXPR_INSTALLED
_evaluate = None
_where = None

# the set of dtypes that we will allow pass to numexpr
_ALLOWED_DTYPES = set(['int64','int32','float64','float32','bool'])
_ALLOWED_DTYPES = dict(evaluate = set(['int64','int32','float64','float32','bool']),
where = set(['int64','float64','bool']))

# the minimum prod shape that we will use numexpr
_MIN_ELEMENTS = 10000
Expand All @@ -26,17 +28,16 @@ def set_use_numexpr(v = True):
# set/unset to use numexpr
global _USE_NUMEXPR
if _NUMEXPR_INSTALLED:
#print "setting use_numexpr : was->%s, now->%s" % (_USE_NUMEXPR,v)
_USE_NUMEXPR = v

# choose what we are going to do
global _evaluate
global _evaluate, _where
if not _USE_NUMEXPR:
_evaluate = _evaluate_standard
_where = _where_standard
else:
_evaluate = _evaluate_numexpr

#print "evaluate -> %s" % _evaluate
_where = _where_numexpr

def set_numexpr_threads(n = None):
# if we are using numexpr, set the threads to n
Expand All @@ -54,7 +55,7 @@ def _evaluate_standard(op, op_str, a, b, raise_on_error=True):
""" standard evaluation """
return op(a,b)

def _can_use_numexpr(op, op_str, a, b):
def _can_use_numexpr(op, op_str, a, b, dtype_check):
""" return a boolean if we WILL be using numexpr """
if op_str is not None:

Expand All @@ -73,15 +74,15 @@ def _can_use_numexpr(op, op_str, a, b):
dtypes |= set([o.dtype.name])

# allowed are a superset
if not len(dtypes) or _ALLOWED_DTYPES >= dtypes:
if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
return True

return False

def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False):
result = None

if _can_use_numexpr(op, op_str, a, b):
if _can_use_numexpr(op, op_str, a, b, 'evaluate'):
try:
a_value, b_value = a, b
if hasattr(a_value,'values'):
Expand All @@ -104,6 +105,40 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False):

return result

def _where_standard(cond, a, b, raise_on_error=True):
return np.where(cond, a, b)

def _where_numexpr(cond, a, b, raise_on_error = False):
result = None

if _can_use_numexpr(None, 'where', a, b, 'where'):

try:
cond_value, a_value, b_value = cond, a, b
if hasattr(cond_value,'values'):
cond_value = cond_value.values
if hasattr(a_value,'values'):
a_value = a_value.values
if hasattr(b_value,'values'):
b_value = b_value.values
result = ne.evaluate('where(cond_value,a_value,b_value)',
local_dict={ 'cond_value' : cond_value,
'a_value' : a_value,
'b_value' : b_value },
casting='safe')
except (ValueError), detail:
if 'unknown type object' in str(detail):
pass
except (Exception), detail:
if raise_on_error:
raise TypeError(str(detail))

if result is None:
result = _where_standard(cond,a,b,raise_on_error)

return result


# turn myself on
set_use_numexpr(True)

Expand All @@ -126,4 +161,20 @@ def evaluate(op, op_str, a, b, raise_on_error=False, use_numexpr=True):
return _evaluate(op, op_str, a, b, raise_on_error=raise_on_error)
return _evaluate_standard(op, op_str, a, b, raise_on_error=raise_on_error)


def where(cond, a, b, raise_on_error=False, use_numexpr=True):
""" evaluate the where condition cond on a and b

Parameters
----------

cond : a boolean array
a : return if cond is True
b : return if cond is False
raise_on_error : pass the error to the higher level if indicated (default is False),
otherwise evaluate the op with and return the results
use_numexpr : whether to try to use numexpr (default True)
"""

if use_numexpr:
return _where(cond, a, b, raise_on_error=raise_on_error)
return _where_standard(cond, a, b, raise_on_error=raise_on_error)
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3720,7 +3720,8 @@ def combine_first(self, other):
-------
combined : DataFrame
"""
combiner = lambda x, y: np.where(isnull(x), y, x)
def combiner(x, y):
return expressions.where(isnull(x), y, x, raise_on_error=True)
return self.combine(other, combiner, overwrite=False)

def update(self, other, join='left', overwrite=True, filter_func=None,
Expand Down Expand Up @@ -3771,7 +3772,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
else:
mask = notnull(this)

self[col] = np.where(mask, this, that)
self[col] = expressions.where(mask, this, that, raise_on_error=True)

#----------------------------------------------------------------------
# Misc methods
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas.core.common as com
import pandas.lib as lib
import pandas.tslib as tslib
import pandas.core.expressions as expressions

from pandas.tslib import Timestamp
from pandas.util import py3compat
Expand Down Expand Up @@ -506,7 +507,7 @@ def func(c,v,o):
return v

try:
return np.where(c,v,o)
return expressions.where(c, v, o, raise_on_error=True)
except (Exception), detail:
if raise_on_error:
raise TypeError('Could not operate [%s] with block values [%s]'
Expand Down
38 changes: 30 additions & 8 deletions pandas/tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def setUp(self):
def test_invalid(self):

# no op
result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame)
result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame, 'evaluate')
self.assert_(result == False)

# mixed
result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame)
result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame, 'evaluate')
self.assert_(result == False)

# min elements
result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2)
result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2, 'evaluate')
self.assert_(result == False)

# ok, we only check on first part of expression
result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2)
result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2, 'evaluate')
self.assert_(result == True)

def test_binary_ops(self):
Expand All @@ -70,14 +70,14 @@ def testit():
for op, op_str in [('add','+'),('sub','-'),('mul','*'),('div','/'),('pow','**')]:

op = getattr(operator,op)
result = expr._can_use_numexpr(op, op_str, f, f)
result = expr._can_use_numexpr(op, op_str, f, f, 'evaluate')
self.assert_(result == (not f._is_mixed_type))

result = expr.evaluate(op, op_str, f, f, use_numexpr=True)
expected = expr.evaluate(op, op_str, f, f, use_numexpr=False)
assert_array_equal(result,expected.values)

result = expr._can_use_numexpr(op, op_str, f2, f2)
result = expr._can_use_numexpr(op, op_str, f2, f2, 'evaluate')
self.assert_(result == False)


Expand Down Expand Up @@ -105,14 +105,14 @@ def testit():

op = getattr(operator,op)

result = expr._can_use_numexpr(op, op_str, f11, f12)
result = expr._can_use_numexpr(op, op_str, f11, f12, 'evaluate')
self.assert_(result == (not f11._is_mixed_type))

result = expr.evaluate(op, op_str, f11, f12, use_numexpr=True)
expected = expr.evaluate(op, op_str, f11, f12, use_numexpr=False)
assert_array_equal(result,expected.values)

result = expr._can_use_numexpr(op, op_str, f21, f22)
result = expr._can_use_numexpr(op, op_str, f21, f22, 'evaluate')
self.assert_(result == False)

expr.set_use_numexpr(False)
Expand All @@ -123,6 +123,28 @@ def testit():
expr.set_numexpr_threads()
testit()

def test_where(self):

def testit():
for f in [ self.frame, self.frame2, self.mixed, self.mixed2 ]:


for cond in [ True, False ]:

c = np.empty(f.shape,dtype=np.bool_)
c.fill(cond)
result = expr.where(c, f.values, f.values+1)
expected = np.where(c, f.values, f.values+1)
assert_array_equal(result,expected)

expr.set_use_numexpr(False)
testit()
expr.set_use_numexpr(True)
expr.set_numexpr_threads(1)
testit()
expr.set_numexpr_threads()
testit()

if __name__ == '__main__':
# unittest.main()
import nose
Expand Down