|
2 | 2 | import plotly.io as pio
|
3 | 3 | from collections import namedtuple, OrderedDict
|
4 | 4 | from ._special_inputs import IdentityMap, Constant, Range
|
| 5 | +from .trendline_functions import ols, lowess, ma, ewm |
5 | 6 |
|
6 | 7 | from _plotly_utils.basevalidators import ColorscaleValidator
|
7 | 8 | from plotly.colors import qualitative, sequential
|
@@ -239,65 +240,6 @@ def make_mapping(args, variable):
|
239 | 240 | )
|
240 | 241 |
|
241 | 242 |
|
242 |
| -def lowess(options, x, y, x_label, y_label, non_missing): |
243 |
| - import statsmodels.api as sm |
244 |
| - |
245 |
| - frac = options.get("frac", 0.6666666) |
246 |
| - # missing ='drop' is the default value for lowess but not for OLS (None) |
247 |
| - # we force it here in case statsmodels change their defaults |
248 |
| - y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] |
249 |
| - hover_header = "<b>LOWESS trendline</b><br><br>" |
250 |
| - return y_out, hover_header, None |
251 |
| - |
252 |
| - |
253 |
| -def ma(options, x, y, x_label, y_label, non_missing): |
254 |
| - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] |
255 |
| - hover_header = "<b>Moving Average trendline</b><br><br>" |
256 |
| - return y_out, hover_header, None |
257 |
| - |
258 |
| - |
259 |
| -def ewm(options, x, y, x_label, y_label, non_missing): |
260 |
| - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] |
261 |
| - hover_header = "<b>EWM trendline</b><br><br>" |
262 |
| - return y_out, hover_header, None |
263 |
| - |
264 |
| - |
265 |
| -def ols(options, x, y, x_label, y_label, non_missing): |
266 |
| - import statsmodels.api as sm |
267 |
| - |
268 |
| - add_constant = options.get("add_constant", True) |
269 |
| - log_x = options.get("log_x", False) |
270 |
| - log_y = options.get("log_y", False) |
271 |
| - |
272 |
| - if log_y: |
273 |
| - y = np.log(y) |
274 |
| - if log_x: |
275 |
| - x = np.log(x) |
276 |
| - if add_constant: |
277 |
| - x = sm.add_constant(x) |
278 |
| - fit_results = sm.OLS(y, x, missing="drop").fit() |
279 |
| - y_out = fit_results.predict() |
280 |
| - if log_y: |
281 |
| - y_out = np.exp(y_out) |
282 |
| - hover_header = "<b>OLS trendline</b><br>" |
283 |
| - if len(fit_results.params) == 2: |
284 |
| - hover_header += "%s = %g * %s + %g<br>" % ( |
285 |
| - y_label, |
286 |
| - fit_results.params[1], |
287 |
| - x_label, |
288 |
| - fit_results.params[0], |
289 |
| - ) |
290 |
| - elif not add_constant: |
291 |
| - hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,) |
292 |
| - else: |
293 |
| - hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],) |
294 |
| - hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared |
295 |
| - return y_out, hover_header, fit_results |
296 |
| - |
297 |
| - |
298 |
| -trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
299 |
| - |
300 |
| - |
301 | 243 | def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
302 | 244 | """Populates a dict with arguments to update trace
|
303 | 245 |
|
@@ -371,6 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
371 | 313 | if trace_spec.constructor == go.Histogram:
|
372 | 314 | mapping_labels["count"] = "%{x}"
|
373 | 315 | elif attr_name == "trendline":
|
| 316 | + trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
374 | 317 | if (
|
375 | 318 | attr_value in trendline_functions
|
376 | 319 | and args["x"]
|
|
0 commit comments