Skip to content

Commit 236cd2c

Browse files
Merge branch 'auto_orient' into wide_form2
2 parents 33d03d5 + bbc22bc commit 236cd2c

File tree

2 files changed

+79
-47
lines changed

2 files changed

+79
-47
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def area(
236236
labels={},
237237
color_discrete_sequence=None,
238238
color_discrete_map={},
239-
orientation="v",
239+
orientation=None,
240240
groupnorm=None,
241241
log_x=False,
242242
log_y=False,
@@ -256,9 +256,7 @@ def area(
256256
return make_figure(
257257
args=locals(),
258258
constructor=go.Scatter,
259-
trace_patch=dict(
260-
stackgroup=1, mode="lines", orientation=orientation, groupnorm=groupnorm
261-
),
259+
trace_patch=dict(stackgroup=1, mode="lines", groupnorm=groupnorm),
262260
)
263261

264262

@@ -291,7 +289,7 @@ def bar(
291289
range_color=None,
292290
color_continuous_midpoint=None,
293291
opacity=None,
294-
orientation="v",
292+
orientation=None,
295293
barmode="relative",
296294
log_x=False,
297295
log_y=False,
@@ -309,7 +307,7 @@ def bar(
309307
return make_figure(
310308
args=locals(),
311309
constructor=go.Bar,
312-
trace_patch=dict(orientation=orientation, textposition="auto"),
310+
trace_patch=dict(textposition="auto"),
313311
layout_patch=dict(barmode=barmode),
314312
)
315313

@@ -335,7 +333,7 @@ def histogram(
335333
color_discrete_map={},
336334
marginal=None,
337335
opacity=None,
338-
orientation="v",
336+
orientation=None,
339337
barmode="relative",
340338
barnorm=None,
341339
histnorm=None,
@@ -361,13 +359,7 @@ def histogram(
361359
args=locals(),
362360
constructor=go.Histogram,
363361
trace_patch=dict(
364-
orientation=orientation,
365-
histnorm=histnorm,
366-
histfunc=histfunc,
367-
nbinsx=nbins if orientation == "v" else None,
368-
nbinsy=None if orientation == "v" else nbins,
369-
cumulative=dict(enabled=cumulative),
370-
bingroup="x" if orientation == "v" else "y",
362+
histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative),
371363
),
372364
layout_patch=dict(barmode=barmode, barnorm=barnorm),
373365
)
@@ -393,8 +385,8 @@ def violin(
393385
labels={},
394386
color_discrete_sequence=None,
395387
color_discrete_map={},
396-
orientation="v",
397-
violinmode="group",
388+
orientation=None,
389+
violinmode=None,
398390
log_x=False,
399391
log_y=False,
400392
range_x=None,
@@ -414,12 +406,7 @@ def violin(
414406
args=locals(),
415407
constructor=go.Violin,
416408
trace_patch=dict(
417-
orientation=orientation,
418-
points=points,
419-
box=dict(visible=box),
420-
scalegroup=True,
421-
x0=" ",
422-
y0=" ",
409+
points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ",
423410
),
424411
layout_patch=dict(violinmode=violinmode),
425412
)
@@ -445,8 +432,8 @@ def box(
445432
labels={},
446433
color_discrete_sequence=None,
447434
color_discrete_map={},
448-
orientation="v",
449-
boxmode="group",
435+
orientation=None,
436+
boxmode=None,
450437
log_x=False,
451438
log_y=False,
452439
range_x=None,
@@ -470,9 +457,7 @@ def box(
470457
return make_figure(
471458
args=locals(),
472459
constructor=go.Box,
473-
trace_patch=dict(
474-
orientation=orientation, boxpoints=points, notched=notched, x0=" ", y0=" "
475-
),
460+
trace_patch=dict(boxpoints=points, notched=notched, x0=" ", y0=" "),
476461
layout_patch=dict(boxmode=boxmode),
477462
)
478463

@@ -497,8 +482,8 @@ def strip(
497482
labels={},
498483
color_discrete_sequence=None,
499484
color_discrete_map={},
500-
orientation="v",
501-
stripmode="group",
485+
orientation=None,
486+
stripmode=None,
502487
log_x=False,
503488
log_y=False,
504489
range_x=None,
@@ -516,7 +501,6 @@ def strip(
516501
args=locals(),
517502
constructor=go.Box,
518503
trace_patch=dict(
519-
orientation=orientation,
520504
boxpoints="all",
521505
pointpos=0,
522506
hoveron="points",
@@ -1384,7 +1368,7 @@ def funnel(
13841368
color_discrete_sequence=None,
13851369
color_discrete_map={},
13861370
opacity=None,
1387-
orientation="h",
1371+
orientation=None,
13881372
log_x=False,
13891373
log_y=False,
13901374
range_x=None,
@@ -1398,9 +1382,7 @@ def funnel(
13981382
In a funnel plot, each row of `data_frame` is represented as a
13991383
rectangular sector of a funnel.
14001384
"""
1401-
return make_figure(
1402-
args=locals(), constructor=go.Funnel, trace_patch=dict(orientation=orientation),
1403-
)
1385+
return make_figure(args=locals(), constructor=go.Funnel)
14041386

14051387

14061388
funnel.__doc__ = make_docstring(funnel)

packages/python/plotly/plotly/express/_core.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def get_label(args, column):
9292
return column
9393

9494

95+
def _is_continuous(df, col_name):
96+
return df[col_name].dtype.kind in "ifc"
97+
98+
9599
def get_decorated_label(args, column, role):
96100
label = get_label(args, column)
97101
if "histfunc" in args and (
@@ -188,7 +192,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
188192
if ((not attr_value) or (name in attr_value))
189193
and (
190194
trace_spec.constructor != go.Parcoords
191-
or args["data_frame"][name].dtype.kind in "ifc"
195+
or _is_continuous(args["data_frame"], name)
192196
)
193197
and (
194198
trace_spec.constructor != go.Parcats
@@ -1161,7 +1165,7 @@ def aggfunc_discrete(x):
11611165
agg_f[count_colname] = "sum"
11621166

11631167
if args["color"]:
1164-
if df[args["color"]].dtype.kind not in "ifc":
1168+
if not _is_continuous(df, args["color"]):
11651169
aggfunc_color = aggfunc_discrete
11661170
discrete_color = True
11671171
elif not aggfunc_color:
@@ -1227,7 +1231,7 @@ def aggfunc_continuous(x):
12271231
return args
12281232

12291233

1230-
def infer_config(args, constructor, trace_patch):
1234+
def infer_config(args, constructor, trace_patch, layout_patch):
12311235
# Declare all supported attributes, across all plot types
12321236
attrables = (
12331237
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
@@ -1263,10 +1267,7 @@ def infer_config(args, constructor, trace_patch):
12631267
if "color_discrete_sequence" not in args:
12641268
attrs.append("color")
12651269
else:
1266-
if (
1267-
args["color"]
1268-
and args["data_frame"][args["color"]].dtype.kind in "ifc"
1269-
):
1270+
if args["color"] and _is_continuous(args["data_frame"], args["color"]):
12701271
attrs.append("color")
12711272
args["color_is_continuous"] = True
12721273
elif constructor in [go.Sunburst, go.Treemap]:
@@ -1305,8 +1306,55 @@ def infer_config(args, constructor, trace_patch):
13051306
if "symbol" in args:
13061307
grouped_attrs.append("marker.symbol")
13071308

1308-
# Compute final trace patch
1309-
trace_patch = trace_patch.copy()
1309+
if "orientation" in args:
1310+
has_x = args["x"] is not None
1311+
has_y = args["y"] is not None
1312+
if args["orientation"] is None:
1313+
if constructor in [go.Histogram, go.Scatter]:
1314+
if has_y and not has_x:
1315+
args["orientation"] = "h"
1316+
elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
1317+
if has_x and not has_y:
1318+
args["orientation"] = "h"
1319+
1320+
if args["orientation"] is None and has_x and has_y:
1321+
x_is_continuous = _is_continuous(args["data_frame"], args["x"])
1322+
y_is_continuous = _is_continuous(args["data_frame"], args["y"])
1323+
if x_is_continuous and not y_is_continuous:
1324+
args["orientation"] = "h"
1325+
if y_is_continuous and not x_is_continuous:
1326+
args["orientation"] = "v"
1327+
1328+
if args["orientation"] is None:
1329+
args["orientation"] = "v"
1330+
1331+
if constructor == go.Histogram:
1332+
if has_x and has_y and args["histfunc"] is None:
1333+
args["histfunc"] = trace_patch["histfunc"] = "sum"
1334+
1335+
orientation = args["orientation"]
1336+
nbins = args["nbins"]
1337+
trace_patch["nbinsx"] = nbins if orientation == "v" else None
1338+
trace_patch["nbinsy"] = None if orientation == "v" else nbins
1339+
trace_patch["bingroup"] = "x" if orientation == "v" else "y"
1340+
trace_patch["orientation"] = args["orientation"]
1341+
1342+
if constructor in [go.Violin, go.Box]:
1343+
mode = "boxmode" if constructor == go.Box else "violinmode"
1344+
if layout_patch[mode] is None and args["color"] is not None:
1345+
if args["y"] == args["color"] and args["orientation"] == "h":
1346+
layout_patch[mode] = "overlay"
1347+
elif args["x"] == args["color"] and args["orientation"] == "v":
1348+
layout_patch[mode] = "overlay"
1349+
if layout_patch[mode] is None:
1350+
layout_patch[mode] = "group"
1351+
1352+
if (
1353+
constructor == go.Histogram2d
1354+
and args["z"] is not None
1355+
and args["histfunc"] is None
1356+
):
1357+
args["histfunc"] = trace_patch["histfunc"] = "sum"
13101358

13111359
if constructor in [go.Histogram2d, go.Densitymapbox]:
13121360
show_colorbar = True
@@ -1354,7 +1402,7 @@ def infer_config(args, constructor, trace_patch):
13541402

13551403
# Create trace specs
13561404
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
1357-
return args, trace_specs, grouped_mappings, sizeref, show_colorbar
1405+
return trace_specs, grouped_mappings, sizeref, show_colorbar
13581406

13591407

13601408
def get_orderings(args, grouper, grouped):
@@ -1398,11 +1446,13 @@ def get_orderings(args, grouper, grouped):
13981446
return orders, group_names, group_values
13991447

14001448

1401-
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1449+
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1450+
trace_patch = trace_patch or {}
1451+
layout_patch = layout_patch or {}
14021452
apply_default_cascade(args)
14031453

1404-
args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
1405-
args, constructor, trace_patch
1454+
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
1455+
args, constructor, trace_patch, layout_patch
14061456
)
14071457
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
14081458
grouped = args["data_frame"].groupby(grouper, sort=False)

0 commit comments

Comments
 (0)