Skip to content

Commit 4e9a063

Browse files
committed
Set squeeze=None for Dataset too
1 parent c2e576e commit 4e9a063

File tree

5 files changed

+41
-25
lines changed

5 files changed

+41
-25
lines changed

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6709,7 +6709,7 @@ def groupby_bins(
67096709
labels: ArrayLike | Literal[False] | None = None,
67106710
precision: int = 3,
67116711
include_lowest: bool = False,
6712-
squeeze: bool = True,
6712+
squeeze: bool | None = None,
67136713
restore_coord_dims: bool = False,
67146714
) -> DataArrayGroupBy:
67156715
"""Returns a DataArrayGroupBy object for performing grouped operations.

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10052,7 +10052,7 @@ def interp_calendar(
1005210052
def groupby(
1005310053
self,
1005410054
group: Hashable | DataArray | IndexVariable,
10055-
squeeze: bool = True,
10055+
squeeze: bool | None = None,
1005610056
restore_coord_dims: bool = False,
1005710057
) -> DatasetGroupBy:
1005810058
"""Returns a DatasetGroupBy object for performing grouped operations.
@@ -10120,7 +10120,7 @@ def groupby_bins(
1012010120
labels: ArrayLike | None = None,
1012110121
precision: int = 3,
1012210122
include_lowest: bool = False,
10123-
squeeze: bool = True,
10123+
squeeze: bool | None = None,
1012410124
restore_coord_dims: bool = False,
1012510125
) -> DatasetGroupBy:
1012610126
"""Returns a DatasetGroupBy object for performing grouped operations.

xarray/core/groupby.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _maybe_squeeze_indices(
7777
indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool
7878
):
7979
if squeeze in [None, True] and grouper.can_squeeze:
80-
if squeeze is None and warn:
80+
if (squeeze is None and warn) or squeeze is True:
8181
emit_user_level_warning(
8282
"The `squeeze` kwarg to GroupBy is being removed."
8383
"Pass .groupby(..., squeeze=False) to disable squeezing,"
@@ -727,7 +727,7 @@ def __init__(
727727
self,
728728
obj: T_Xarray,
729729
groupers: tuple[ResolvedGrouper],
730-
squeeze: bool = False,
730+
squeeze: bool | None = False,
731731
restore_coord_dims: bool = True,
732732
) -> None:
733733
"""Create a GroupBy object
@@ -859,7 +859,7 @@ def _iter_grouped(self) -> Iterator[T_Xarray]:
859859
(grouper,) = self.groupers
860860
for idx, indices in enumerate(self._group_indices):
861861
indices = _maybe_squeeze_indices(
862-
indices, self._squeeze, grouper, warn=idx > 0
862+
indices, self._squeeze, grouper, warn=idx == 0
863863
)
864864
yield self._obj.isel({self._group_dim: indices})
865865

@@ -1363,7 +1363,7 @@ def _iter_grouped_shortcut(self):
13631363
(grouper,) = self.groupers
13641364
for idx, indices in enumerate(self._group_indices):
13651365
indices = _maybe_squeeze_indices(
1366-
indices, self._squeeze, grouper, warn=idx > 0
1366+
indices, self._squeeze, grouper, warn=idx == 0
13671367
)
13681368
yield var[{self._group_dim: indices}]
13691369

xarray/tests/test_groupby.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,34 @@ def test_consolidate_slices() -> None:
5959
_consolidate_slices([slice(3), 4]) # type: ignore[list-item]
6060

6161

62-
def test_groupby_dims_property(dataset) -> None:
63-
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
64-
assert dataset.groupby("y").dims == dataset.isel(y=1).dims
62+
def test_groupby_dims_property(dataset, recwarn) -> None:
63+
# dims is sensitive to squeeze, always warn
64+
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
65+
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
66+
assert dataset.groupby("y").dims == dataset.isel(y=1).dims
6567

68+
# when squeeze=False, no warning should be raised
6669
assert dataset.groupby("x", squeeze=False).dims == dataset.isel(x=slice(1, 2)).dims
6770
assert dataset.groupby("y", squeeze=False).dims == dataset.isel(y=slice(1, 2)).dims
71+
assert len(recwarn) == 0
6872

6973
stacked = dataset.stack({"xy": ("x", "y")})
7074
assert stacked.groupby("xy", squeeze=False).dims == stacked.isel(xy=[0]).dims
75+
assert len(recwarn) == 0
7176

7277

7378
def test_multi_index_groupby_map(dataset) -> None:
7479
# regression test for GH873
7580
ds = dataset.isel(z=1, drop=True)[["foo"]]
7681
expected = 2 * ds
77-
actual = (
78-
ds.stack(space=["x", "y"])
79-
.groupby("space")
80-
.map(lambda x: 2 * x)
81-
.unstack("space")
82-
)
82+
# The function in `map` may be sensitive to squeeze, always warn
83+
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
84+
actual = (
85+
ds.stack(space=["x", "y"])
86+
.groupby("space")
87+
.map(lambda x: 2 * x)
88+
.unstack("space")
89+
)
8390
assert_equal(expected, actual)
8491

8592

@@ -202,7 +209,9 @@ def func(arg1, arg2, arg3=0):
202209

203210
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
204211
expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]})
205-
actual = dataset.groupby("x").map(func, args=(1,), arg3=1)
212+
# The function in `map` may be sensitive to squeeze, always warn
213+
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
214+
actual = dataset.groupby("x").map(func, args=(1,), arg3=1)
206215
assert_identical(expected, actual)
207216

208217

@@ -887,7 +896,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
887896

888897
with xr.set_options(use_flox=use_flox):
889898
actual = da.groupby_bins(
890-
"x", bins=x_bins, include_lowest=True, right=False
899+
"x", bins=x_bins, include_lowest=True, right=False, squeeze=False
891900
).mean()
892901
expected = xr.DataArray(
893902
np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]),
@@ -1135,8 +1144,8 @@ def test_groupby_properties(self):
11351144
"by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)]
11361145
)
11371146
@pytest.mark.parametrize("shortcut", [True, False])
1138-
@pytest.mark.parametrize("squeeze", [True, False])
1139-
def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None:
1147+
@pytest.mark.parametrize("squeeze", [None])
1148+
def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None:
11401149
expected = self.da
11411150
if use_da:
11421151
by = expected.coords[by]
@@ -1148,6 +1157,10 @@ def identity(x):
11481157
actual = grouped.map(identity, shortcut=shortcut)
11491158
assert_identical(expected, actual)
11501159

1160+
# abc is not a dim coordinate so no warnings expected!
1161+
if (by.name if use_da else by) != "abc":
1162+
assert len(recwarn) == (1 if squeeze in [None, True] else 0)
1163+
11511164
def test_groupby_sum(self):
11521165
array = self.da
11531166
grouped = array.groupby("abc")
@@ -1508,7 +1521,7 @@ def test_groupby_bins_ellipsis(self):
15081521
da = xr.DataArray(np.ones((2, 3, 4)))
15091522
bins = [-1, 0, 1, 2]
15101523
with xr.set_options(use_flox=False):
1511-
actual = da.groupby_bins("dim_0", bins).mean(...)
1524+
actual = da.groupby_bins("dim_0", bins, squeeze=False).mean(...)
15121525
with xr.set_options(use_flox=True):
15131526
expected = da.groupby_bins("dim_0", bins).mean(...)
15141527
assert_allclose(actual, expected)

xarray/tests/test_units.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3933,9 +3933,12 @@ def test_grouped_operations(self, func, variant, dtype):
39333933
for key, value in func.kwargs.items()
39343934
}
39353935
expected = attach_units(
3936-
func(strip_units(data_array).groupby("y"), **stripped_kwargs), units
3936+
func(
3937+
strip_units(data_array).groupby("y", squeeze=False), **stripped_kwargs
3938+
),
3939+
units,
39373940
)
3938-
actual = func(data_array.groupby("y"))
3941+
actual = func(data_array.groupby("y", squeeze=False))
39393942

39403943
assert_units_equal(expected, actual)
39413944
assert_identical(expected, actual)
@@ -5440,9 +5443,9 @@ def test_grouped_operations(self, func, variant, dtype):
54405443
name: strip_units(value) for name, value in func.kwargs.items()
54415444
}
54425445
expected = attach_units(
5443-
func(strip_units(ds).groupby("y"), **stripped_kwargs), units
5446+
func(strip_units(ds).groupby("y", squeeze=False), **stripped_kwargs), units
54445447
)
5445-
actual = func(ds.groupby("y"))
5448+
actual = func(ds.groupby("y", squeeze=False))
54465449

54475450
assert_units_equal(expected, actual)
54485451
assert_equal(expected, actual)

0 commit comments

Comments
 (0)