Skip to content

Commit 9941749

Browse files
short circuit more machinery when one_group
1 parent 706b04e commit 9941749

File tree

1 file changed

+27
-15
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+27
-15
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18981898

18991899
# Create grouped mappings
19001900
grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
1901+
grouped_mappings = [x for x in grouped_mappings if x.grouper]
19011902

19021903
# Create trace specs
19031904
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
@@ -1915,15 +1916,18 @@ def get_orderings(args, grouper, grouped):
19151916
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
19161917
of a single dimension-group
19171918
"""
1918-
19191919
orders = {} if "category_orders" not in args else args["category_orders"].copy()
1920+
1921+
if grouper == [one_group]:
1922+
sorted_group_names = [("",)]
1923+
return orders, sorted_group_names
1924+
19201925
for col in grouper:
1921-
if col != one_group:
1922-
uniques = list(args["data_frame"][col].unique())
1923-
if col not in orders:
1924-
orders[col] = uniques
1925-
else:
1926-
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
1926+
uniques = list(args["data_frame"][col].unique())
1927+
if col not in orders:
1928+
orders[col] = uniques
1929+
else:
1930+
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
19271931

19281932
sorted_group_names = []
19291933
for group_name in grouped.groups:
@@ -1932,11 +1936,10 @@ def get_orderings(args, grouper, grouped):
19321936
sorted_group_names.append(group_name)
19331937

19341938
for i, col in reversed(list(enumerate(grouper))):
1935-
if col != one_group:
1936-
sorted_group_names = sorted(
1937-
sorted_group_names,
1938-
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
1939-
)
1939+
sorted_group_names = sorted(
1940+
sorted_group_names,
1941+
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
1942+
)
19401943
return orders, sorted_group_names
19411944

19421945

@@ -1955,8 +1958,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19551958
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
19561959
args, constructor, trace_patch, layout_patch
19571960
)
1958-
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1959-
grouped = args["data_frame"].groupby(grouper, sort=False)
1961+
if len(grouped_mappings):
1962+
grouper = [x.grouper for x in grouped_mappings]
1963+
grouped = args["data_frame"].groupby(grouper, sort=False)
1964+
else:
1965+
grouper = [one_group]
1966+
grouped = None
19601967

19611968
orders, sorted_group_names = get_orderings(args, grouper, grouped)
19621969

@@ -1988,7 +1995,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19881995
trace_name_labels = None
19891996
facet_col_wrap = args.get("facet_col_wrap", 0)
19901997
for group_name in sorted_group_names:
1991-
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
1998+
if grouped is not None:
1999+
group = grouped.get_group(
2000+
group_name if len(group_name) > 1 else group_name[0]
2001+
)
2002+
else:
2003+
group = args["data_frame"]
19922004
mapping_labels = OrderedDict()
19932005
trace_name_labels = OrderedDict()
19942006
frame_name = ""

0 commit comments

Comments
 (0)