Skip to content

Commit 5b295a2

Browse files
committed
Comprehensive tests for all groupby rank args
1 parent 67ca634 commit 5b295a2

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

pandas/tests/groupby/test_groupby.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,168 @@ def test_rank_apply(self):
18951895
expected = expected.reindex(result.index)
18961896
assert_series_equal(result, expected)
18971897

1898+
@pytest.mark.parametrize("vals", [
1899+
[2, 2, 8, 2, 6], ['bar', 'bar', 'foo', 'bar', 'baz']])
1900+
@pytest.mark.parametrize("ties_method,ascending,pct,exp", [
1901+
('average', True, False, DataFrame(
1902+
[2., 2., 5., 2., 4.], columns=['val'])),
1903+
('average', True, True, DataFrame(
1904+
[0.4, 0.4, 1.0, 0.4, 0.8], columns=['val'])),
1905+
('average', False, False, DataFrame(
1906+
[4., 4., 1., 4., 2.], columns=['val'])),
1907+
('average', False, True, DataFrame(
1908+
[.8, .8, .2, .8, .4], columns=['val'])),
1909+
('min', True, False, DataFrame(
1910+
[1., 1., 5., 1., 4.], columns=['val'])),
1911+
('min', True, True, DataFrame(
1912+
[0.2, 0.2, 1.0, 0.2, 0.8], columns=['val'])),
1913+
('min', False, False, DataFrame(
1914+
[3., 3., 1., 3., 2.], columns=['val'])),
1915+
('min', False, True, DataFrame(
1916+
[.6, .6, .2, .6, .4], columns=['val'])),
1917+
('max', True, False, DataFrame(
1918+
[3., 3., 5., 3., 4.], columns=['val'])),
1919+
('max', True, True, DataFrame(
1920+
[0.6, 0.6, 1.0, 0.6, 0.8], columns=['val'])),
1921+
('max', False, False, DataFrame(
1922+
[5., 5., 1., 5., 2.], columns=['val'])),
1923+
('max', False, True, DataFrame(
1924+
[1., 1., .2, 1., .4], columns=['val'])),
1925+
('first', True, False, DataFrame(
1926+
[1., 2., 5., 3., 4.], columns=['val'])),
1927+
('first', True, True, DataFrame(
1928+
[0.2, 0.4, 1.0, 0.6, 0.8], columns=['val'])),
1929+
('first', False, False, DataFrame(
1930+
[3., 4., 1., 5., 2.], columns=['val'])),
1931+
('first', False, True, DataFrame(
1932+
[.6, .8, .2, 1., .4], columns=['val'])),
1933+
('dense', True, False, DataFrame(
1934+
[1., 1., 3., 1., 2.], columns=['val'])),
1935+
('dense', True, True, DataFrame(
1936+
[0.2, 0.2, 0.6, 0.2, 0.4], columns=['val'])),
1937+
('dense', False, False, DataFrame(
1938+
[3., 3., 1., 3., 2.], columns=['val'])),
1939+
('dense', False, True, DataFrame(
1940+
[.6, .6, .2, .6, .4], columns=['val'])),
1941+
])
1942+
def test_rank_args(self, vals, ties_method, ascending, pct, exp):
1943+
if ties_method == 'first' and vals[0] == 'bar':
1944+
pytest.xfail("See GH 19482")
1945+
df = DataFrame({'key': ['foo']*5, 'val': vals})
1946+
result = df.groupby('key').rank(method=ties_method, ascending=ascending,
1947+
pct=pct)
1948+
1949+
assert_frame_equal(result, exp)
1950+
1951+
@pytest.mark.parametrize("vals", [
1952+
[2, 2, np.nan, 8, 2, 6, np.nan, np.nan], # floats
1953+
['bar', 'bar', np.nan, 'foo', 'bar', 'baz', np.nan, np.nan] # objects
1954+
])
1955+
@pytest.mark.parametrize("ties_method,ascending,na_option,pct,exp", [
1956+
('average', True, 'keep', False, DataFrame(
1957+
[2., 2., np.nan, 5., 2., 4., np.nan, np.nan], columns=['val'])),
1958+
('average', True, 'keep', True, DataFrame(
1959+
[0.4, 0.4, np.nan, 1.0, 0.4, 0.8, np.nan, np.nan],
1960+
columns=['val'])),
1961+
('average', False, 'keep', False, DataFrame(
1962+
[4., 4., np.nan, 1., 4., 2., np.nan, np.nan], columns=['val'])),
1963+
('average', False, 'keep', True, DataFrame(
1964+
[.8, 0.8, np.nan, 0.2, 0.8, 0.4, np.nan, np.nan], columns=['val'])),
1965+
('min', True, 'keep', False, DataFrame(
1966+
[1., 1., np.nan, 5., 1., 4., np.nan, np.nan], columns=['val'])),
1967+
('min', True, 'keep', True, DataFrame(
1968+
[0.2, 0.2, np.nan, 1.0, 0.2, 0.8, np.nan, np.nan],
1969+
columns=['val'])),
1970+
('min', False, 'keep', False, DataFrame(
1971+
[3., 3., np.nan, 1., 3., 2., np.nan, np.nan], columns=['val'])),
1972+
('min', False, 'keep', True, DataFrame(
1973+
[.6, 0.6, np.nan, 0.2, 0.6, 0.4, np.nan, np.nan], columns=['val'])),
1974+
('max', True, 'keep', False, DataFrame(
1975+
[3., 3., np.nan, 5., 3., 4., np.nan, np.nan], columns=['val'])),
1976+
('max', True, 'keep', True, DataFrame(
1977+
[0.6, 0.6, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan],
1978+
columns=['val'])),
1979+
('max', False, 'keep', False, DataFrame(
1980+
[5., 5., np.nan, 1., 5., 2., np.nan, np.nan], columns=['val'])),
1981+
('max', False, 'keep', True, DataFrame(
1982+
[1., 1., np.nan, 0.2, 1., 0.4, np.nan, np.nan], columns=['val'])),
1983+
('first', True, 'keep', False, DataFrame(
1984+
[1., 2., np.nan, 5., 3., 4., np.nan, np.nan], columns=['val'])),
1985+
('first', True, 'keep', True, DataFrame(
1986+
[0.2, 0.4, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan],
1987+
columns=['val'])),
1988+
('first', False, 'keep', False, DataFrame(
1989+
[3., 4., np.nan, 1., 5., 2., np.nan, np.nan], columns=['val'])),
1990+
('first', False, 'keep', True, DataFrame(
1991+
[.6, 0.8, np.nan, 0.2, 1., 0.4, np.nan, np.nan], columns=['val'])),
1992+
('dense', True, 'keep', False, DataFrame(
1993+
[1., 1., np.nan, 3., 1., 2., np.nan, np.nan], columns=['val'])),
1994+
('dense', True, 'keep', True, DataFrame(
1995+
[0.2, 0.2, np.nan, 0.6, 0.2, 0.4, np.nan, np.nan],
1996+
columns=['val'])),
1997+
('dense', False, 'keep', False, DataFrame(
1998+
[3., 3., np.nan, 1., 3., 2., np.nan, np.nan], columns=['val'])),
1999+
('dense', False, 'keep', True, DataFrame(
2000+
[.6, 0.6, np.nan, 0.2, 0.6, 0.4, np.nan, np.nan], columns=['val'])),
2001+
('average', True, 'no_na', False, DataFrame(
2002+
[2., 2., 7., 5., 2., 4., 7., 7.], columns=['val'])),
2003+
('average', True, 'no_na', True, DataFrame(
2004+
[0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875],
2005+
columns=['val'])),
2006+
('average', False, 'no_na', False, DataFrame(
2007+
[4., 4., 7.0, 1., 4., 2., 7.0, 7.0], columns=['val'])),
2008+
('average', False, 'no_na', True, DataFrame(
2009+
[0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875], columns=['val'])),
2010+
('min', True, 'no_na', False, DataFrame(
2011+
[1., 1., 6., 5., 1., 4., 6., 6.], columns=['val'])),
2012+
('min', True, 'no_na', True, DataFrame(
2013+
[0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75],
2014+
columns=['val'])),
2015+
('min', False, 'no_na', False, DataFrame(
2016+
[3., 3., 6., 1., 3., 2., 6., 6.], columns=['val'])),
2017+
('min', False, 'no_na', True, DataFrame(
2018+
[0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75],
2019+
columns=['val'])),
2020+
('max', True, 'no_na', False, DataFrame(
2021+
[3., 3., 8., 5., 3., 4., 8., 8.], columns=['val'])),
2022+
('max', True, 'no_na', True, DataFrame(
2023+
[0.375, 0.375, 1., 0.625, 0.375, 0.5, 1., 1.], columns=['val'])),
2024+
('max', False, 'no_na', False, DataFrame(
2025+
[5., 5., 8., 1., 5., 2., 8., 8.], columns=['val'])),
2026+
('max', False, 'no_na', True, DataFrame(
2027+
[0.625, 0.625, 1., 0.125, 0.625, 0.25, 1., 1.], columns=['val'])),
2028+
('first', True, 'no_na', False, DataFrame(
2029+
[1., 2., 6., 5., 3., 4., 7., 8.], columns=['val'])),
2030+
('first', True, 'no_na', True, DataFrame(
2031+
[0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.],
2032+
columns=['val'])),
2033+
('first', False, 'no_na', False, DataFrame(
2034+
[3., 4., 6., 1., 5., 2., 7., 8.], columns=['val'])),
2035+
('first', False, 'no_na', True, DataFrame(
2036+
[0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.],
2037+
columns=['val'])),
2038+
('dense', True, 'no_na', False, DataFrame(
2039+
[1., 1., 4., 3., 1., 2., 4., 4.], columns=['val'])),
2040+
('dense', True, 'no_na', True, DataFrame(
2041+
[0.125, 0.125, 0.5, 0.375, 0.125, 0.25, 0.5, 0.5],
2042+
columns=['val'])),
2043+
('dense', False, 'no_na', False, DataFrame(
2044+
[3., 3., 4., 1., 3., 2., 4., 4.], columns=['val'])),
2045+
('dense', False, 'no_na', True, DataFrame(
2046+
[0.375, 0.375, 0.5, 0.125, 0.375, 0.25, 0.5, 0.5],
2047+
columns=['val'])),
2048+
])
2049+
def test_rank_args_missing(self, vals, ties_method, ascending, na_option,
2050+
pct, exp):
2051+
if ties_method == 'first' and vals[0] == 'bar':
2052+
pytest.xfail("See GH 19482")
2053+
2054+
df = DataFrame({'key': ['foo']*8, 'val': vals})
2055+
result = df.groupby('key').rank(method=ties_method, ascending=ascending,
2056+
na_option=na_option, pct=pct)
2057+
2058+
assert_frame_equal(result, exp)
2059+
18982060
def test_dont_clobber_name_column(self):
18992061
df = DataFrame({'key': ['a', 'a', 'a', 'b', 'b', 'b'],
19002062
'name': ['foo', 'bar', 'baz'] * 2})

0 commit comments

Comments
 (0)