From adca80933af1e593b795db3505912d9ed8cf2f06 Mon Sep 17 00:00:00 2001 From: jreback Date: Wed, 18 Jun 2014 20:22:35 -0400 Subject: [PATCH] BUG: Bug in DataFrame.where with a symmetric shaped frame and a passed other of a DataFrame (GH7506) --- doc/source/v0.14.1.txt | 4 ++-- pandas/core/internals.py | 8 ++++++- pandas/tests/test_frame.py | 48 +++++++++++++++++++++++++++----------- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index c1e5877d09004..af859b7be2558 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -163,12 +163,12 @@ Experimental Bug Fixes ~~~~~~~~~ +- Bug in ``DataFrame.where`` with a symmetric shaped frame and a passed other of a DataFrame (:issue:`7506`) - -- Bug in ``value_counts`` where ``NaT`` did not qualify as missing (``NaN``) (:issue:`7423`) +- Bug in ``value_counts`` where ``NaT`` did not qualify as missing (``NaN``) (:issue:`7423`) diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 75ec53c95869a..6b2d6bcfe3c80 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -921,9 +921,13 @@ def where(self, other, cond, align=True, raise_on_error=True, if hasattr(other, 'ndim') and hasattr(values, 'ndim'): if values.ndim != other.ndim or values.shape == other.shape[::-1]: + # if its symmetric are ok, no reshaping needed (GH 7506) + if (values.shape[0] == np.array(values.shape)).all(): + pass + # pseodo broadcast (its a 2d vs 1d say and where needs it in a # specific direction) - if (other.ndim >= 1 and values.ndim - 1 == other.ndim and + elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and values.shape[0] != other.shape[0]): other = _block_shape(other).T else: @@ -941,9 +945,11 @@ def where(self, other, cond, align=True, raise_on_error=True, # may need to undo transpose of values if hasattr(values, 'ndim'): if values.ndim != cond.ndim or values.shape == cond.shape[::-1]: + values = values.T is_transposed = not is_transposed + # our where function def func(c, v, o): if c.ravel().all(): diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 8ed1d2d2d4f95..6848b130dee3a 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5564,27 +5564,27 @@ def test_to_csv_from_csv(self): with ensure_clean(pname) as path: self.frame['A'][:5] = nan - + self.frame.to_csv(path) self.frame.to_csv(path, columns=['A', 'B']) self.frame.to_csv(path, header=False) self.frame.to_csv(path, index=False) - + # test roundtrip self.tsframe.to_csv(path) recons = DataFrame.from_csv(path) - + assert_frame_equal(self.tsframe, recons) - + self.tsframe.to_csv(path, index_label='index') recons = DataFrame.from_csv(path, index_col=None) assert(len(recons.columns) == len(self.tsframe.columns) + 1) - + # no index self.tsframe.to_csv(path, index=False) recons = DataFrame.from_csv(path, index_col=None) assert_almost_equal(self.tsframe.values, recons.values) - + # corner case dm = DataFrame({'s1': Series(lrange(3), lrange(3)), 's2': Series(lrange(2), lrange(2))}) @@ -5600,7 +5600,7 @@ def test_to_csv_from_csv(self): df.to_csv(path) result = DataFrame.from_csv(path) assert_frame_equal(result, df) - + midx = MultiIndex.from_tuples([('A', 1, 2), ('A', 1, 2), ('B', 1, 2)]) df = DataFrame(np.random.randn(3, 3), index=midx, columns=['x', 'y', 'z']) @@ -5608,16 +5608,16 @@ def test_to_csv_from_csv(self): result = DataFrame.from_csv(path, index_col=[0, 1, 2], parse_dates=False) assert_frame_equal(result, df, check_names=False) # TODO from_csv names index ['Unnamed: 1', 'Unnamed: 2'] should it ? - + # column aliases col_aliases = Index(['AA', 'X', 'Y', 'Z']) self.frame2.to_csv(path, header=col_aliases) rs = DataFrame.from_csv(path) xp = self.frame2.copy() xp.columns = col_aliases - + assert_frame_equal(xp, rs) - + self.assertRaises(ValueError, self.frame2.to_csv, path, header=['AA', 'X']) @@ -5881,7 +5881,7 @@ def test_to_csv_from_csv_w_some_infs(self): with ensure_clean() as path: self.frame.to_csv(path) recons = DataFrame.from_csv(path) - + assert_frame_equal(self.frame, recons, check_names=False) # TODO to_csv drops column name assert_frame_equal(np.isinf(self.frame), np.isinf(recons), check_names=False) @@ -5940,11 +5940,11 @@ def test_to_csv_multiindex(self): frame.to_csv(path, header=False) frame.to_csv(path, columns=['A', 'B']) - + # round trip frame.to_csv(path) df = DataFrame.from_csv(path, index_col=[0, 1], parse_dates=False) - + assert_frame_equal(frame, df, check_names=False) # TODO to_csv drops column name self.assertEqual(frame.index.names, df.index.names) self.frame.index = old_index # needed if setUP becomes a classmethod @@ -9155,6 +9155,28 @@ def test_where_bug(self): result.where(result > 2, np.nan, inplace=True) assert_frame_equal(result, expected) + # transpositional issue + # GH7506 + a = DataFrame({ 0 : [1,2], 1 : [3,4], 2 : [5,6]}) + b = DataFrame({ 0 : [np.nan,8], 1:[9,np.nan], 2:[np.nan,np.nan]}) + do_not_replace = b.isnull() | (a > b) + + expected = a.copy() + expected[~do_not_replace] = b + + result = a.where(do_not_replace,b) + assert_frame_equal(result,expected) + + a = DataFrame({ 0 : [4,6], 1 : [1,0]}) + b = DataFrame({ 0 : [np.nan,3],1:[3,np.nan]}) + do_not_replace = b.isnull() | (a > b) + + expected = a.copy() + expected[~do_not_replace] = b + + result = a.where(do_not_replace,b) + assert_frame_equal(result,expected) + def test_where_datetime(self): # GH 3311