Skip to content

Commit be3a63f

Browse files
move trendline code to own module
1 parent e7f04b1 commit be3a63f

File tree

4 files changed

+62
-60
lines changed

4 files changed

+62
-60
lines changed

doc/apidoc/plotly.express.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@ plotly's high-level API for rapid figure generation. ::
6060

6161
generated/plotly.express.data.rst
6262
generated/plotly.express.colors.rst
63+
generated/plotly.express.trendline_functions.rst

packages/python/plotly/plotly/express/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
6262

63-
from . import data, colors # noqa: F401
63+
from . import data, colors, trendline_functions # noqa: F401
6464

6565
__all__ = [
6666
"scatter",

packages/python/plotly/plotly/express/_core.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
44
from ._special_inputs import IdentityMap, Constant, Range
5+
from .trendline_functions import ols, lowess, ma, ewm
56

67
from _plotly_utils.basevalidators import ColorscaleValidator
78
from plotly.colors import qualitative, sequential
@@ -239,65 +240,6 @@ def make_mapping(args, variable):
239240
)
240241

241242

242-
def lowess(options, x, y, x_label, y_label, non_missing):
243-
import statsmodels.api as sm
244-
245-
frac = options.get("frac", 0.6666666)
246-
# missing ='drop' is the default value for lowess but not for OLS (None)
247-
# we force it here in case statsmodels change their defaults
248-
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
249-
hover_header = "<b>LOWESS trendline</b><br><br>"
250-
return y_out, hover_header, None
251-
252-
253-
def ma(options, x, y, x_label, y_label, non_missing):
254-
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
255-
hover_header = "<b>Moving Average trendline</b><br><br>"
256-
return y_out, hover_header, None
257-
258-
259-
def ewm(options, x, y, x_label, y_label, non_missing):
260-
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
261-
hover_header = "<b>EWM trendline</b><br><br>"
262-
return y_out, hover_header, None
263-
264-
265-
def ols(options, x, y, x_label, y_label, non_missing):
266-
import statsmodels.api as sm
267-
268-
add_constant = options.get("add_constant", True)
269-
log_x = options.get("log_x", False)
270-
log_y = options.get("log_y", False)
271-
272-
if log_y:
273-
y = np.log(y)
274-
if log_x:
275-
x = np.log(x)
276-
if add_constant:
277-
x = sm.add_constant(x)
278-
fit_results = sm.OLS(y, x, missing="drop").fit()
279-
y_out = fit_results.predict()
280-
if log_y:
281-
y_out = np.exp(y_out)
282-
hover_header = "<b>OLS trendline</b><br>"
283-
if len(fit_results.params) == 2:
284-
hover_header += "%s = %g * %s + %g<br>" % (
285-
y_label,
286-
fit_results.params[1],
287-
x_label,
288-
fit_results.params[0],
289-
)
290-
elif not add_constant:
291-
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
292-
else:
293-
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
294-
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
295-
return y_out, hover_header, fit_results
296-
297-
298-
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
299-
300-
301243
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
302244
"""Populates a dict with arguments to update trace
303245
@@ -371,6 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
371313
if trace_spec.constructor == go.Histogram:
372314
mapping_labels["count"] = "%{x}"
373315
elif attr_name == "trendline":
316+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
374317
if (
375318
attr_value in trendline_functions
376319
and args["x"]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pandas as pd
2+
import numpy as np
3+
4+
5+
def ols(options, x, y, x_label, y_label, non_missing):
6+
import statsmodels.api as sm
7+
8+
add_constant = options.get("add_constant", True)
9+
log_x = options.get("log_x", False)
10+
log_y = options.get("log_y", False)
11+
12+
if log_y:
13+
y = np.log(y)
14+
y_label = "log(%s)" % y_label
15+
if log_x:
16+
x = np.log(x)
17+
x_label = "log(%s)" % x_label
18+
if add_constant:
19+
x = sm.add_constant(x)
20+
fit_results = sm.OLS(y, x, missing="drop").fit()
21+
y_out = fit_results.predict()
22+
if log_y:
23+
y_out = np.exp(y_out)
24+
hover_header = "<b>OLS trendline</b><br>"
25+
if len(fit_results.params) == 2:
26+
hover_header += "%s = %g * %s + %g<br>" % (
27+
y_label,
28+
fit_results.params[1],
29+
x_label,
30+
fit_results.params[0],
31+
)
32+
elif not add_constant:
33+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
34+
else:
35+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
36+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
37+
return y_out, hover_header, fit_results
38+
39+
40+
def lowess(options, x, y, x_label, y_label, non_missing):
41+
import statsmodels.api as sm
42+
43+
frac = options.get("frac", 0.6666666)
44+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
45+
hover_header = "<b>LOWESS trendline</b><br><br>"
46+
return y_out, hover_header, None
47+
48+
49+
def ma(options, x, y, x_label, y_label, non_missing):
50+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
51+
hover_header = "<b>Moving Average trendline</b><br><br>"
52+
return y_out, hover_header, None
53+
54+
55+
def ewm(options, x, y, x_label, y_label, non_missing):
56+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
57+
hover_header = "<b>EWM trendline</b><br><br>"
58+
return y_out, hover_header, None

0 commit comments

Comments
 (0)