|
2 | 2 | import plotly.io as pio
|
3 | 3 | from collections import namedtuple, OrderedDict
|
4 | 4 | from ._special_inputs import IdentityMap, Constant, Range
|
| 5 | +from .trendline_functions import ols, lowess, ma, ewm |
5 | 6 |
|
6 | 7 | from _plotly_utils.basevalidators import ColorscaleValidator
|
7 | 8 | from plotly.colors import qualitative, sequential
|
@@ -229,65 +230,6 @@ def make_mapping(args, variable):
|
229 | 230 | )
|
230 | 231 |
|
231 | 232 |
|
232 |
| -def lowess(options, x, y, x_label, y_label, non_missing): |
233 |
| - import statsmodels.api as sm |
234 |
| - |
235 |
| - frac = options.get("frac", 0.6666666) |
236 |
| - # missing ='drop' is the default value for lowess but not for OLS (None) |
237 |
| - # we force it here in case statsmodels change their defaults |
238 |
| - y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] |
239 |
| - hover_header = "<b>LOWESS trendline</b><br><br>" |
240 |
| - return y_out, hover_header, None |
241 |
| - |
242 |
| - |
243 |
| -def ma(options, x, y, x_label, y_label, non_missing): |
244 |
| - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] |
245 |
| - hover_header = "<b>Moving Average trendline</b><br><br>" |
246 |
| - return y_out, hover_header, None |
247 |
| - |
248 |
| - |
249 |
| -def ewm(options, x, y, x_label, y_label, non_missing): |
250 |
| - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] |
251 |
| - hover_header = "<b>EWM trendline</b><br><br>" |
252 |
| - return y_out, hover_header, None |
253 |
| - |
254 |
| - |
255 |
| -def ols(options, x, y, x_label, y_label, non_missing): |
256 |
| - import statsmodels.api as sm |
257 |
| - |
258 |
| - add_constant = options.get("add_constant", True) |
259 |
| - log_x = options.get("log_x", False) |
260 |
| - log_y = options.get("log_y", False) |
261 |
| - |
262 |
| - if log_y: |
263 |
| - y = np.log(y) |
264 |
| - if log_x: |
265 |
| - x = np.log(x) |
266 |
| - if add_constant: |
267 |
| - x = sm.add_constant(x) |
268 |
| - fit_results = sm.OLS(y, x, missing="drop").fit() |
269 |
| - y_out = fit_results.predict() |
270 |
| - if log_y: |
271 |
| - y_out = np.exp(y_out) |
272 |
| - hover_header = "<b>OLS trendline</b><br>" |
273 |
| - if len(fit_results.params) == 2: |
274 |
| - hover_header += "%s = %g * %s + %g<br>" % ( |
275 |
| - y_label, |
276 |
| - fit_results.params[1], |
277 |
| - x_label, |
278 |
| - fit_results.params[0], |
279 |
| - ) |
280 |
| - elif not add_constant: |
281 |
| - hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,) |
282 |
| - else: |
283 |
| - hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],) |
284 |
| - hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared |
285 |
| - return y_out, hover_header, fit_results |
286 |
| - |
287 |
| - |
288 |
| -trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
289 |
| - |
290 |
| - |
291 | 233 | def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
292 | 234 | """Populates a dict with arguments to update trace
|
293 | 235 |
|
@@ -361,6 +303,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
361 | 303 | if trace_spec.constructor == go.Histogram:
|
362 | 304 | mapping_labels["count"] = "%{x}"
|
363 | 305 | elif attr_name == "trendline":
|
| 306 | + trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
364 | 307 | if (
|
365 | 308 | attr_value in trendline_functions
|
366 | 309 | and args["x"]
|
|
0 commit comments