Skip to content

Commit dace44a

Browse files
ma and ewm trendlines
1 parent 5ab8da3 commit dace44a

File tree

1 file changed

+35
-8
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+35
-8
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
303303
mapping_labels["count"] = "%{x}"
304304
elif attr_name == "trendline":
305305
if (
306-
attr_value in ["ols", "lowess"]
306+
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
307307
and args["x"]
308308
and args["y"]
309309
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
@@ -335,19 +335,36 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
335335
)
336336

337337
# preserve original values of "x" in case they're dates
338-
trace_patch["x"] = sorted_trace_data[args["x"]][
339-
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
340-
]
338+
non_missing = np.logical_not(
339+
np.logical_or(np.isnan(y), np.isnan(x))
340+
)
341+
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
341342

342-
if attr_value == "lowess":
343+
if attr_value[0] == "lowess":
344+
alpha = attr_value[1] or 0.6666666
343345
# missing ='drop' is the default value for lowess but not for OLS (None)
344346
# we force it here in case statsmodels change their defaults
345-
trendline = sm.nonparametric.lowess(y, x, missing="drop")
347+
trendline = sm.nonparametric.lowess(
348+
y, x, missing="drop", frac=alpha
349+
)
346350
trace_patch["y"] = trendline[:, 1]
347351
hover_header = "<b>LOWESS trendline</b><br><br>"
348-
elif attr_value == "ols":
352+
elif attr_value[0] == "ma":
353+
trace_patch["y"] = (
354+
pd.Series(y[non_missing])
355+
.rolling(window=attr_value[1] or 3)
356+
.mean()
357+
)
358+
elif attr_value[0] == "ewm":
359+
trace_patch["y"] = (
360+
pd.Series(y[non_missing])
361+
.ewm(alpha=attr_value[1] or 0.5)
362+
.mean()
363+
)
364+
elif attr_value[0] == "ols":
365+
add_constant = attr_value[1] is not False
349366
fit_results = sm.OLS(
350-
y, sm.add_constant(x), missing="drop"
367+
y, sm.add_constant(x) if add_constant else x, missing="drop"
351368
).fit()
352369
trace_patch["y"] = fit_results.predict()
353370
hover_header = "<b>OLS trendline</b><br>"
@@ -358,6 +375,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
358375
args["x"],
359376
fit_results.params[0],
360377
)
378+
elif not add_constant:
379+
hover_header += "%s = %g* %s<br>" % (
380+
args["y"],
381+
fit_results.params[0],
382+
args["x"],
383+
)
361384
else:
362385
hover_header += "%s = %g<br>" % (
363386
args["y"],
@@ -1799,6 +1822,10 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17991822
):
18001823
args["facet_col_wrap"] = 0
18011824

1825+
if args.get("trendline", None) is not None:
1826+
if isinstance(args["trendline"], str):
1827+
args["trendline"] = (args["trendline"], None)
1828+
18021829
# Compute applicable grouping attributes
18031830
for k in group_attrables:
18041831
if k in args:

0 commit comments

Comments
 (0)