diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 94b7c59b93563..071cd116ea982 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1,6 +1,6 @@ import abc import inspect -from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Tuple, Type, Union import numpy as np @@ -72,7 +72,9 @@ def series_generator(self) -> Iterator["Series"]: pass @abc.abstractmethod - def wrap_results_for_axis(self, results: ResType) -> Union["Series", "DataFrame"]: + def wrap_results_for_axis( + self, results: ResType, res_index: "Index" + ) -> Union["Series", "DataFrame"]: pass # --------------------------------------------------------------- @@ -112,15 +114,6 @@ def f(x): self.f = f - # results - self.result = None - self._res_index: Optional["Index"] = None - - @property - def res_index(self) -> "Index": - assert self._res_index is not None - return self._res_index - @property def res_columns(self) -> "Index": return self.result_columns @@ -313,12 +306,12 @@ def apply_standard(self): return self.obj._constructor_sliced(result, index=labels) # compute the result using the series generator - results = self.apply_series_generator() + results, res_index = self.apply_series_generator() # wrap results - return self.wrap_results(results) + return self.wrap_results(results, res_index) - def apply_series_generator(self) -> ResType: + def apply_series_generator(self) -> Tuple[ResType, "Index"]: series_gen = self.series_generator res_index = self.result_index @@ -345,19 +338,20 @@ def apply_series_generator(self) -> ResType: results[i] = self.f(v) keys.append(v.name) - self._res_index = res_index - return results + return results, res_index - def wrap_results(self, results: ResType) -> Union["Series", "DataFrame"]: + def wrap_results( + self, results: ResType, res_index: "Index" + ) -> Union["Series", "DataFrame"]: # see if we can infer the results if len(results) > 0 and 0 in results and is_sequence(results[0]): - return self.wrap_results_for_axis(results) + return self.wrap_results_for_axis(results, res_index) # dict of scalars result = self.obj._constructor_sliced(results) - result.index = self.res_index + result.index = res_index return result @@ -380,7 +374,9 @@ def result_index(self) -> "Index": def result_columns(self) -> "Index": return self.index - def wrap_results_for_axis(self, results: ResType) -> "DataFrame": + def wrap_results_for_axis( + self, results: ResType, res_index: "Index" + ) -> "DataFrame": """ return the results for the rows """ result = self.obj._constructor(data=results) @@ -389,8 +385,8 @@ def wrap_results_for_axis(self, results: ResType) -> "DataFrame": if len(result.index) == len(self.res_columns): result.index = self.res_columns - if len(result.columns) == len(self.res_index): - result.columns = self.res_index + if len(result.columns) == len(res_index): + result.columns = res_index return result @@ -418,35 +414,37 @@ def result_index(self) -> "Index": def result_columns(self) -> "Index": return self.columns - def wrap_results_for_axis(self, results: ResType) -> Union["Series", "DataFrame"]: + def wrap_results_for_axis( + self, results: ResType, res_index: "Index" + ) -> Union["Series", "DataFrame"]: """ return the results for the columns """ result: Union["Series", "DataFrame"] # we have requested to expand if self.result_type == "expand": - result = self.infer_to_same_shape(results) + result = self.infer_to_same_shape(results, res_index) # we have a non-series and don't want inference elif not isinstance(results[0], ABCSeries): from pandas import Series result = Series(results) - result.index = self.res_index + result.index = res_index # we may want to infer results else: - result = self.infer_to_same_shape(results) + result = self.infer_to_same_shape(results, res_index) return result - def infer_to_same_shape(self, results: ResType) -> "DataFrame": + def infer_to_same_shape(self, results: ResType, res_index: "Index") -> "DataFrame": """ infer the results to the same shape as the input object """ result = self.obj._constructor(data=results) result = result.T # set the index - result.index = self.res_index + result.index = res_index # infer dtypes result = result.infer_objects()