Skip to content

Commit 8e8c41b

Browse files
committed
clean up broadcast tests
1 parent d88c131 commit 8e8c41b

File tree

6 files changed

+132
-27
lines changed

6 files changed

+132
-27
lines changed

doc/source/basics.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,12 @@ The :meth:`~DataFrame.apply` method will also dispatch on a string method name.
793793
df.apply('mean')
794794
df.apply('mean', axis=1)
795795
796-
Depending on the return type of the function passed to :meth:`~DataFrame.apply`,
797-
the result will either be of lower dimension or the same dimension.
796+
The return type of the function passed to :meth:`~DataFrame.apply` affects the
797+
type of the ultimate output from DataFrame.apply
798+
799+
* If the applied function returns a ``Series``, the ultimate output is a ``DataFrame``.
800+
The columns match the index ``Series`` returned by the applied function.
801+
* If the applied function returns any other type, the ultimate output is a ``Series``.
798802

799803
:meth:`~DataFrame.apply` combined with some cleverness can be used to answer many questions
800804
about a data set. For example, suppose we wanted to extract the date where the

doc/source/whatsnew/v0.23.0.txt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ where a list-like (e.g. ``tuple`` or ``list`` is returned), (:issue:`16353`, :is
347347
df = pd.DataFrame(np.tile(np.arange(3), 6).reshape(6, -1) + 1, columns=['A', 'B', 'C'])
348348
df
349349

350-
Previous Behavior. If the returned shape happened to match the index, this would return a list-like.
350+
Previous Behavior. If the returned shape happened to match the original columns, this would return a ``DataFrame``.
351+
If the return shape did not match, a ``Series`` with lists was returned.
351352

352353
.. code-block:: python
353354

@@ -379,12 +380,25 @@ New Behavior. The behavior is consistent. These will *always* return a ``Series`
379380
df.apply(lambda x: [1, 2, 3], axis=1)
380381
df.apply(lambda x: [1, 2], axis=1)
381382

382-
To have automatic inference, you can use ``result_type='infer'``
383+
To have expanded columns, you can use ``result_type='infer'``
383384

384385
.. ipython:: python
385386

386387
df.apply(lambda x: [1, 2, 3], axis=1, result_type='infer')
387388

389+
To have broadcast the result across, you can use ``result_type='broadcast'``. The shape
390+
must match the original columns.
391+
392+
.. ipython:: python
393+
394+
df.apply(lambda x: [1, 2, 3], axis=1, result_type='broadcast')
395+
396+
Returning a ``Series`` allows one to control the exact return structure and column names:
397+
398+
.. ipython:: python
399+
400+
df.apply(lambda x: Series([1, 2, 3], index=x.index), axis=1)
401+
388402

389403
.. _whatsnew_0230.api_breaking.build_changes:
390404

pandas/core/apply.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, obj, func, broadcast, raw, reduce, result_type,
4343
if broadcast is not None:
4444
warnings.warn("The broadcast argument is deprecated and will "
4545
"be removed in a future version. You can specify "
46-
"result_type='broadcast' broadcast a scalar result",
46+
"result_type='broadcast' to broadcast the result "
47+
"to the original dimensions",
4748
FutureWarning, stacklevel=4)
4849
if broadcast:
4950
result_type = 'broadcast'
@@ -160,11 +161,32 @@ def apply_raw(self):
160161

161162
def apply_broadcast(self, target):
162163
result_values = np.empty_like(target.values)
163-
columns = target.columns
164-
for i, col in enumerate(columns):
165-
result_values[:, i] = self.f(target[col])
166164

167-
result = self.obj._constructor(result_values, index=target.index,
165+
# axis which we want to compare compliance
166+
result_compare = target.shape[0]
167+
168+
index = target.index
169+
for i, col in enumerate(target.columns):
170+
res = self.f(target[col])
171+
ares = np. asarray(res).ndim
172+
173+
# must be a scalar or 1d
174+
if ares > 1:
175+
raise ValueError("too many dims to broadcast")
176+
elif ares == 1:
177+
178+
# must match return dim
179+
if result_compare != len(res):
180+
raise ValueError("cannot broadcast result")
181+
182+
# if we have a Series result, then then index
183+
# is our result
184+
if isinstance(res, ABCSeries):
185+
index = res.index
186+
187+
result_values[:, i] = res
188+
189+
result = self.obj._constructor(result_values, index=index,
168190
columns=target.columns)
169191
return result
170192

pandas/core/frame.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4870,10 +4870,11 @@ def apply(self, func, axis=0, broadcast=None, raw=False, reduce=None,
48704870
while guessing, exceptions raised by func will be ignored). If
48714871
reduce is True a Series will always be returned, and if False a
48724872
DataFrame will always be returned.
4873+
48734874
result_type : {'infer', 'broadcast, None}
48744875
These only act when axis=1 {columns}
4875-
* infer : list-like results will be turned into columns
4876-
* broadcast : scalar results will be broadcast to all rows
4876+
* 'infer' : list-like results will be turned into columns
4877+
* 'broadcast' : scalar results will be broadcast to all columns
48774878
* None : list-like results will be returned as a list
48784879
in a single column. However if the apply function
48794880
returns a Series these are expanded to columns.

pandas/core/sparse/frame.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ def apply(self, func, axis=0, broadcast=None, reduce=False,
855855
856856
result_type : {'infer', 'broadcast, None}
857857
These only act when axis=1 {columns}
858-
* infer : list-like results will be turned into columns
859-
* broadcast : scalar results will be broadcast to all rows
858+
* 'infer' : list-like results will be turned into columns
859+
* 'broadcast' : scalar results will be broadcast to all columns
860860
* None : list-like results will be returned as a list
861861
in a single column. However if the apply function
862862
returns a Series these are expanded to columns.

pandas/tests/frame/test_apply.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,58 @@ def test_with_string_args(self):
121121
expected = getattr(self.frame, arg)(axis=1)
122122
tm.assert_series_equal(result, expected)
123123

124-
def test_apply_broadcast(self):
124+
def test_apply_broadcast_deprecated(self):
125125
with tm.assert_produces_warning(FutureWarning):
126-
broadcasted = self.frame.apply(np.mean, broadcast=True)
127-
agged = self.frame.apply(np.mean)
126+
self.frame.apply(np.mean, broadcast=True)
128127

129-
for col, ts in compat.iteritems(broadcasted):
130-
assert (ts == agged[col]).all()
128+
def test_apply_broadcast(self):
131129

132-
with tm.assert_produces_warning(FutureWarning):
133-
broadcasted = self.frame.apply(np.mean, axis=1, broadcast=True)
134-
agged = self.frame.apply(np.mean, axis=1)
135-
for idx in broadcasted.index:
136-
assert (broadcasted.xs(idx) == agged[idx]).all()
130+
# scalars
131+
result = self.frame.apply(np.mean, result_type='broadcast')
132+
expected = DataFrame([self.frame.mean()], index=self.frame.index)
133+
tm.assert_frame_equal(result, expected)
137134

138-
with tm.assert_produces_warning(FutureWarning):
139-
self.frame.apply(np.mean, axis=1, broadcast=False)
135+
result = self.frame.apply(np.mean, axis=1, result_type='broadcast')
136+
m = self.frame.mean(axis=1)
137+
expected = DataFrame({c: m for c in self.frame.columns})
138+
tm.assert_frame_equal(result, expected)
139+
140+
# lists
141+
result = self.frame.apply(
142+
lambda x: list(range(len(self.frame.columns))),
143+
axis=1,
144+
result_type='broadcast')
145+
m = list(range(len(self.frame.columns)))
146+
expected = DataFrame([m] * len(self.frame.index),
147+
dtype='float64',
148+
index=self.frame.index,
149+
columns=self.frame.columns)
150+
tm.assert_frame_equal(result, expected)
151+
152+
result = self.frame.apply(lambda x: list(range(len(self.frame.index))),
153+
result_type='broadcast')
154+
m = list(range(len(self.frame.index)))
155+
expected = DataFrame({c: m for c in self.frame.columns},
156+
dtype='float64',
157+
index=self.frame.index)
158+
tm.assert_frame_equal(result, expected)
159+
160+
def test_apply_broadcast_error(self):
161+
df = DataFrame(
162+
np.tile(np.arange(3, dtype='int64'), 6).reshape(6, -1) + 1,
163+
columns=['A', 'B', 'C'])
164+
165+
# > 1 ndim
166+
with pytest.raises(ValueError):
167+
df.apply(lambda x: np.array([1, 2]).reshape(-1, 2),
168+
axis=1,
169+
result_type='broadcast')
170+
171+
# cannot broadcast
172+
with pytest.raises(ValueError):
173+
df.apply(lambda x: [1, 2],
174+
axis=1,
175+
result_type='broadcast')
140176

141177
def test_apply_raw(self):
142178
result0 = self.frame.apply(np.mean, raw=True)
@@ -213,8 +249,7 @@ def _checkit(axis=0, raw=False):
213249
_check(no_index, lambda x: x)
214250
_check(no_index, lambda x: x.mean())
215251

216-
with tm.assert_produces_warning(FutureWarning):
217-
result = no_cols.apply(lambda x: x.mean(), broadcast=True)
252+
result = no_cols.apply(lambda x: x.mean(), result_type='broadcast')
218253
assert isinstance(result, DataFrame)
219254

220255
def test_apply_with_args_kwds(self):
@@ -680,6 +715,35 @@ def test_result_type(self):
680715
expected.columns = [0, 1]
681716
assert_frame_equal(result, expected)
682717

718+
# broadcast result
719+
result = df.apply(lambda x: [1, 2, 3], axis=1, result_type='broadcast')
720+
expected = df.copy()
721+
assert_frame_equal(result, expected)
722+
723+
columns = ['other', 'col', 'names']
724+
result = df.apply(
725+
lambda x: pd.Series([1, 2, 3],
726+
index=columns),
727+
axis=1,
728+
result_type='broadcast')
729+
expected = df.copy()
730+
expected.columns = columns
731+
assert_frame_equal(result, expected)
732+
733+
# series result
734+
result = df.apply(lambda x: Series([1, 2, 3], index=x.index), axis=1)
735+
expected = df.copy()
736+
assert_frame_equal(result, expected)
737+
738+
# series result with other index
739+
columns = ['other', 'col', 'names']
740+
result = df.apply(
741+
lambda x: pd.Series([1, 2, 3], index=columns),
742+
axis=1)
743+
expected = df.copy()
744+
expected.columns = columns
745+
assert_frame_equal(result, expected)
746+
683747
@pytest.mark.parametrize(
684748
"box",
685749
[lambda x: list(x),

0 commit comments

Comments
 (0)