Skip to content

Commit 949c949

Browse files
committed
1.remove inf check
2.reformat file 3.add tests
1 parent 56005d5 commit 949c949

File tree

3 files changed

+104
-4
lines changed

3 files changed

+104
-4
lines changed

pandas/_libs/window/aggregations.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ cdef inline void add_var(float64_t val, float64_t *nobs, float64_t *mean_x,
313313
if val == prev_value[0]:
314314
num_consecutive_same_value[0] += 1
315315
else:
316-
num_consecutive_same_value[0] = 1 # reset to 1 (include current value itself)
316+
# reset to 1 (include current value itself)
317+
num_consecutive_same_value[0] = 1
317318
prev_value[0] = val
318319

319320
# Welford's method for the online variance-calculation
@@ -361,8 +362,7 @@ def roll_var(const float64_t[:] values, ndarray[int64_t] start,
361362
"""
362363
cdef:
363364
float64_t mean_x, ssqdm_x, nobs, compensation_add,
364-
float64_t compensation_remove,
365-
float64_t val, prev, delta, mean_x_old, prev_value
365+
float64_t compensation_remove, prev_value
366366
int64_t s, e
367367
Py_ssize_t i, j, N = len(start), num_consecutive_same_value = 0
368368
ndarray[float64_t] output

pandas/core/_numba/kernels/var_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def add_var(
2626
) -> tuple[int, float, float, float, int, float]:
2727
if not np.isnan(val):
2828

29-
if val == prev_value and not np.isinf(val):
29+
if val == prev_value:
3030
num_consecutive_same_value += 1
3131
else:
3232
num_consecutive_same_value = 1

pandas/tests/window/test_rolling.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,3 +1781,103 @@ def test_step_not_integer_raises():
17811781
def test_step_not_positive_raises():
17821782
with pytest.raises(ValueError, match="step must be >= 0"):
17831783
DataFrame(range(2)).rolling(1, step=-1)
1784+
1785+
1786+
@pytest.mark.parametrize(
1787+
["values", "window", "min_periods", "expected"],
1788+
[
1789+
[
1790+
np.array([20, 10, 10, np.inf, 1, 1, 2, 3]),
1791+
3,
1792+
1,
1793+
np.array(
1794+
[
1795+
np.nan,
1796+
50.0,
1797+
33.33333333333333,
1798+
0.0,
1799+
40.5,
1800+
0.0,
1801+
0.3333333333333333,
1802+
1.0,
1803+
]
1804+
),
1805+
],
1806+
[
1807+
np.array([20, 10, 10, np.nan, 10, 1, 2, 3]),
1808+
3,
1809+
1,
1810+
np.array(
1811+
[
1812+
np.nan,
1813+
50.0,
1814+
33.33333333333333,
1815+
0.0,
1816+
0.0,
1817+
40.5,
1818+
24.333333333333332,
1819+
1.0,
1820+
]
1821+
),
1822+
],
1823+
[
1824+
np.array([np.nan, 5, 6, 7, 5, 5, 5]),
1825+
3,
1826+
3,
1827+
np.array([np.nan, np.nan, np.nan, 1.0, 1.0, 1.3333333333333335, 0.0]),
1828+
],
1829+
[
1830+
np.array([5, 7, 7, 7, np.nan, np.inf, 4, 3, 3, 3]),
1831+
3,
1832+
3,
1833+
np.array(
1834+
[
1835+
np.nan,
1836+
np.nan,
1837+
1.3333333333333335,
1838+
0.0,
1839+
np.nan,
1840+
np.nan,
1841+
np.nan,
1842+
np.nan,
1843+
0.33333333333333337,
1844+
0.0,
1845+
]
1846+
),
1847+
],
1848+
[
1849+
np.array([5, 7, 7, 7, np.nan, np.inf, 7, 3, 3, 3]),
1850+
3,
1851+
3,
1852+
np.array(
1853+
[
1854+
np.nan,
1855+
np.nan,
1856+
1.3333333333333335,
1857+
0.0,
1858+
np.nan,
1859+
np.nan,
1860+
np.nan,
1861+
np.nan,
1862+
5.333333333333333,
1863+
0.0,
1864+
]
1865+
),
1866+
],
1867+
],
1868+
)
1869+
def test_rolling_var_same_value_count_logic(values, window, min_periods, expected):
1870+
# GH 42064
1871+
1872+
sr = Series(values)
1873+
result_var = sr.rolling(window, min_periods=min_periods).var()
1874+
# 1. result should be close to correct value
1875+
# non-zero values can still differ slightly as the result of online algorithm
1876+
assert np.isclose(result_var, expected, equal_nan=True).all()
1877+
# 2. zeros should be exactly the same since the new algo takes effect here
1878+
assert (result_var[expected == 0] == 0).all()
1879+
1880+
# std should also pass as it's just a sqrt of var
1881+
result_std = sr.rolling(window, min_periods=min_periods).std()
1882+
assert np.isclose(result_std, np.sqrt(expected), equal_nan=True).all()
1883+
assert (result_std[expected == 0] == 0).all()

0 commit comments

Comments
 (0)