-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG/CLN: Decouple Series/DataFrame.transform #35964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7b6ab94
04c1238
052df6e
7b13811
133bfaa
a5d4a19
25c4457
8454d91
c37ef68
9eee0cb
cf4f80b
69e6807
f66a806
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10750,9 +10750,48 @@ def transform(self, func, *args, **kwargs): | |
1 1.000000 2.718282 | ||
2 1.414214 7.389056 | ||
""" | ||
result = self.agg(func, *args, **kwargs) | ||
if is_scalar(result) or len(result) != len(self): | ||
raise ValueError("transforms cannot produce aggregated results") | ||
raise NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove this? or is the doc-string used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just move the doc-string to shared_docs? |
||
|
||
def _transform(self, func, *args, **kwargs): | ||
rhshadrach marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(func, dict): | ||
results = {} | ||
for name, how in func.items(): | ||
colg = self._gotitem(name, ndim=1) | ||
try: | ||
results[name] = colg.transform(how, *args, **kwargs) | ||
except Exception as e: | ||
if str(e) == "Function did not transform": | ||
raise e | ||
|
||
# combine results | ||
if len(results) == 0: | ||
raise ValueError("Transform function failed") | ||
from pandas.core.reshape.concat import concat | ||
|
||
return concat(results, axis=1) | ||
|
||
try: | ||
if isinstance(func, str): | ||
result = self._try_aggregate_string_function(func, *args, **kwargs) | ||
else: | ||
f = self._get_cython_func(func) | ||
if f and not args and not kwargs: | ||
result = getattr(self, f)() | ||
else: | ||
try: | ||
result = self.apply(func, args=args, **kwargs) | ||
except Exception: | ||
result = func(self, *args, **kwargs) | ||
|
||
except Exception: | ||
raise ValueError("Transform function failed") | ||
|
||
# Functions that transform may return empty Series/DataFrame | ||
# when the dtype is not appropriate | ||
if isinstance(result, NDFrame) and result.empty: | ||
raise ValueError("Transform function failed") | ||
if not isinstance(result, NDFrame) or not result.index.equals(self.index): | ||
raise ValueError("Function did not transform") | ||
|
||
return result | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from pandas import DataFrame, Index, MultiIndex, Series, isna | ||
import pandas._testing as tm | ||
from pandas.core.base import SpecificationError | ||
from pandas.core.groupby.base import transformation_kernels | ||
|
||
|
||
class TestSeriesApply: | ||
|
@@ -222,7 +223,7 @@ def test_transform(self, string_series): | |
expected.columns = ["sqrt"] | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = string_series.transform([np.sqrt]) | ||
result = string_series.apply([np.sqrt]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line |
||
tm.assert_frame_equal(result, expected) | ||
|
||
result = string_series.transform(["sqrt"]) | ||
|
@@ -248,9 +249,34 @@ def test_transform(self, string_series): | |
result = string_series.apply({"foo": np.sqrt, "bar": np.abs}) | ||
tm.assert_series_equal(result.reindex_like(expected), expected) | ||
|
||
expected = pd.concat([f_sqrt, f_abs], axis=1) | ||
expected.columns = ["foo", "bar"] | ||
result = string_series.transform({"foo": np.sqrt, "bar": np.abs}) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
# UDF via apply | ||
def func(x): | ||
if isinstance(x, Series): | ||
raise ValueError | ||
return x + 1 | ||
|
||
result = string_series.transform(func) | ||
expected = string_series + 1 | ||
tm.assert_series_equal(result, expected) | ||
|
||
# UDF that maps Series -> Series | ||
def func(x): | ||
if not isinstance(x, Series): | ||
raise ValueError | ||
return x + 1 | ||
|
||
result = string_series.transform(func) | ||
expected = string_series + 1 | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_transform_and_agg_error(self, string_series): | ||
# we are trying to transform with an aggregator | ||
msg = "transforms cannot produce aggregated results" | ||
msg = "Function did not transform" | ||
with pytest.raises(ValueError, match=msg): | ||
string_series.transform(["min", "max"]) | ||
|
||
|
@@ -259,6 +285,7 @@ def test_transform_and_agg_error(self, string_series): | |
with np.errstate(all="ignore"): | ||
string_series.agg(["sqrt", "max"]) | ||
|
||
msg = "Function did not transform" | ||
with pytest.raises(ValueError, match=msg): | ||
with np.errstate(all="ignore"): | ||
string_series.transform(["sqrt", "max"]) | ||
|
@@ -467,11 +494,73 @@ def test_transform_none_to_type(self): | |
# GH34377 | ||
df = pd.DataFrame({"a": [None]}) | ||
|
||
msg = "DataFrame constructor called with incompatible data and dtype" | ||
with pytest.raises(TypeError, match=msg): | ||
msg = "Transform function failed.*" | ||
with pytest.raises(ValueError, match=msg): | ||
df.transform({"a": int}) | ||
|
||
|
||
def test_transform_reducer_raises(all_reductions): | ||
op = all_reductions | ||
s = pd.Series([1, 2, 3]) | ||
msg = "Function did not transform" | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform(op) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform([op]) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform({"A": op}) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform({"A": [op]}) | ||
|
||
|
||
# mypy doesn't allow adding lists of different types | ||
# https://github.com/python/mypy/issues/5492 | ||
@pytest.mark.parametrize("op", [*transformation_kernels, lambda x: x + 1]) | ||
def test_transform_bad_dtype(op): | ||
s = pd.Series(3 * [object]) # Series that will fail on most transforms | ||
if op in ("backfill", "shift", "pad", "bfill", "ffill"): | ||
pytest.xfail("Transform function works on any datatype") | ||
msg = "Transform function failed" | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform(op) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform([op]) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform({"A": op}) | ||
with pytest.raises(ValueError, match=msg): | ||
s.transform({"A": [op]}) | ||
|
||
|
||
@pytest.mark.parametrize("use_apply", [True, False]) | ||
def test_transform_passes_args(use_apply): | ||
# transform uses UDF either via apply or passing the entire Series | ||
expected_args = [1, 2] | ||
expected_kwargs = {"c": 3} | ||
|
||
def f(x, a, b, c): | ||
# transform is using apply iff x is not a Series | ||
if use_apply == isinstance(x, Series): | ||
# Force transform to fallback | ||
raise ValueError | ||
assert [a, b] == expected_args | ||
assert c == expected_kwargs["c"] | ||
return x | ||
|
||
pd.Series([1]).transform(f, 0, *expected_args, **expected_kwargs) | ||
|
||
|
||
def test_transform_axis_1_raises(): | ||
msg = "No axis named 1 for object type Series" | ||
with pytest.raises(ValueError, match=msg): | ||
pd.Series([1]).transform("sum", axis=1) | ||
|
||
|
||
def test_transform_nested_renamer(): | ||
match = "nested renamer is not supported" | ||
with pytest.raises(SpecificationError, match=match): | ||
pd.Series([1]).transform({"A": {"B": ["sum"]}}) | ||
|
||
|
||
class TestSeriesMap: | ||
def test_map(self, datetime_series): | ||
index, data = tm.getMixedTypeDict() | ||
|
Uh oh!
There was an error while loading. Please reload this page.