Skip to content

Commit 9762086

Browse files
nicholas-esterernicolaskruchten
authored andcommitted
traces can be selected with index or string
if selector is string, it is converted to dict(type=selector)
1 parent 6627d3e commit 9762086

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

packages/python/plotly/plotly/basedatatypes.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,18 +1177,16 @@ def select_traces(self, selector=None, row=None, col=None, secondary_y=None):
11771177
def _perform_select_traces(self, filter_by_subplot, grid_subplot_refs, selector):
11781178
from plotly.subplots import _get_subplot_ref_for_trace
11791179

1180-
for trace in self.data:
1181-
# Filter by subplot
1182-
if filter_by_subplot:
1183-
trace_subplot_ref = _get_subplot_ref_for_trace(trace)
1184-
if trace_subplot_ref not in grid_subplot_refs:
1185-
continue
1180+
# functions for filtering
1181+
def _filter_by_subplot_ref(trace):
1182+
trace_subplot_ref = _get_subplot_ref_for_trace(trace)
1183+
return trace_subplot_ref in grid_subplot_refs
11861184

1187-
# Filter by selector
1188-
if not self._selector_matches(trace, selector):
1189-
continue
1185+
funcs = []
1186+
if filter_by_subplot:
1187+
funcs.append(_filter_by_subplot_ref)
11901188

1191-
yield trace
1189+
return self._filter_by_selector(self.data, funcs, selector)
11921190

11931191
@staticmethod
11941192
def _selector_matches(obj, selector):
@@ -1214,8 +1212,8 @@ def _selector_matches(obj, selector):
12141212
if isinstance(selector_val, BasePlotlyType):
12151213
selector_val = selector_val.to_plotly_json()
12161214

1217-
if obj_val != selector_val:
1218-
return False
1215+
return obj_val == selector_val
1216+
12191217
return True
12201218
# If selector is a function, call it with the obj as the argument
12211219
elif type(selector) == type(lambda x: True):
@@ -1226,6 +1224,34 @@ def _selector_matches(obj, selector):
12261224
"accepting a graph object returning a boolean."
12271225
)
12281226

1227+
def _filter_by_selector(self, objects, funcs, selector):
1228+
"""
1229+
objects is a sequence of objects, funcs a list of functions that
1230+
return True if the object should be included in the selection and False
1231+
otherwise and selector is an argument to the self._selector_matches
1232+
function.
1233+
If selector is an integer, the resulting sequence obtained after
1234+
sucessively filtering by each function in funcs is indexed by this
1235+
integer.
1236+
Otherwise selector is used as the selector argument to
1237+
self._selector_matches which is used to filter down the sequence.
1238+
The function returns the sequence (an iterator).
1239+
"""
1240+
1241+
# if selector is not an int, we call it on each trace to test it for selection
1242+
if type(selector) != type(int()):
1243+
funcs.append(lambda obj: self._selector_matches(obj, selector))
1244+
1245+
def _filt(last, f):
1246+
return filter(f, last)
1247+
1248+
filtered_objects = reduce(_filt, funcs, objects)
1249+
1250+
if type(selector) == type(int()):
1251+
return iter([list(filtered_objects)[selector]])
1252+
1253+
return filtered_objects
1254+
12291255
def for_each_trace(self, fn, selector=None, row=None, col=None, secondary_y=None):
12301256
"""
12311257
Apply a function to all traces that satisfy the specified selection
@@ -1470,30 +1496,9 @@ def _filter_sec_y(obj):
14701496
yref_to_secondary_y.get(obj.yref, None) == secondary_y
14711497
)
14721498

1473-
def _filter_selector_matches(obj):
1474-
""" Filter objects for which selector matches """
1475-
return self._selector_matches(obj, selector)
1476-
14771499
funcs = [_filter_row, _filter_col, _filter_sec_y]
1478-
# If selector is not an int, we use the _filter_selector_matches to
1479-
# filter out items
1480-
if type(selector) != type(int()):
1481-
# append selector as filter function
1482-
funcs += [_filter_selector_matches]
1483-
1484-
def _reducer(last, f):
1485-
# takes list of objects that has been filtered down up to now (last)
1486-
# and applies the next filter function (f) to filter it down further.
1487-
return filter(lambda o: f(o), last)
1488-
1489-
# filtered_objs is a sequence of objects filtered by the above functions
1490-
filtered_objs = reduce(_reducer, funcs, self.layout[prop])
1491-
# If selector is an integer, use it as an index into the sequence of
1492-
# filtered objects. Note in this case we do not call _filter_selector_matches.
1493-
if type(selector) == type(int()):
1494-
# wrap in iter because this function should always return an iterator
1495-
return iter([list(filtered_objs)[selector]])
1496-
return filtered_objs
1500+
1501+
return self._filter_by_selector(self.layout[prop], funcs, selector)
14971502

14981503
def _add_annotation_like(
14991504
self,

0 commit comments

Comments
 (0)