Skip to content

Commit b1d039c

Browse files
make trendlines more robust
1 parent 1db86d0 commit b1d039c

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,24 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
265265
attr_value in ["ols", "lowess"]
266266
and args["x"]
267267
and args["y"]
268-
and len(trace_data) > 1
268+
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
269269
):
270270
import statsmodels.api as sm
271271

272272
# sorting is bad but trace_specs with "trendline" have no other attrs
273273
sorted_trace_data = trace_data.sort_values(by=args["x"])
274-
y = sorted_trace_data[args["y"]]
275-
x = sorted_trace_data[args["x"]]
274+
y = sorted_trace_data[args["y"]].values
275+
x = sorted_trace_data[args["x"]].values
276276

277+
x_is_date = False
277278
if x.dtype.type == np.datetime64:
278279
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
280+
x_is_date = True
281+
elif x.dtype.type == np.object_:
282+
x = x.astype(np.float64)
283+
284+
if y.dtype.type == np.object_:
285+
y = y.astype(np.float64)
279286

280287
if attr_value == "lowess":
281288
# missing ='drop' is the default value for lowess but not for OLS (None)
@@ -286,25 +293,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286293
hover_header = "<b>LOWESS trendline</b><br><br>"
287294
elif attr_value == "ols":
288295
fit_results = sm.OLS(
289-
y.values, sm.add_constant(x.values), missing="drop"
296+
y, sm.add_constant(x), missing="drop"
290297
).fit()
291298
trace_patch["y"] = fit_results.predict()
292299
trace_patch["x"] = x[
293300
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
294301
]
295302
hover_header = "<b>OLS trendline</b><br>"
296-
hover_header += "%s = %g * %s + %g<br>" % (
297-
args["y"],
298-
fit_results.params[1],
299-
args["x"],
300-
fit_results.params[0],
301-
)
303+
if len(fit_results.params) == 2:
304+
hover_header += "%s = %g * %s + %g<br>" % (
305+
args["y"],
306+
fit_results.params[1],
307+
args["x"],
308+
fit_results.params[0],
309+
)
310+
else:
311+
hover_header += "%s = %g<br>" % (
312+
args["y"],
313+
fit_results.params[0],
314+
)
302315
hover_header += (
303316
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
304317
)
318+
if x_is_date:
319+
trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
305320
mapping_labels[get_label(args, args["x"])] = "%{x}"
306321
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
307-
308322
elif attr_name.startswith("error"):
309323
error_xy = attr_name[:7]
310324
arr = "arrayminus" if attr_name.endswith("minus") else "array"
Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
11
import plotly.express as px
22
import numpy as np
3+
import pandas as pd
4+
import pytest
5+
from datetime import datetime
36

47

5-
def test_trendline_nan_values():
8+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
9+
def test_trendline_results_passthrough(mode):
10+
df = px.data.gapminder().query("continent == 'Oceania'")
11+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12+
assert len(fig.data) == 4
13+
for trace in fig["data"][0::2]:
14+
assert "trendline" not in trace.hovertemplate
15+
for trendline in fig["data"][1::2]:
16+
assert "trendline" in trendline.hovertemplate
17+
if mode == "ols":
18+
assert "R<sup>2</sup>" in trendline.hovertemplate
19+
results = px.get_trendline_results(fig)
20+
if mode == "ols":
21+
assert len(results) == 2
22+
assert results["country"].values[0] == "Australia"
23+
assert results["country"].values[0] == "Australia"
24+
au_result = results["px_fit_results"].values[0]
25+
assert len(au_result.params) == 2
26+
else:
27+
assert len(results) == 0
28+
29+
30+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
31+
def test_trendline_enough_values(mode):
32+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
33+
assert len(fig.data) == 2
34+
assert len(fig.data[1].x) == 2
35+
fig = px.scatter(x=[0], y=[0], trendline=mode)
36+
assert len(fig.data) == 2
37+
assert fig.data[1].x is None
38+
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
39+
assert len(fig.data) == 2
40+
assert fig.data[1].x is None
41+
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
42+
assert len(fig.data) == 2
43+
assert fig.data[1].x is None
44+
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
45+
assert len(fig.data) == 2
46+
assert len(fig.data[1].x) == 2
47+
48+
49+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
50+
def test_trendline_nan_values(mode):
651
df = px.data.gapminder().query("continent == 'Oceania'")
752
start_date = 1970
853
df["pop"][df["year"] < start_date] = np.nan
9-
modes = ["ols", "lowess"]
10-
for mode in modes:
11-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12-
for trendline in fig["data"][1::2]:
13-
assert trendline.x[0] >= start_date
14-
assert len(trendline.x) == len(trendline.y)
54+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
55+
for trendline in fig["data"][1::2]:
56+
assert trendline.x[0] >= start_date
57+
assert len(trendline.x) == len(trendline.y)
58+
59+
60+
def test_no_slope_ols_trendline():
61+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
62+
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
63+
results = px.get_trendline_results(fig)
64+
params = results["px_fit_results"].iloc[0].params
65+
assert np.all(np.isclose(params, [0, 1]))
66+
67+
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
68+
assert "y = 0" in fig.data[1].hovertemplate
69+
results = px.get_trendline_results(fig)
70+
params = results["px_fit_results"].iloc[0].params
71+
assert np.all(np.isclose(params, [0]))
72+
73+
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
74+
assert "y = 0" in fig.data[1].hovertemplate
75+
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
76+
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
77+
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
78+
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
79+
80+
81+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
82+
def test_trendline_on_timeseries(mode):
83+
df = px.data.stocks()
84+
df["date"] = pd.to_datetime(df["date"])
85+
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
86+
assert len(fig.data) == 2
87+
assert len(fig.data[0].x) == len(fig.data[1].x)
88+
assert type(fig.data[0].x[0]) == datetime
89+
assert type(fig.data[1].x[0]) == datetime
90+
assert np.all(fig.data[0].x == fig.data[1].x)

0 commit comments

Comments
 (0)