Skip to content

Commit 9243612

Browse files
committed
Fixed arrays tests to check dtypes
1 parent 28b45f2 commit 9243612

File tree

4 files changed

+33
-34
lines changed

4 files changed

+33
-34
lines changed

pandas/core/groupby/groupby.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -759,20 +759,19 @@ def _try_cast(self, result, obj, numeric_only=False):
759759
else:
760760
dtype = obj.dtype
761761

762-
if is_extension_array_dtype(dtype):
763-
# The function can return something of any type, so check
764-
# if the type is compatible with the calling EA.
765-
try:
766-
result = obj.values._from_sequence(result)
767-
except Exception:
768-
# https://github.com/pandas-dev/pandas/issues/22850
769-
# pandas has no control over what 3rd-party ExtensionArrays
770-
# do in _values_from_sequence. We still want ops to work
771-
# though, so we catch any regular Exception.
772-
pass
773-
774762
if not is_scalar(result):
775-
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
763+
if is_extension_array_dtype(dtype):
764+
# The function can return something of any type, so check
765+
# if the type is compatible with the calling EA.
766+
try:
767+
result = obj.values._from_sequence(result)
768+
except Exception:
769+
# https://github.com/pandas-dev/pandas/issues/22850
770+
# pandas has no control over what 3rd-party ExtensionArrays
771+
# do in _values_from_sequence. We still want ops to work
772+
# though, so we catch any regular Exception.
773+
pass
774+
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
776775
result = maybe_downcast_to_dtype(result, dtype)
777776

778777
return result

pandas/tests/arrays/test_integer.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -651,11 +651,9 @@ def test_preserve_dtypes(op):
651651

652652
# groupby
653653
result = getattr(df.groupby("A"), op)()
654-
expected = pd.DataFrame({
655-
"B": np.array([1.0, 3.0]),
656-
"C": np.array([1, 3], dtype="int64")
657-
}, index=pd.Index(['a', 'b'], name='A'))
658-
tm.assert_frame_equal(result, expected)
654+
655+
assert result.dtypes['B'].name == 'float64'
656+
assert result.dtypes['C'].name == 'Int64'
659657

660658

661659
@pytest.mark.parametrize('op', ['mean'])
@@ -674,11 +672,23 @@ def test_reduce_to_float(op):
674672

675673
# groupby
676674
result = getattr(df.groupby("A"), op)()
677-
expected = pd.DataFrame({
678-
"B": np.array([1.0, 3.0]),
679-
"C": np.array([1, 3], dtype="float64")
680-
}, index=pd.Index(['a', 'b'], name='A'))
681-
tm.assert_frame_equal(result, expected)
675+
676+
assert result.dtypes['B'].name == 'float64'
677+
assert result.dtypes['C'].name == 'Int64'
678+
679+
680+
@pytest.mark.parametrize('op', ['sum'])
681+
def test_groupby_extension_array(op):
682+
# GH23227
683+
# groupby on an extension array should return the extension array type
684+
df = pd.DataFrame({
685+
'Int': pd.Series([1, 2, 3], dtype='Int64'),
686+
'A': [1, 2, 1]
687+
})
688+
689+
result = getattr(df.groupby('A').Int, op)()
690+
assert result is not None
691+
assert result.dtype.name == 'Int64'
682692

683693

684694
def test_astype_nansafe():

pandas/tests/groupby/test_grouping.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,6 @@ def test_groupby_categorical_index_and_columns(self, observed):
282282
expected = DataFrame(data=expected_data.T, index=expected_columns)
283283
assert_frame_equal(result, expected)
284284

285-
def test_groupby_extension_array(self):
286-
287-
# GH23227
288-
# groupby on an extension array should return the extension array type
289-
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
290-
'A': [1, 2, 1]})
291-
result = df.groupby('A').Int.sum()
292-
assert result is not None
293-
assert result.dtype.name == 'Int64'
294-
295285
def test_grouper_getting_correct_binner(self):
296286

297287
# GH 10063

pandas/tests/sparse/test_groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ def test_groupby_includes_fill_value(fill_value):
5555
sdf = df.to_sparse(fill_value=fill_value)
5656
result = sdf.groupby('a').sum()
5757
expected = df.groupby('a').sum()
58-
tm.assert_frame_equal(result, expected,
58+
tm.assert_frame_equal(result, expected.to_sparse(fill_value=fill_value),
5959
check_index_type=False)

0 commit comments

Comments
 (0)