diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 9eab385a7a2a5..bee0817692bce 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -2700,7 +2700,7 @@ def plot_group(group, ax): return fig -def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, +def hist_frame(data, column=None, weights=None, by=None, grid=True, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False, sharey=False, figsize=None, layout=None, bins=10, **kwds): """ @@ -2711,6 +2711,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, data : DataFrame column : string or sequence If passed, will be used to limit data to a subset of columns + weights : string or sequence + If passed, will be used to weight the data by : object, optional If passed, then used to form histograms for separate groups grid : boolean, default True @@ -2742,7 +2744,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, """ if by is not None: - axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize, + axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize, sharex=sharex, sharey=sharey, layout=layout, bins=bins, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, **kwds) @@ -2846,7 +2848,7 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, return axes -def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, +def grouped_hist(data, column=None, weights=None, by=None, ax=None, bins=50, figsize=None, layout=None, sharex=False, sharey=False, rot=90, grid=True, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, **kwargs): @@ -2857,6 +2859,7 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, ---------- data: Series/DataFrame column: object, optional + weights: object, optional by: object, optional ax: axes, optional bins: int, default 50 @@ -2872,12 +2875,14 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, ------- axes: collection of Matplotlib Axes """ - def plot_group(group, ax): - ax.hist(group.dropna().values, bins=bins, **kwargs) + def plot_group(group, ax, weights=None): + if weights is not None: + weights=weights.dropna().values + ax.hist(group.dropna().values, weights=weights, bins=bins, **kwargs) xrot = xrot or rot - fig, axes = _grouped_plot(plot_group, data, column=column, + fig, axes = _grouped_plot(plot_group, data, column=column, weights=weights, by=by, sharex=sharex, sharey=sharey, ax=ax, figsize=figsize, layout=layout, rot=rot) @@ -2964,7 +2969,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, return ret -def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, +def _grouped_plot(plotf, data, column=None, weights=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, rot=0, ax=None, **kwargs): from pandas import DataFrame @@ -2977,6 +2982,8 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, grouped = data.groupby(by) if column is not None: + if weights is not None: + weights = grouped[weights] grouped = grouped[column] naxes = len(grouped) @@ -2986,11 +2993,20 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, _axes = _flatten(axes) + weight = None for i, (key, group) in enumerate(grouped): ax = _axes[i] + if weights is not None: + weight = weights.get_group(key) if numeric_only and isinstance(group, DataFrame): group = group._get_numeric_data() - plotf(group, ax, **kwargs) + if weight is not None: + weight = weight._get_numeric_data() + if weight is not None: + plotf(group, ax, weight, **kwargs) + else: + # scatterplot etc has not the weight implemented in plotf + plotf(group, ax, **kwargs) ax.set_title(com.pprint_thing(key)) return fig, axes