diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py
index c6f2fc099a0..7a08d2e6ac4 100644
--- a/packages/python/plotly/plotly/express/_chart_types.py
+++ b/packages/python/plotly/plotly/express/_chart_types.py
@@ -12,6 +12,7 @@ def scatter(
size=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
facet_row=None,
facet_col=None,
@@ -174,6 +175,7 @@ def line(
line_dash=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
facet_row=None,
facet_col=None,
@@ -217,6 +219,7 @@ def area(
color=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
facet_row=None,
facet_col=None,
@@ -262,6 +265,7 @@ def bar(
facet_col=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
error_x=None,
error_x_minus=None,
@@ -368,6 +372,7 @@ def violin(
facet_col=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -418,6 +423,7 @@ def box(
facet_col=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -463,6 +469,7 @@ def strip(
facet_col=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -514,6 +521,7 @@ def scatter_3d(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
error_x=None,
error_x_minus=None,
error_y=None,
@@ -564,6 +572,7 @@ def line_3d(
line_group=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
error_x=None,
error_x_minus=None,
error_y=None,
@@ -609,6 +618,7 @@ def scatter_ternary(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -646,6 +656,7 @@ def line_ternary(
line_group=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
animation_frame=None,
animation_group=None,
@@ -679,6 +690,7 @@ def scatter_polar(
size=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
text=None,
animation_frame=None,
animation_group=None,
@@ -721,6 +733,7 @@ def line_polar(
line_dash=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
line_group=None,
text=None,
animation_frame=None,
@@ -759,6 +772,7 @@ def bar_polar(
color=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -798,6 +812,7 @@ def choropleth(
color=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
@@ -838,6 +853,7 @@ def scatter_geo(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
@@ -882,6 +898,7 @@ def line_geo(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
line_group=None,
animation_frame=None,
animation_group=None,
@@ -920,6 +937,7 @@ def scatter_mapbox(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
@@ -955,6 +973,7 @@ def line_mapbox(
text=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
line_group=None,
animation_frame=None,
animation_group=None,
@@ -985,6 +1004,7 @@ def scatter_matrix(
size=None,
hover_name=None,
hover_data=None,
+ custom_data=None,
category_orders={},
labels={},
color_discrete_sequence=None,
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 69b3c9ac382..3d55d3004da 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -6,6 +6,7 @@
from .colors import qualitative, sequential
import math
import pandas
+import numpy as np
from plotly.subplots import (
make_subplots,
@@ -137,12 +138,35 @@ def make_mapping(args, variable):
def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
-
+ """Populates a dict with arguments to update trace
+
+ Parameters
+ ----------
+ args : dict
+ args to be used for the trace
+ trace_spec : NamedTuple
+ which kind of trace to be used (has constructor, marginal etc.
+ attributes)
+ g : pandas DataFrame
+ data
+ mapping_labels : dict
+ to be used for hovertemplate
+ sizeref : float
+ marker sizeref
+
+ Returns
+ -------
+ result : dict
+ dict to be used to update trace
+ fit_results : dict
+ fit information to be used for trendlines
+ """
if "line_close" in args and args["line_close"]:
g = g.append(g.iloc[0])
result = trace_spec.trace_patch.copy() or {}
fit_results = None
hover_header = ""
+ custom_data_len = 0
for k in trace_spec.attrs:
v = args[k]
v_label = get_decorated_label(args, v, k)
@@ -194,7 +218,6 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
elif k == "trendline":
if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1:
import statsmodels.api as sm
- import numpy as np
# sorting is bad but trace_specs with "trendline" have no other attrs
g2 = g.sort_values(by=args["x"])
@@ -231,6 +254,9 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
if error_xy not in result:
result[error_xy] = {}
result[error_xy][arr] = g[v]
+ elif k == "custom_data":
+ result["customdata"] = g[v].values
+ custom_data_len = len(v) # number of custom data columns
elif k == "hover_name":
if trace_spec.constructor not in [
go.Histogram,
@@ -246,10 +272,20 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
go.Histogram2d,
go.Histogram2dContour,
]:
- result["customdata"] = g[v].values
- for i, col in enumerate(v):
+ for col in v:
+ try:
+ position = args["custom_data"].index(col)
+ except (ValueError, AttributeError, KeyError):
+ position = custom_data_len
+ custom_data_len += 1
+ if "customdata" in result:
+ result["customdata"] = np.hstack(
+ (result["customdata"], g[col].values[:, None])
+ )
+ else:
+ result["customdata"] = g[col].values[:, None]
v_label_col = get_decorated_label(args, col, None)
- mapping_labels[v_label_col] = "%%{customdata[%d]}" % i
+ mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position)
elif k == "color":
if trace_spec.constructor == go.Choropleth:
result["z"] = g[v]
@@ -721,12 +757,13 @@ def apply_default_cascade(args):
def infer_config(args, constructor, trace_patch):
# Declare all supported attributes, across all plot types
attrables = (
- ["x", "y", "z", "a", "b", "c", "r", "theta", "size"]
- + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"]
+ ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
+ + ["custom_data", "hover_name", "hover_data", "text"]
+ + ["error_x", "error_x_minus"]
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ ["lat", "lon", "locations", "animation_group"]
)
- array_attrables = ["dimensions", "hover_data"]
+ array_attrables = ["dimensions", "custom_data", "hover_data"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
# Validate that the strings provided as attribute values reference columns
@@ -916,6 +953,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
if constructor_to_use == go.Scatter
else go.Scatterpolargl
)
+ # Create the trace
trace = constructor_to_use(name=trace_name)
if trace_spec.constructor not in [
go.Parcats,
diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py
index 8c54f89a026..fbefe4e3860 100644
--- a/packages/python/plotly/plotly/express/_doc.py
+++ b/packages/python/plotly/plotly/express/_doc.py
@@ -111,6 +111,10 @@
colref_list,
"Values from these columns appear as extra data in the hover tooltip.",
],
+ custom_data=[
+ colref_list,
+ "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
+ ],
text=[colref, "Values from this column appear in the figure as text labels."],
locationmode=[
"(string, one of 'ISO-3', 'USA-states', 'country names')",
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
index 001bdb6997a..588bfa3d18a 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
@@ -10,3 +10,44 @@ def test_scatter():
assert np.all(fig.data[0].y == iris.sepal_length)
# test defaults
assert fig.data[0].mode == "markers"
+
+
+def test_custom_data_scatter():
+ iris = px.data.iris()
+ # No hover, no custom data
+ fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species")
+ assert fig.data[0].customdata is None
+ # Hover, no custom data
+ fig = px.scatter(
+ iris,
+ x="sepal_width",
+ y="sepal_length",
+ color="species",
+ hover_data=["petal_length", "petal_width"],
+ )
+ for data in fig.data:
+ assert np.all(np.in1d(data.customdata[:, 1], iris.petal_width))
+ # Hover and custom data, no repeated arguments
+ fig = px.scatter(
+ iris,
+ x="sepal_width",
+ y="sepal_length",
+ hover_data=["petal_length", "petal_width"],
+ custom_data=["species_id", "species"],
+ )
+ assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
+ assert fig.data[0].customdata.shape[1] == 4
+ # Hover and custom data, with repeated arguments
+ fig = px.scatter(
+ iris,
+ x="sepal_width",
+ y="sepal_length",
+ hover_data=["petal_length", "petal_width", "species_id"],
+ custom_data=["species_id", "species"],
+ )
+ assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
+ assert fig.data[0].customdata.shape[1] == 4
+ assert (
+ fig.data[0].hovertemplate
+ == "sepal_width=%{x}
sepal_length=%{y}
petal_length=%{customdata[2]}
petal_width=%{customdata[3]}
species_id=%{customdata[0]}"
+ )