diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index b92e414f2055e..60c9c200407e8 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -277,6 +277,7 @@ Other enhancements - :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`) - Improved error message in ``corr`` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`) - :meth:`DataFrame.explode` now supports exploding multiple columns. Its ``column`` argument now also accepts a list of str or tuples for exploding on multiple columns at the same time (:issue:`39240`) +- :meth:`DataFrame.sample` now accepts the ``ignore_index`` argument to reset the index after sampling, similar to :meth:`DataFrame.drop_duplicates` and :meth:`DataFrame.sort_values` (:issue:`38581`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index adc722e770cff..f770dd5831e84 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5144,11 +5144,12 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries: def sample( self: FrameOrSeries, n=None, - frac=None, - replace=False, + frac: float | None = None, + replace: bool_t = False, weights=None, random_state: RandomState | None = None, - axis=None, + axis: Axis | None = None, + ignore_index: bool_t = False, ) -> FrameOrSeries: """ Return a random sample of items from an axis of object. @@ -5189,6 +5190,10 @@ def sample( axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type (0 for Series and DataFrames). + ignore_index : bool, default False + If True, the resulting index will be labeled 0, 1, …, n - 1. + + .. versionadded:: 1.3.0 Returns ------- @@ -5350,7 +5355,11 @@ def sample( ) locs = rs.choice(axis_length, size=n, replace=replace, p=weights) - return self.take(locs, axis=axis) + result = self.take(locs, axis=axis) + if ignore_index: + result.index = ibase.default_index(len(result)) + + return result @final @doc(klass=_shared_doc_kwargs["klass"]) diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index 55ef665c55241..604788ba91633 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -5,6 +5,7 @@ from pandas import ( DataFrame, + Index, Series, ) import pandas._testing as tm @@ -326,3 +327,12 @@ def test_sample_is_copy(self): with tm.assert_produces_warning(None): df2["d"] = 1 + + def test_sample_ignore_index(self): + # GH 38581 + df = DataFrame( + {"col1": range(10, 20), "col2": range(20, 30), "colString": ["a"] * 10} + ) + result = df.sample(3, ignore_index=True) + expected_index = Index([0, 1, 2]) + tm.assert_index_equal(result.index, expected_index)