@@ -239,6 +239,56 @@ def make_mapping(args, variable):
239
239
)
240
240
241
241
242
+ def lowess (options , x , y , x_label , y_label , non_missing ):
243
+ import statsmodels .api as sm
244
+
245
+ frac = options .get ("frac" , 0.6666666 )
246
+ # missing ='drop' is the default value for lowess but not for OLS (None)
247
+ # we force it here in case statsmodels change their defaults
248
+ y_out = sm .nonparametric .lowess (y , x , missing = "drop" , frac = frac )[:, 1 ]
249
+ hover_header = "<b>LOWESS trendline</b><br><br>"
250
+ return y_out , hover_header , None
251
+
252
+
253
+ def ma (options , x , y , x_label , y_label , non_missing ):
254
+ y_out = pd .Series (y , index = x ).rolling (** options ).mean ()[non_missing ]
255
+ hover_header = "<b>Moving Average trendline</b><br><br>"
256
+ return y_out , hover_header , None
257
+
258
+
259
+ def ewm (options , x , y , x_label , y_label , non_missing ):
260
+ y_out = pd .Series (y , index = x ).ewm (** options ).mean ()[non_missing ]
261
+ hover_header = "<b>EWM trendline</b><br><br>"
262
+ return y_out , hover_header , None
263
+
264
+
265
+ def ols (options , x , y , x_label , y_label , non_missing ):
266
+ import statsmodels .api as sm
267
+
268
+ add_constant = options .get ("add_constant" , True )
269
+ fit_results = sm .OLS (
270
+ y , sm .add_constant (x ) if add_constant else x , missing = "drop"
271
+ ).fit ()
272
+ y_out = fit_results .predict ()
273
+ hover_header = "<b>OLS trendline</b><br>"
274
+ if len (fit_results .params ) == 2 :
275
+ hover_header += "%s = %g * %s + %g<br>" % (
276
+ y_label ,
277
+ fit_results .params [1 ],
278
+ x_label ,
279
+ fit_results .params [0 ],
280
+ )
281
+ elif not add_constant :
282
+ hover_header += "%s = %g* %s<br>" % (y_label , fit_results .params [0 ], x_label ,)
283
+ else :
284
+ hover_header += "%s = %g<br>" % (y_label , fit_results .params [0 ],)
285
+ hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
286
+ return y_out , hover_header , fit_results
287
+
288
+
289
+ trendline_functions = dict (lowess = lowess , ma = ma , ewm = ewm , ols = ols )
290
+
291
+
242
292
def make_trace_kwargs (args , trace_spec , trace_data , mapping_labels , sizeref ):
243
293
"""Populates a dict with arguments to update trace
244
294
@@ -313,12 +363,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
313
363
mapping_labels ["count" ] = "%{x}"
314
364
elif attr_name == "trendline" :
315
365
if (
316
- attr_value [ 0 ] in [ "ols" , "lowess" , "ma" , "ewm" ]
366
+ attr_value in trendline_functions
317
367
and args ["x" ]
318
368
and args ["y" ]
319
369
and len (trace_data [[args ["x" ], args ["y" ]]].dropna ()) > 1
320
370
):
321
- import statsmodels .api as sm
322
371
323
372
# sorting is bad but trace_specs with "trendline" have no other attrs
324
373
sorted_trace_data = trace_data .sort_values (by = args ["x" ])
@@ -349,56 +398,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
349
398
np .logical_or (np .isnan (y ), np .isnan (x ))
350
399
)
351
400
trace_patch ["x" ] = sorted_trace_data [args ["x" ]][non_missing ]
352
-
353
- if attr_value [0 ] == "lowess" :
354
- alpha = attr_value [1 ] or 0.6666666
355
- # missing ='drop' is the default value for lowess but not for OLS (None)
356
- # we force it here in case statsmodels change their defaults
357
- trendline = sm .nonparametric .lowess (
358
- y , x , missing = "drop" , frac = alpha
359
- )
360
- trace_patch ["y" ] = trendline [:, 1 ]
361
- hover_header = "<b>LOWESS trendline</b><br><br>"
362
- elif attr_value [0 ] == "ma" :
363
- trace_patch ["y" ] = (
364
- pd .Series (y [non_missing ])
365
- .rolling (window = attr_value [1 ] or 3 )
366
- .mean ()
367
- )
368
- elif attr_value [0 ] == "ewm" :
369
- trace_patch ["y" ] = (
370
- pd .Series (y [non_missing ])
371
- .ewm (alpha = attr_value [1 ] or 0.5 )
372
- .mean ()
373
- )
374
- elif attr_value [0 ] == "ols" :
375
- add_constant = attr_value [1 ] is not False
376
- fit_results = sm .OLS (
377
- y , sm .add_constant (x ) if add_constant else x , missing = "drop"
378
- ).fit ()
379
- trace_patch ["y" ] = fit_results .predict ()
380
- hover_header = "<b>OLS trendline</b><br>"
381
- if len (fit_results .params ) == 2 :
382
- hover_header += "%s = %g * %s + %g<br>" % (
383
- args ["y" ],
384
- fit_results .params [1 ],
385
- args ["x" ],
386
- fit_results .params [0 ],
387
- )
388
- elif not add_constant :
389
- hover_header += "%s = %g* %s<br>" % (
390
- args ["y" ],
391
- fit_results .params [0 ],
392
- args ["x" ],
393
- )
394
- else :
395
- hover_header += "%s = %g<br>" % (
396
- args ["y" ],
397
- fit_results .params [0 ],
398
- )
399
- hover_header += (
400
- "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
401
- )
401
+ trendline_function = trendline_functions [attr_value ]
402
+ y_out , hover_header , fit_results = trendline_function (
403
+ args ["trendline_options" ],
404
+ x ,
405
+ y ,
406
+ args ["x" ],
407
+ args ["y" ],
408
+ non_missing ,
409
+ )
410
+ assert len (y_out ) == len (
411
+ trace_patch ["x" ]
412
+ ), "missing-data-handling failure in trendline code"
413
+ trace_patch ["y" ] = y_out
402
414
mapping_labels [get_label (args , args ["x" ])] = "%{x}"
403
415
mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
404
416
elif attr_name .startswith ("error" ):
@@ -1845,9 +1857,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1845
1857
):
1846
1858
args ["facet_col_wrap" ] = 0
1847
1859
1848
- if args .get ("trendline" , None ) is not None :
1849
- if isinstance (args ["trendline" ], str ):
1850
- args ["trendline" ] = (args ["trendline" ], None )
1860
+ if "trendline_options" in args and args ["trendline_options" ] is None :
1861
+ args ["trendline_options" ] = dict ()
1851
1862
1852
1863
# Compute applicable grouping attributes
1853
1864
for k in group_attrables :
0 commit comments