@@ -229,6 +229,56 @@ def make_mapping(args, variable):
229
229
)
230
230
231
231
232
+ def lowess (options , x , y , x_label , y_label , non_missing ):
233
+ import statsmodels .api as sm
234
+
235
+ frac = options .get ("frac" , 0.6666666 )
236
+ # missing ='drop' is the default value for lowess but not for OLS (None)
237
+ # we force it here in case statsmodels change their defaults
238
+ y_out = sm .nonparametric .lowess (y , x , missing = "drop" , frac = frac )[:, 1 ]
239
+ hover_header = "<b>LOWESS trendline</b><br><br>"
240
+ return y_out , hover_header , None
241
+
242
+
243
+ def ma (options , x , y , x_label , y_label , non_missing ):
244
+ y_out = pd .Series (y , index = x ).rolling (** options ).mean ()[non_missing ]
245
+ hover_header = "<b>Moving Average trendline</b><br><br>"
246
+ return y_out , hover_header , None
247
+
248
+
249
+ def ewm (options , x , y , x_label , y_label , non_missing ):
250
+ y_out = pd .Series (y , index = x ).ewm (** options ).mean ()[non_missing ]
251
+ hover_header = "<b>EWM trendline</b><br><br>"
252
+ return y_out , hover_header , None
253
+
254
+
255
+ def ols (options , x , y , x_label , y_label , non_missing ):
256
+ import statsmodels .api as sm
257
+
258
+ add_constant = options .get ("add_constant" , True )
259
+ fit_results = sm .OLS (
260
+ y , sm .add_constant (x ) if add_constant else x , missing = "drop"
261
+ ).fit ()
262
+ y_out = fit_results .predict ()
263
+ hover_header = "<b>OLS trendline</b><br>"
264
+ if len (fit_results .params ) == 2 :
265
+ hover_header += "%s = %g * %s + %g<br>" % (
266
+ y_label ,
267
+ fit_results .params [1 ],
268
+ x_label ,
269
+ fit_results .params [0 ],
270
+ )
271
+ elif not add_constant :
272
+ hover_header += "%s = %g* %s<br>" % (y_label , fit_results .params [0 ], x_label ,)
273
+ else :
274
+ hover_header += "%s = %g<br>" % (y_label , fit_results .params [0 ],)
275
+ hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
276
+ return y_out , hover_header , fit_results
277
+
278
+
279
+ trendline_functions = dict (lowess = lowess , ma = ma , ewm = ewm , ols = ols )
280
+
281
+
232
282
def make_trace_kwargs (args , trace_spec , trace_data , mapping_labels , sizeref ):
233
283
"""Populates a dict with arguments to update trace
234
284
@@ -303,12 +353,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
303
353
mapping_labels ["count" ] = "%{x}"
304
354
elif attr_name == "trendline" :
305
355
if (
306
- attr_value [ 0 ] in [ "ols" , "lowess" , "ma" , "ewm" ]
356
+ attr_value in trendline_functions
307
357
and args ["x" ]
308
358
and args ["y" ]
309
359
and len (trace_data [[args ["x" ], args ["y" ]]].dropna ()) > 1
310
360
):
311
- import statsmodels .api as sm
312
361
313
362
# sorting is bad but trace_specs with "trendline" have no other attrs
314
363
sorted_trace_data = trace_data .sort_values (by = args ["x" ])
@@ -339,56 +388,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
339
388
np .logical_or (np .isnan (y ), np .isnan (x ))
340
389
)
341
390
trace_patch ["x" ] = sorted_trace_data [args ["x" ]][non_missing ]
342
-
343
- if attr_value [0 ] == "lowess" :
344
- alpha = attr_value [1 ] or 0.6666666
345
- # missing ='drop' is the default value for lowess but not for OLS (None)
346
- # we force it here in case statsmodels change their defaults
347
- trendline = sm .nonparametric .lowess (
348
- y , x , missing = "drop" , frac = alpha
349
- )
350
- trace_patch ["y" ] = trendline [:, 1 ]
351
- hover_header = "<b>LOWESS trendline</b><br><br>"
352
- elif attr_value [0 ] == "ma" :
353
- trace_patch ["y" ] = (
354
- pd .Series (y [non_missing ])
355
- .rolling (window = attr_value [1 ] or 3 )
356
- .mean ()
357
- )
358
- elif attr_value [0 ] == "ewm" :
359
- trace_patch ["y" ] = (
360
- pd .Series (y [non_missing ])
361
- .ewm (alpha = attr_value [1 ] or 0.5 )
362
- .mean ()
363
- )
364
- elif attr_value [0 ] == "ols" :
365
- add_constant = attr_value [1 ] is not False
366
- fit_results = sm .OLS (
367
- y , sm .add_constant (x ) if add_constant else x , missing = "drop"
368
- ).fit ()
369
- trace_patch ["y" ] = fit_results .predict ()
370
- hover_header = "<b>OLS trendline</b><br>"
371
- if len (fit_results .params ) == 2 :
372
- hover_header += "%s = %g * %s + %g<br>" % (
373
- args ["y" ],
374
- fit_results .params [1 ],
375
- args ["x" ],
376
- fit_results .params [0 ],
377
- )
378
- elif not add_constant :
379
- hover_header += "%s = %g* %s<br>" % (
380
- args ["y" ],
381
- fit_results .params [0 ],
382
- args ["x" ],
383
- )
384
- else :
385
- hover_header += "%s = %g<br>" % (
386
- args ["y" ],
387
- fit_results .params [0 ],
388
- )
389
- hover_header += (
390
- "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
391
- )
391
+ trendline_function = trendline_functions [attr_value ]
392
+ y_out , hover_header , fit_results = trendline_function (
393
+ args ["trendline_options" ],
394
+ x ,
395
+ y ,
396
+ args ["x" ],
397
+ args ["y" ],
398
+ non_missing ,
399
+ )
400
+ assert len (y_out ) == len (
401
+ trace_patch ["x" ]
402
+ ), "missing-data-handling failure in trendline code"
403
+ trace_patch ["y" ] = y_out
392
404
mapping_labels [get_label (args , args ["x" ])] = "%{x}"
393
405
mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
394
406
elif attr_name .startswith ("error" ):
@@ -1822,9 +1834,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1822
1834
):
1823
1835
args ["facet_col_wrap" ] = 0
1824
1836
1825
- if args .get ("trendline" , None ) is not None :
1826
- if isinstance (args ["trendline" ], str ):
1827
- args ["trendline" ] = (args ["trendline" ], None )
1837
+ if "trendline_options" in args and args ["trendline_options" ] is None :
1838
+ args ["trendline_options" ] = dict ()
1828
1839
1829
1840
# Compute applicable grouping attributes
1830
1841
for k in group_attrables :
0 commit comments