Skip to content

Commit 9d5d0cb

Browse files
move trendline code to own module
1 parent ede9bd4 commit 9d5d0cb

File tree

4 files changed

+62
-60
lines changed

4 files changed

+62
-60
lines changed

doc/apidoc/plotly.express.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ plotly's high-level API for rapid figure generation. ::
5959

6060
generated/plotly.express.data.rst
6161
generated/plotly.express.colors.rst
62+
generated/plotly.express.trendline_functions.rst

packages/python/plotly/plotly/express/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
6161

62-
from . import data, colors # noqa: F401
62+
from . import data, colors, trendline_functions # noqa: F401
6363

6464
__all__ = [
6565
"scatter",

packages/python/plotly/plotly/express/_core.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
44
from ._special_inputs import IdentityMap, Constant, Range
5+
from .trendline_functions import ols, lowess, ma, ewm
56

67
from _plotly_utils.basevalidators import ColorscaleValidator
78
from plotly.colors import qualitative, sequential
@@ -229,65 +230,6 @@ def make_mapping(args, variable):
229230
)
230231

231232

232-
def lowess(options, x, y, x_label, y_label, non_missing):
233-
import statsmodels.api as sm
234-
235-
frac = options.get("frac", 0.6666666)
236-
# missing ='drop' is the default value for lowess but not for OLS (None)
237-
# we force it here in case statsmodels change their defaults
238-
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
239-
hover_header = "<b>LOWESS trendline</b><br><br>"
240-
return y_out, hover_header, None
241-
242-
243-
def ma(options, x, y, x_label, y_label, non_missing):
244-
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
245-
hover_header = "<b>Moving Average trendline</b><br><br>"
246-
return y_out, hover_header, None
247-
248-
249-
def ewm(options, x, y, x_label, y_label, non_missing):
250-
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
251-
hover_header = "<b>EWM trendline</b><br><br>"
252-
return y_out, hover_header, None
253-
254-
255-
def ols(options, x, y, x_label, y_label, non_missing):
256-
import statsmodels.api as sm
257-
258-
add_constant = options.get("add_constant", True)
259-
log_x = options.get("log_x", False)
260-
log_y = options.get("log_y", False)
261-
262-
if log_y:
263-
y = np.log(y)
264-
if log_x:
265-
x = np.log(x)
266-
if add_constant:
267-
x = sm.add_constant(x)
268-
fit_results = sm.OLS(y, x, missing="drop").fit()
269-
y_out = fit_results.predict()
270-
if log_y:
271-
y_out = np.exp(y_out)
272-
hover_header = "<b>OLS trendline</b><br>"
273-
if len(fit_results.params) == 2:
274-
hover_header += "%s = %g * %s + %g<br>" % (
275-
y_label,
276-
fit_results.params[1],
277-
x_label,
278-
fit_results.params[0],
279-
)
280-
elif not add_constant:
281-
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
282-
else:
283-
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
284-
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
285-
return y_out, hover_header, fit_results
286-
287-
288-
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
289-
290-
291233
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
292234
"""Populates a dict with arguments to update trace
293235
@@ -361,6 +303,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
361303
if trace_spec.constructor == go.Histogram:
362304
mapping_labels["count"] = "%{x}"
363305
elif attr_name == "trendline":
306+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
364307
if (
365308
attr_value in trendline_functions
366309
and args["x"]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pandas as pd
2+
import numpy as np
3+
4+
5+
def ols(options, x, y, x_label, y_label, non_missing):
6+
import statsmodels.api as sm
7+
8+
add_constant = options.get("add_constant", True)
9+
log_x = options.get("log_x", False)
10+
log_y = options.get("log_y", False)
11+
12+
if log_y:
13+
y = np.log(y)
14+
y_label = "log(%s)" % y_label
15+
if log_x:
16+
x = np.log(x)
17+
x_label = "log(%s)" % x_label
18+
if add_constant:
19+
x = sm.add_constant(x)
20+
fit_results = sm.OLS(y, x, missing="drop").fit()
21+
y_out = fit_results.predict()
22+
if log_y:
23+
y_out = np.exp(y_out)
24+
hover_header = "<b>OLS trendline</b><br>"
25+
if len(fit_results.params) == 2:
26+
hover_header += "%s = %g * %s + %g<br>" % (
27+
y_label,
28+
fit_results.params[1],
29+
x_label,
30+
fit_results.params[0],
31+
)
32+
elif not add_constant:
33+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
34+
else:
35+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
36+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
37+
return y_out, hover_header, fit_results
38+
39+
40+
def lowess(options, x, y, x_label, y_label, non_missing):
41+
import statsmodels.api as sm
42+
43+
frac = options.get("frac", 0.6666666)
44+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
45+
hover_header = "<b>LOWESS trendline</b><br><br>"
46+
return y_out, hover_header, None
47+
48+
49+
def ma(options, x, y, x_label, y_label, non_missing):
50+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
51+
hover_header = "<b>Moving Average trendline</b><br><br>"
52+
return y_out, hover_header, None
53+
54+
55+
def ewm(options, x, y, x_label, y_label, non_missing):
56+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
57+
hover_header = "<b>EWM trendline</b><br><br>"
58+
return y_out, hover_header, None

0 commit comments

Comments
 (0)