Skip to content

Commit b6dd8f9

Browse files
extract trendline function API
1 parent dace44a commit b6dd8f9

File tree

3 files changed

+72
-55
lines changed

3 files changed

+72
-55
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def scatter(
4646
marginal_x=None,
4747
marginal_y=None,
4848
trendline=None,
49+
trendline_options=None,
4950
trendline_color_override=None,
5051
log_x=False,
5152
log_y=False,
@@ -90,6 +91,7 @@ def density_contour(
9091
marginal_x=None,
9192
marginal_y=None,
9293
trendline=None,
94+
trendline_options=None,
9395
trendline_color_override=None,
9496
log_x=False,
9597
log_y=False,

packages/python/plotly/plotly/express/_core.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,56 @@ def make_mapping(args, variable):
229229
)
230230

231231

232+
def lowess(options, x, y, x_label, y_label, non_missing):
233+
import statsmodels.api as sm
234+
235+
frac = options.get("frac", 0.6666666)
236+
# missing ='drop' is the default value for lowess but not for OLS (None)
237+
# we force it here in case statsmodels change their defaults
238+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
239+
hover_header = "<b>LOWESS trendline</b><br><br>"
240+
return y_out, hover_header, None
241+
242+
243+
def ma(options, x, y, x_label, y_label, non_missing):
244+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
245+
hover_header = "<b>Moving Average trendline</b><br><br>"
246+
return y_out, hover_header, None
247+
248+
249+
def ewm(options, x, y, x_label, y_label, non_missing):
250+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
251+
hover_header = "<b>EWM trendline</b><br><br>"
252+
return y_out, hover_header, None
253+
254+
255+
def ols(options, x, y, x_label, y_label, non_missing):
256+
import statsmodels.api as sm
257+
258+
add_constant = options.get("add_constant", True)
259+
fit_results = sm.OLS(
260+
y, sm.add_constant(x) if add_constant else x, missing="drop"
261+
).fit()
262+
y_out = fit_results.predict()
263+
hover_header = "<b>OLS trendline</b><br>"
264+
if len(fit_results.params) == 2:
265+
hover_header += "%s = %g * %s + %g<br>" % (
266+
y_label,
267+
fit_results.params[1],
268+
x_label,
269+
fit_results.params[0],
270+
)
271+
elif not add_constant:
272+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
273+
else:
274+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
275+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
276+
return y_out, hover_header, fit_results
277+
278+
279+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
280+
281+
232282
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
233283
"""Populates a dict with arguments to update trace
234284
@@ -303,12 +353,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
303353
mapping_labels["count"] = "%{x}"
304354
elif attr_name == "trendline":
305355
if (
306-
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
356+
attr_value in trendline_functions
307357
and args["x"]
308358
and args["y"]
309359
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
310360
):
311-
import statsmodels.api as sm
312361

313362
# sorting is bad but trace_specs with "trendline" have no other attrs
314363
sorted_trace_data = trace_data.sort_values(by=args["x"])
@@ -339,56 +388,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
339388
np.logical_or(np.isnan(y), np.isnan(x))
340389
)
341390
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
342-
343-
if attr_value[0] == "lowess":
344-
alpha = attr_value[1] or 0.6666666
345-
# missing ='drop' is the default value for lowess but not for OLS (None)
346-
# we force it here in case statsmodels change their defaults
347-
trendline = sm.nonparametric.lowess(
348-
y, x, missing="drop", frac=alpha
349-
)
350-
trace_patch["y"] = trendline[:, 1]
351-
hover_header = "<b>LOWESS trendline</b><br><br>"
352-
elif attr_value[0] == "ma":
353-
trace_patch["y"] = (
354-
pd.Series(y[non_missing])
355-
.rolling(window=attr_value[1] or 3)
356-
.mean()
357-
)
358-
elif attr_value[0] == "ewm":
359-
trace_patch["y"] = (
360-
pd.Series(y[non_missing])
361-
.ewm(alpha=attr_value[1] or 0.5)
362-
.mean()
363-
)
364-
elif attr_value[0] == "ols":
365-
add_constant = attr_value[1] is not False
366-
fit_results = sm.OLS(
367-
y, sm.add_constant(x) if add_constant else x, missing="drop"
368-
).fit()
369-
trace_patch["y"] = fit_results.predict()
370-
hover_header = "<b>OLS trendline</b><br>"
371-
if len(fit_results.params) == 2:
372-
hover_header += "%s = %g * %s + %g<br>" % (
373-
args["y"],
374-
fit_results.params[1],
375-
args["x"],
376-
fit_results.params[0],
377-
)
378-
elif not add_constant:
379-
hover_header += "%s = %g* %s<br>" % (
380-
args["y"],
381-
fit_results.params[0],
382-
args["x"],
383-
)
384-
else:
385-
hover_header += "%s = %g<br>" % (
386-
args["y"],
387-
fit_results.params[0],
388-
)
389-
hover_header += (
390-
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
391-
)
391+
trendline_function = trendline_functions[attr_value]
392+
y_out, hover_header, fit_results = trendline_function(
393+
args["trendline_options"],
394+
x,
395+
y,
396+
args["x"],
397+
args["y"],
398+
non_missing,
399+
)
400+
assert len(y_out) == len(
401+
trace_patch["x"]
402+
), "missing-data-handling failure in trendline code"
403+
trace_patch["y"] = y_out
392404
mapping_labels[get_label(args, args["x"])] = "%{x}"
393405
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
394406
elif attr_name.startswith("error"):
@@ -1822,9 +1834,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18221834
):
18231835
args["facet_col_wrap"] = 0
18241836

1825-
if args.get("trendline", None) is not None:
1826-
if isinstance(args["trendline"], str):
1827-
args["trendline"] = (args["trendline"], None)
1837+
if "trendline_options" in args and args["trendline_options"] is None:
1838+
args["trendline_options"] = dict()
18281839

18291840
# Compute applicable grouping attributes
18301841
for k in group_attrables:

packages/python/plotly/plotly/express/_doc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@
388388
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
389389
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
390390
],
391+
trendline_options=[
392+
"dict",
393+
"Options passed to the function named in the `trendline` argument.",
394+
],
391395
trendline_color_override=[
392396
"str",
393397
"Valid CSS color.",

0 commit comments

Comments
 (0)