Skip to content

Commit 6eac898

Browse files
ma and ewm trendlines
1 parent ce1bf36 commit 6eac898

File tree

1 file changed

+35
-8
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+35
-8
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
313313
mapping_labels["count"] = "%{x}"
314314
elif attr_name == "trendline":
315315
if (
316-
attr_value in ["ols", "lowess"]
316+
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
317317
and args["x"]
318318
and args["y"]
319319
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
@@ -345,19 +345,36 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
345345
)
346346

347347
# preserve original values of "x" in case they're dates
348-
trace_patch["x"] = sorted_trace_data[args["x"]][
349-
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
350-
]
348+
non_missing = np.logical_not(
349+
np.logical_or(np.isnan(y), np.isnan(x))
350+
)
351+
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
351352

352-
if attr_value == "lowess":
353+
if attr_value[0] == "lowess":
354+
alpha = attr_value[1] or 0.6666666
353355
# missing ='drop' is the default value for lowess but not for OLS (None)
354356
# we force it here in case statsmodels change their defaults
355-
trendline = sm.nonparametric.lowess(y, x, missing="drop")
357+
trendline = sm.nonparametric.lowess(
358+
y, x, missing="drop", frac=alpha
359+
)
356360
trace_patch["y"] = trendline[:, 1]
357361
hover_header = "<b>LOWESS trendline</b><br><br>"
358-
elif attr_value == "ols":
362+
elif attr_value[0] == "ma":
363+
trace_patch["y"] = (
364+
pd.Series(y[non_missing])
365+
.rolling(window=attr_value[1] or 3)
366+
.mean()
367+
)
368+
elif attr_value[0] == "ewm":
369+
trace_patch["y"] = (
370+
pd.Series(y[non_missing])
371+
.ewm(alpha=attr_value[1] or 0.5)
372+
.mean()
373+
)
374+
elif attr_value[0] == "ols":
375+
add_constant = attr_value[1] is not False
359376
fit_results = sm.OLS(
360-
y, sm.add_constant(x), missing="drop"
377+
y, sm.add_constant(x) if add_constant else x, missing="drop"
361378
).fit()
362379
trace_patch["y"] = fit_results.predict()
363380
hover_header = "<b>OLS trendline</b><br>"
@@ -368,6 +385,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
368385
args["x"],
369386
fit_results.params[0],
370387
)
388+
elif not add_constant:
389+
hover_header += "%s = %g* %s<br>" % (
390+
args["y"],
391+
fit_results.params[0],
392+
args["x"],
393+
)
371394
else:
372395
hover_header += "%s = %g<br>" % (
373396
args["y"],
@@ -1822,6 +1845,10 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18221845
):
18231846
args["facet_col_wrap"] = 0
18241847

1848+
if args.get("trendline", None) is not None:
1849+
if isinstance(args["trendline"], str):
1850+
args["trendline"] = (args["trendline"], None)
1851+
18251852
# Compute applicable grouping attributes
18261853
for k in group_attrables:
18271854
if k in args:

0 commit comments

Comments
 (0)