Skip to content

Commit f0bbf60

Browse files
extract trendline function API
1 parent 6eac898 commit f0bbf60

File tree

3 files changed

+72
-55
lines changed

3 files changed

+72
-55
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def scatter(
4646
marginal_x=None,
4747
marginal_y=None,
4848
trendline=None,
49+
trendline_options=None,
4950
trendline_color_override=None,
5051
log_x=False,
5152
log_y=False,
@@ -90,6 +91,7 @@ def density_contour(
9091
marginal_x=None,
9192
marginal_y=None,
9293
trendline=None,
94+
trendline_options=None,
9395
trendline_color_override=None,
9496
log_x=False,
9597
log_y=False,

packages/python/plotly/plotly/express/_core.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,56 @@ def make_mapping(args, variable):
239239
)
240240

241241

242+
def lowess(options, x, y, x_label, y_label, non_missing):
243+
import statsmodels.api as sm
244+
245+
frac = options.get("frac", 0.6666666)
246+
# missing ='drop' is the default value for lowess but not for OLS (None)
247+
# we force it here in case statsmodels change their defaults
248+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
249+
hover_header = "<b>LOWESS trendline</b><br><br>"
250+
return y_out, hover_header, None
251+
252+
253+
def ma(options, x, y, x_label, y_label, non_missing):
254+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
255+
hover_header = "<b>Moving Average trendline</b><br><br>"
256+
return y_out, hover_header, None
257+
258+
259+
def ewm(options, x, y, x_label, y_label, non_missing):
260+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
261+
hover_header = "<b>EWM trendline</b><br><br>"
262+
return y_out, hover_header, None
263+
264+
265+
def ols(options, x, y, x_label, y_label, non_missing):
266+
import statsmodels.api as sm
267+
268+
add_constant = options.get("add_constant", True)
269+
fit_results = sm.OLS(
270+
y, sm.add_constant(x) if add_constant else x, missing="drop"
271+
).fit()
272+
y_out = fit_results.predict()
273+
hover_header = "<b>OLS trendline</b><br>"
274+
if len(fit_results.params) == 2:
275+
hover_header += "%s = %g * %s + %g<br>" % (
276+
y_label,
277+
fit_results.params[1],
278+
x_label,
279+
fit_results.params[0],
280+
)
281+
elif not add_constant:
282+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
283+
else:
284+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
285+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
286+
return y_out, hover_header, fit_results
287+
288+
289+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
290+
291+
242292
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
243293
"""Populates a dict with arguments to update trace
244294
@@ -313,12 +363,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
313363
mapping_labels["count"] = "%{x}"
314364
elif attr_name == "trendline":
315365
if (
316-
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
366+
attr_value in trendline_functions
317367
and args["x"]
318368
and args["y"]
319369
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
320370
):
321-
import statsmodels.api as sm
322371

323372
# sorting is bad but trace_specs with "trendline" have no other attrs
324373
sorted_trace_data = trace_data.sort_values(by=args["x"])
@@ -349,56 +398,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
349398
np.logical_or(np.isnan(y), np.isnan(x))
350399
)
351400
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
352-
353-
if attr_value[0] == "lowess":
354-
alpha = attr_value[1] or 0.6666666
355-
# missing ='drop' is the default value for lowess but not for OLS (None)
356-
# we force it here in case statsmodels change their defaults
357-
trendline = sm.nonparametric.lowess(
358-
y, x, missing="drop", frac=alpha
359-
)
360-
trace_patch["y"] = trendline[:, 1]
361-
hover_header = "<b>LOWESS trendline</b><br><br>"
362-
elif attr_value[0] == "ma":
363-
trace_patch["y"] = (
364-
pd.Series(y[non_missing])
365-
.rolling(window=attr_value[1] or 3)
366-
.mean()
367-
)
368-
elif attr_value[0] == "ewm":
369-
trace_patch["y"] = (
370-
pd.Series(y[non_missing])
371-
.ewm(alpha=attr_value[1] or 0.5)
372-
.mean()
373-
)
374-
elif attr_value[0] == "ols":
375-
add_constant = attr_value[1] is not False
376-
fit_results = sm.OLS(
377-
y, sm.add_constant(x) if add_constant else x, missing="drop"
378-
).fit()
379-
trace_patch["y"] = fit_results.predict()
380-
hover_header = "<b>OLS trendline</b><br>"
381-
if len(fit_results.params) == 2:
382-
hover_header += "%s = %g * %s + %g<br>" % (
383-
args["y"],
384-
fit_results.params[1],
385-
args["x"],
386-
fit_results.params[0],
387-
)
388-
elif not add_constant:
389-
hover_header += "%s = %g* %s<br>" % (
390-
args["y"],
391-
fit_results.params[0],
392-
args["x"],
393-
)
394-
else:
395-
hover_header += "%s = %g<br>" % (
396-
args["y"],
397-
fit_results.params[0],
398-
)
399-
hover_header += (
400-
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
401-
)
401+
trendline_function = trendline_functions[attr_value]
402+
y_out, hover_header, fit_results = trendline_function(
403+
args["trendline_options"],
404+
x,
405+
y,
406+
args["x"],
407+
args["y"],
408+
non_missing,
409+
)
410+
assert len(y_out) == len(
411+
trace_patch["x"]
412+
), "missing-data-handling failure in trendline code"
413+
trace_patch["y"] = y_out
402414
mapping_labels[get_label(args, args["x"])] = "%{x}"
403415
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
404416
elif attr_name.startswith("error"):
@@ -1845,9 +1857,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18451857
):
18461858
args["facet_col_wrap"] = 0
18471859

1848-
if args.get("trendline", None) is not None:
1849-
if isinstance(args["trendline"], str):
1850-
args["trendline"] = (args["trendline"], None)
1860+
if "trendline_options" in args and args["trendline_options"] is None:
1861+
args["trendline_options"] = dict()
18511862

18521863
# Compute applicable grouping attributes
18531864
for k in group_attrables:

packages/python/plotly/plotly/express/_doc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@
405405
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
406406
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
407407
],
408+
trendline_options=[
409+
"dict",
410+
"Options passed to the function named in the `trendline` argument.",
411+
],
408412
trendline_color_override=[
409413
"str",
410414
"Valid CSS color.",

0 commit comments

Comments
 (0)