diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index c6f2fc099a0..7a08d2e6ac4 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -12,6 +12,7 @@ def scatter( size=None, hover_name=None, hover_data=None, + custom_data=None, text=None, facet_row=None, facet_col=None, @@ -174,6 +175,7 @@ def line( line_dash=None, hover_name=None, hover_data=None, + custom_data=None, text=None, facet_row=None, facet_col=None, @@ -217,6 +219,7 @@ def area( color=None, hover_name=None, hover_data=None, + custom_data=None, text=None, facet_row=None, facet_col=None, @@ -262,6 +265,7 @@ def bar( facet_col=None, hover_name=None, hover_data=None, + custom_data=None, text=None, error_x=None, error_x_minus=None, @@ -368,6 +372,7 @@ def violin( facet_col=None, hover_name=None, hover_data=None, + custom_data=None, animation_frame=None, animation_group=None, category_orders={}, @@ -418,6 +423,7 @@ def box( facet_col=None, hover_name=None, hover_data=None, + custom_data=None, animation_frame=None, animation_group=None, category_orders={}, @@ -463,6 +469,7 @@ def strip( facet_col=None, hover_name=None, hover_data=None, + custom_data=None, animation_frame=None, animation_group=None, category_orders={}, @@ -514,6 +521,7 @@ def scatter_3d( text=None, hover_name=None, hover_data=None, + custom_data=None, error_x=None, error_x_minus=None, error_y=None, @@ -564,6 +572,7 @@ def line_3d( line_group=None, hover_name=None, hover_data=None, + custom_data=None, error_x=None, error_x_minus=None, error_y=None, @@ -609,6 +618,7 @@ def scatter_ternary( text=None, hover_name=None, hover_data=None, + custom_data=None, animation_frame=None, animation_group=None, category_orders={}, @@ -646,6 +656,7 @@ def line_ternary( line_group=None, hover_name=None, hover_data=None, + custom_data=None, text=None, animation_frame=None, animation_group=None, @@ -679,6 +690,7 @@ def scatter_polar( size=None, hover_name=None, hover_data=None, + custom_data=None, text=None, animation_frame=None, animation_group=None, @@ -721,6 +733,7 @@ def line_polar( line_dash=None, hover_name=None, hover_data=None, + custom_data=None, line_group=None, text=None, animation_frame=None, @@ -759,6 +772,7 @@ def bar_polar( color=None, hover_name=None, hover_data=None, + custom_data=None, animation_frame=None, animation_group=None, category_orders={}, @@ -798,6 +812,7 @@ def choropleth( color=None, hover_name=None, hover_data=None, + custom_data=None, size=None, animation_frame=None, animation_group=None, @@ -838,6 +853,7 @@ def scatter_geo( text=None, hover_name=None, hover_data=None, + custom_data=None, size=None, animation_frame=None, animation_group=None, @@ -882,6 +898,7 @@ def line_geo( text=None, hover_name=None, hover_data=None, + custom_data=None, line_group=None, animation_frame=None, animation_group=None, @@ -920,6 +937,7 @@ def scatter_mapbox( text=None, hover_name=None, hover_data=None, + custom_data=None, size=None, animation_frame=None, animation_group=None, @@ -955,6 +973,7 @@ def line_mapbox( text=None, hover_name=None, hover_data=None, + custom_data=None, line_group=None, animation_frame=None, animation_group=None, @@ -985,6 +1004,7 @@ def scatter_matrix( size=None, hover_name=None, hover_data=None, + custom_data=None, category_orders={}, labels={}, color_discrete_sequence=None, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 69b3c9ac382..3d55d3004da 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -6,6 +6,7 @@ from .colors import qualitative, sequential import math import pandas +import numpy as np from plotly.subplots import ( make_subplots, @@ -137,12 +138,35 @@ def make_mapping(args, variable): def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): - + """Populates a dict with arguments to update trace + + Parameters + ---------- + args : dict + args to be used for the trace + trace_spec : NamedTuple + which kind of trace to be used (has constructor, marginal etc. + attributes) + g : pandas DataFrame + data + mapping_labels : dict + to be used for hovertemplate + sizeref : float + marker sizeref + + Returns + ------- + result : dict + dict to be used to update trace + fit_results : dict + fit information to be used for trendlines + """ if "line_close" in args and args["line_close"]: g = g.append(g.iloc[0]) result = trace_spec.trace_patch.copy() or {} fit_results = None hover_header = "" + custom_data_len = 0 for k in trace_spec.attrs: v = args[k] v_label = get_decorated_label(args, v, k) @@ -194,7 +218,6 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): elif k == "trendline": if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1: import statsmodels.api as sm - import numpy as np # sorting is bad but trace_specs with "trendline" have no other attrs g2 = g.sort_values(by=args["x"]) @@ -231,6 +254,9 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): if error_xy not in result: result[error_xy] = {} result[error_xy][arr] = g[v] + elif k == "custom_data": + result["customdata"] = g[v].values + custom_data_len = len(v) # number of custom data columns elif k == "hover_name": if trace_spec.constructor not in [ go.Histogram, @@ -246,10 +272,20 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): go.Histogram2d, go.Histogram2dContour, ]: - result["customdata"] = g[v].values - for i, col in enumerate(v): + for col in v: + try: + position = args["custom_data"].index(col) + except (ValueError, AttributeError, KeyError): + position = custom_data_len + custom_data_len += 1 + if "customdata" in result: + result["customdata"] = np.hstack( + (result["customdata"], g[col].values[:, None]) + ) + else: + result["customdata"] = g[col].values[:, None] v_label_col = get_decorated_label(args, col, None) - mapping_labels[v_label_col] = "%%{customdata[%d]}" % i + mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position) elif k == "color": if trace_spec.constructor == go.Choropleth: result["z"] = g[v] @@ -721,12 +757,13 @@ def apply_default_cascade(args): def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( - ["x", "y", "z", "a", "b", "c", "r", "theta", "size"] - + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"] + ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] + + ["custom_data", "hover_name", "hover_data", "text"] + + ["error_x", "error_x_minus"] + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + ["lat", "lon", "locations", "animation_group"] ) - array_attrables = ["dimensions", "hover_data"] + array_attrables = ["dimensions", "custom_data", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] # Validate that the strings provided as attribute values reference columns @@ -916,6 +953,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): if constructor_to_use == go.Scatter else go.Scatterpolargl ) + # Create the trace trace = constructor_to_use(name=trace_name) if trace_spec.constructor not in [ go.Parcats, diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 8c54f89a026..fbefe4e3860 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -111,6 +111,10 @@ colref_list, "Values from these columns appear as extra data in the hover tooltip.", ], + custom_data=[ + colref_list, + "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)", + ], text=[colref, "Values from this column appear in the figure as text labels."], locationmode=[ "(string, one of 'ISO-3', 'USA-states', 'country names')", diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py index 001bdb6997a..588bfa3d18a 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py @@ -10,3 +10,44 @@ def test_scatter(): assert np.all(fig.data[0].y == iris.sepal_length) # test defaults assert fig.data[0].mode == "markers" + + +def test_custom_data_scatter(): + iris = px.data.iris() + # No hover, no custom data + fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species") + assert fig.data[0].customdata is None + # Hover, no custom data + fig = px.scatter( + iris, + x="sepal_width", + y="sepal_length", + color="species", + hover_data=["petal_length", "petal_width"], + ) + for data in fig.data: + assert np.all(np.in1d(data.customdata[:, 1], iris.petal_width)) + # Hover and custom data, no repeated arguments + fig = px.scatter( + iris, + x="sepal_width", + y="sepal_length", + hover_data=["petal_length", "petal_width"], + custom_data=["species_id", "species"], + ) + assert np.all(fig.data[0].customdata[:, 0] == iris.species_id) + assert fig.data[0].customdata.shape[1] == 4 + # Hover and custom data, with repeated arguments + fig = px.scatter( + iris, + x="sepal_width", + y="sepal_length", + hover_data=["petal_length", "petal_width", "species_id"], + custom_data=["species_id", "species"], + ) + assert np.all(fig.data[0].customdata[:, 0] == iris.species_id) + assert fig.data[0].customdata.shape[1] == 4 + assert ( + fig.data[0].hovertemplate + == "sepal_width=%{x}
sepal_length=%{y}
petal_length=%{customdata[2]}
petal_width=%{customdata[3]}
species_id=%{customdata[0]}" + )