Skip to content

Commit d2f8b72

Browse files
committed
Merge branch 'master' of https://github.com/plotly/plotly.py
2 parents f628e94 + 33701df commit d2f8b72

File tree

7 files changed

+227
-26
lines changed

7 files changed

+227
-26
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
77
### Fixed
88

99
- Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524))
10+
- Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)).
11+
- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554))
1012

1113
## [4.8.1] - 2020-05-28
1214

packages/python/plotly/_plotly_utils/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,12 @@ def key(v):
247247
return tuple(v_parts)
248248

249249
return sorted(vals, key=key, reverse=reverse)
250+
251+
252+
def _get_int_type():
253+
np = get_module("numpy", should_load=False)
254+
if np:
255+
int_type = (int, np.integer)
256+
else:
257+
int_type = (int,)
258+
return int_type

packages/python/plotly/plotly/basedatatypes.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import contextmanager
1010
from copy import deepcopy, copy
1111

12-
from _plotly_utils.utils import _natural_sort_strings
12+
from _plotly_utils.utils import _natural_sort_strings, _get_int_type
1313
from .optional_imports import get_module
1414

1515
# Create Undefined sentinel value
@@ -1560,12 +1560,7 @@ def _validate_rows_cols(name, n, vals):
15601560
if len(vals) != n:
15611561
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
15621562

1563-
try:
1564-
import numpy as np
1565-
1566-
int_type = (int, np.integer)
1567-
except ImportError:
1568-
int_type = (int,)
1563+
int_type = _get_int_type()
15691564

15701565
if [r for r in vals if not isinstance(r, int_type)]:
15711566
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
@@ -1677,14 +1672,19 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
16771672
- All remaining properties are passed to the constructor
16781673
of the specified trace type.
16791674
1680-
rows : None or list[int] (default None)
1675+
rows : None, list[int], or int (default None)
16811676
List of subplot row indexes (starting from 1) for the traces to be
16821677
added. Only valid if figure was created using
16831678
`plotly.tools.make_subplots`
1679+
If a single integer is passed, all traces will be added to row number
1680+
16841681
cols : None or list[int] (default None)
16851682
List of subplot column indexes (starting from 1) for the traces
16861683
to be added. Only valid if figure was created using
16871684
`plotly.tools.make_subplots`
1685+
If a single integer is passed, all traces will be added to column number
1686+
1687+
16881688
secondary_ys: None or list[boolean] (default None)
16891689
List of secondary_y booleans for traces to be added. See the
16901690
docstring for `add_trace` for more info.
@@ -1723,6 +1723,15 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
17231723
for ind, new_trace in enumerate(data):
17241724
new_trace._trace_ind = ind + len(self.data)
17251725

1726+
# Allow integers as inputs to subplots
1727+
int_type = _get_int_type()
1728+
1729+
if isinstance(rows, int_type):
1730+
rows = [rows] * len(data)
1731+
1732+
if isinstance(cols, int_type):
1733+
cols = [cols] * len(data)
1734+
17261735
# Validate rows / cols
17271736
n = len(data)
17281737
BaseFigure._validate_rows_cols("rows", n, rows)

packages/python/plotly/plotly/express/_core.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ def get_label(args, column):
117117

118118

119119
def invert_label(args, column):
120+
"""Invert mapping.
121+
Find key corresponding to value column in dict args["labels"].
122+
Returns `column` if the value does not exist.
123+
"""
120124
reversed_labels = {value: key for (key, value) in args["labels"].items()}
121125
try:
122126
return reversed_labels[column]
@@ -273,17 +277,35 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
273277
attr_value in ["ols", "lowess"]
274278
and args["x"]
275279
and args["y"]
276-
and len(trace_data) > 1
280+
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
277281
):
278282
import statsmodels.api as sm
279283

280284
# sorting is bad but trace_specs with "trendline" have no other attrs
281285
sorted_trace_data = trace_data.sort_values(by=args["x"])
282-
y = sorted_trace_data[args["y"]]
283-
x = sorted_trace_data[args["x"]]
286+
y = sorted_trace_data[args["y"]].values
287+
x = sorted_trace_data[args["x"]].values
284288

289+
x_is_date = False
285290
if x.dtype.type == np.datetime64:
286291
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
292+
x_is_date = True
293+
elif x.dtype.type == np.object_:
294+
try:
295+
x = x.astype(np.float64)
296+
except ValueError:
297+
raise ValueError(
298+
"Could not convert value of 'x' ('%s') into a numeric type. "
299+
"If 'x' contains stringified dates, please convert to a datetime column."
300+
% args["x"]
301+
)
302+
if y.dtype.type == np.object_:
303+
try:
304+
y = y.astype(np.float64)
305+
except ValueError:
306+
raise ValueError(
307+
"Could not convert value of 'y' into a numeric type."
308+
)
287309

288310
if attr_value == "lowess":
289311
# missing ='drop' is the default value for lowess but not for OLS (None)
@@ -294,25 +316,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
294316
hover_header = "<b>LOWESS trendline</b><br><br>"
295317
elif attr_value == "ols":
296318
fit_results = sm.OLS(
297-
y.values, sm.add_constant(x.values), missing="drop"
319+
y, sm.add_constant(x), missing="drop"
298320
).fit()
299321
trace_patch["y"] = fit_results.predict()
300322
trace_patch["x"] = x[
301323
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
302324
]
303325
hover_header = "<b>OLS trendline</b><br>"
304-
hover_header += "%s = %g * %s + %g<br>" % (
305-
args["y"],
306-
fit_results.params[1],
307-
args["x"],
308-
fit_results.params[0],
309-
)
326+
if len(fit_results.params) == 2:
327+
hover_header += "%s = %g * %s + %g<br>" % (
328+
args["y"],
329+
fit_results.params[1],
330+
args["x"],
331+
fit_results.params[0],
332+
)
333+
else:
334+
hover_header += "%s = %g<br>" % (
335+
args["y"],
336+
fit_results.params[0],
337+
)
310338
hover_header += (
311339
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
312340
)
341+
if x_is_date:
342+
trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
313343
mapping_labels[get_label(args, args["x"])] = "%{x}"
314344
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
315-
316345
elif attr_name.startswith("error"):
317346
error_xy = attr_name[:7]
318347
arr = "arrayminus" if attr_name.endswith("minus") else "array"
@@ -442,6 +471,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
442471
mapping_labels_copy = OrderedDict(mapping_labels)
443472
if args["hover_data"] and isinstance(args["hover_data"], dict):
444473
for k, v in mapping_labels.items():
474+
# We need to invert the mapping here
445475
k_args = invert_label(args, k)
446476
if k_args in args["hover_data"]:
447477
if args["hover_data"][k_args][0]:

packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,33 @@ def test_add_traces(self):
6363
{"type": "histogram2dcontour", "line": {"color": "cyan"}},
6464
]
6565
)
66+
67+
68+
class TestAddTracesRowsColsDataTypes(TestCase):
69+
def test_add_traces_with_iterable(self):
70+
import plotly.express as px
71+
72+
df = px.data.tips()
73+
fig = px.scatter(df, x="total_bill", y="tip", color="day")
74+
from plotly.subplots import make_subplots
75+
76+
fig2 = make_subplots(1, 2)
77+
fig2.add_traces(fig.data, rows=[1,] * len(fig.data), cols=[1,] * len(fig.data))
78+
79+
expected_data_length = 4
80+
81+
self.assertEqual(expected_data_length, len(fig2.data))
82+
83+
def test_add_traces_with_integers(self):
84+
import plotly.express as px
85+
86+
df = px.data.tips()
87+
fig = px.scatter(df, x="total_bill", y="tip", color="day")
88+
from plotly.subplots import make_subplots
89+
90+
fig2 = make_subplots(1, 2)
91+
fig2.add_traces(fig.data, rows=1, cols=2)
92+
93+
expected_data_length = 4
94+
95+
self.assertEqual(expected_data_length, len(fig2.data))
Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,110 @@
11
import plotly.express as px
22
import numpy as np
3+
import pandas as pd
4+
import pytest
5+
from datetime import datetime
36

47

5-
def test_trendline_nan_values():
8+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
9+
def test_trendline_results_passthrough(mode):
10+
df = px.data.gapminder().query("continent == 'Oceania'")
11+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12+
assert len(fig.data) == 4
13+
for trace in fig["data"][0::2]:
14+
assert "trendline" not in trace.hovertemplate
15+
for trendline in fig["data"][1::2]:
16+
assert "trendline" in trendline.hovertemplate
17+
if mode == "ols":
18+
assert "R<sup>2</sup>" in trendline.hovertemplate
19+
results = px.get_trendline_results(fig)
20+
if mode == "ols":
21+
assert len(results) == 2
22+
assert results["country"].values[0] == "Australia"
23+
assert results["country"].values[0] == "Australia"
24+
au_result = results["px_fit_results"].values[0]
25+
assert len(au_result.params) == 2
26+
else:
27+
assert len(results) == 0
28+
29+
30+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
31+
def test_trendline_enough_values(mode):
32+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
33+
assert len(fig.data) == 2
34+
assert len(fig.data[1].x) == 2
35+
fig = px.scatter(x=[0], y=[0], trendline=mode)
36+
assert len(fig.data) == 2
37+
assert fig.data[1].x is None
38+
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
39+
assert len(fig.data) == 2
40+
assert fig.data[1].x is None
41+
fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
42+
assert len(fig.data) == 2
43+
assert fig.data[1].x is None
44+
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
45+
assert len(fig.data) == 2
46+
assert fig.data[1].x is None
47+
fig = px.scatter(
48+
x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
49+
)
50+
assert len(fig.data) == 2
51+
assert fig.data[1].x is None
52+
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
53+
assert len(fig.data) == 2
54+
assert len(fig.data[1].x) == 2
55+
fig = px.scatter(
56+
x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
57+
)
58+
assert len(fig.data) == 2
59+
assert len(fig.data[1].x) == 2
60+
61+
62+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
63+
def test_trendline_nan_values(mode):
664
df = px.data.gapminder().query("continent == 'Oceania'")
765
start_date = 1970
866
df["pop"][df["year"] < start_date] = np.nan
9-
modes = ["ols", "lowess"]
10-
for mode in modes:
11-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12-
for trendline in fig["data"][1::2]:
13-
assert trendline.x[0] >= start_date
14-
assert len(trendline.x) == len(trendline.y)
67+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
68+
for trendline in fig["data"][1::2]:
69+
assert trendline.x[0] >= start_date
70+
assert len(trendline.x) == len(trendline.y)
71+
72+
73+
def test_no_slope_ols_trendline():
74+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
75+
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
76+
results = px.get_trendline_results(fig)
77+
params = results["px_fit_results"].iloc[0].params
78+
assert np.all(np.isclose(params, [0, 1]))
79+
80+
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
81+
assert "y = 0" in fig.data[1].hovertemplate
82+
results = px.get_trendline_results(fig)
83+
params = results["px_fit_results"].iloc[0].params
84+
assert np.all(np.isclose(params, [0]))
85+
86+
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
87+
assert "y = 0" in fig.data[1].hovertemplate
88+
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
89+
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
90+
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
91+
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
92+
93+
94+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
95+
def test_trendline_on_timeseries(mode):
96+
df = px.data.stocks()
97+
98+
with pytest.raises(ValueError) as err_msg:
99+
px.scatter(df, x="date", y="GOOG", trendline=mode)
100+
assert "Could not convert value of 'x' ('date') into a numeric type." in str(
101+
err_msg.value
102+
)
103+
104+
df["date"] = pd.to_datetime(df["date"])
105+
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
106+
assert len(fig.data) == 2
107+
assert len(fig.data[0].x) == len(fig.data[1].x)
108+
assert type(fig.data[0].x[0]) == datetime
109+
assert type(fig.data[1].x[0]) == datetime
110+
assert np.all(fig.data[0].x == fig.data[1].x)

packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,28 @@ def test_numpy_integer_import(self):
7070
value = get_by_path(fig, data_path)
7171
expected_value = (1,)
7272
self.assertEqual(value, expected_value)
73+
74+
def test_get_numpy_int_type(self):
75+
import numpy as np
76+
from _plotly_utils.utils import _get_int_type
77+
78+
int_type_tuple = _get_int_type()
79+
expected_tuple = (int, np.integer)
80+
81+
self.assertEqual(int_type_tuple, expected_tuple)
82+
83+
84+
class TestNoNumpyIntegerBaseType(TestCase):
85+
def test_no_numpy_int_type(self):
86+
import sys
87+
from _plotly_utils.utils import _get_int_type
88+
from _plotly_utils.optional_imports import get_module
89+
90+
np = get_module("numpy", should_load=False)
91+
if np:
92+
sys.modules.pop("numpy")
93+
94+
int_type_tuple = _get_int_type()
95+
expected_tuple = (int,)
96+
97+
self.assertEqual(int_type_tuple, expected_tuple)

0 commit comments

Comments
 (0)