|
1 | 1 | import numpy as np
|
2 | 2 | import pytest
|
3 | 3 |
|
| 4 | +from pandas.core.dtypes.common import pandas_dtype |
| 5 | + |
4 | 6 | from pandas import (
|
5 | 7 | NA,
|
6 | 8 | DataFrame,
|
|
19 | 21 | def get_dtype(dtype, coerce_int=None):
|
20 | 22 | if coerce_int is False and "int" in dtype:
|
21 | 23 | return None
|
22 |
| - if dtype != "category": |
23 |
| - return np.dtype(dtype) |
24 |
| - return dtype |
| 24 | + return pandas_dtype(dtype) |
25 | 25 |
|
26 | 26 |
|
27 | 27 | @pytest.mark.parametrize(
|
@@ -66,21 +66,23 @@ def get_dtype(dtype, coerce_int=None):
|
66 | 66 | ],
|
67 | 67 | )
|
68 | 68 | def test_series_dtypes(method, data, expected_data, coerce_int, dtypes, min_periods):
|
69 |
| - s = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) |
70 |
| - if dtypes in ("m8[ns]", "M8[ns]") and method != "count": |
| 69 | + ser = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) |
| 70 | + rolled = ser.rolling(2, min_periods=min_periods) |
| 71 | + |
| 72 | + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": |
71 | 73 | msg = "No numeric types to aggregate"
|
72 | 74 | with pytest.raises(DataError, match=msg):
|
73 |
| - getattr(s.rolling(2, min_periods=min_periods), method)() |
| 75 | + getattr(rolled, method)() |
74 | 76 | else:
|
75 |
| - result = getattr(s.rolling(2, min_periods=min_periods), method)() |
| 77 | + result = getattr(rolled, method)() |
76 | 78 | expected = Series(expected_data, dtype="float64")
|
77 | 79 | tm.assert_almost_equal(result, expected)
|
78 | 80 |
|
79 | 81 |
|
80 | 82 | def test_series_nullable_int(any_signed_int_ea_dtype):
|
81 | 83 | # GH 43016
|
82 |
| - s = Series([0, 1, NA], dtype=any_signed_int_ea_dtype) |
83 |
| - result = s.rolling(2).mean() |
| 84 | + ser = Series([0, 1, NA], dtype=any_signed_int_ea_dtype) |
| 85 | + result = ser.rolling(2).mean() |
84 | 86 | expected = Series([np.nan, 0.5, np.nan])
|
85 | 87 | tm.assert_series_equal(result, expected)
|
86 | 88 |
|
@@ -130,14 +132,15 @@ def test_series_nullable_int(any_signed_int_ea_dtype):
|
130 | 132 | ],
|
131 | 133 | )
|
132 | 134 | def test_dataframe_dtypes(method, expected_data, dtypes, min_periods):
|
133 |
| - if dtypes == "category": |
134 |
| - pytest.skip("Category dataframe testing not implemented.") |
| 135 | + |
135 | 136 | df = DataFrame(np.arange(10).reshape((5, 2)), dtype=get_dtype(dtypes))
|
136 |
| - if dtypes in ("m8[ns]", "M8[ns]") and method != "count": |
| 137 | + rolled = df.rolling(2, min_periods=min_periods) |
| 138 | + |
| 139 | + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": |
137 | 140 | msg = "No numeric types to aggregate"
|
138 | 141 | with pytest.raises(DataError, match=msg):
|
139 |
| - getattr(df.rolling(2, min_periods=min_periods), method)() |
| 142 | + getattr(rolled, method)() |
140 | 143 | else:
|
141 |
| - result = getattr(df.rolling(2, min_periods=min_periods), method)() |
| 144 | + result = getattr(rolled, method)() |
142 | 145 | expected = DataFrame(expected_data, dtype="float64")
|
143 | 146 | tm.assert_frame_equal(result, expected)
|
0 commit comments