Skip to content

Commit cb34e6d

Browse files
authored
Avoid parenthesizing comprehension element (#6198)
## Summary This PR adds a new precedence level for the comprehension element. This fixes the generator to not add parentheses around the comprehension element every time. The new precedence level is `COMPREHENSION_ELEMENT` and it should occur after the `NAMED_EXPR` precedence level because named expressions are always parenthesized. This matches the behavior of Python `ast.unparse` and tested with the following snippet: ```python import ast code = "" ast.unparse(ast.parse(code)) ``` ## Test Plan Add a bunch of test cases for all the valid nodes at that position. fixes: #5777
1 parent 0274de1 commit cb34e6d

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

crates/ruff_python_codegen/src/generator.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod precedence {
3434
pub(crate) const COMMA: u8 = 21;
3535
pub(crate) const NAMED_EXPR: u8 = 23;
3636
pub(crate) const ASSERT: u8 = 23;
37+
pub(crate) const COMPREHENSION_ELEMENT: u8 = 27;
3738
pub(crate) const LAMBDA: u8 = 27;
3839
pub(crate) const IF_EXP: u8 = 27;
3940
pub(crate) const COMPREHENSION: u8 = 29;
@@ -1052,7 +1053,7 @@ impl<'a> Generator<'a> {
10521053
range: _range,
10531054
}) => {
10541055
self.p("[");
1055-
self.unparse_expr(elt, precedence::MAX);
1056+
self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT);
10561057
self.unparse_comp(generators);
10571058
self.p("]");
10581059
}
@@ -1062,7 +1063,7 @@ impl<'a> Generator<'a> {
10621063
range: _range,
10631064
}) => {
10641065
self.p("{");
1065-
self.unparse_expr(elt, precedence::MAX);
1066+
self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT);
10661067
self.unparse_comp(generators);
10671068
self.p("}");
10681069
}
@@ -1073,9 +1074,9 @@ impl<'a> Generator<'a> {
10731074
range: _range,
10741075
}) => {
10751076
self.p("{");
1076-
self.unparse_expr(key, precedence::MAX);
1077+
self.unparse_expr(key, precedence::COMPREHENSION_ELEMENT);
10771078
self.p(": ");
1078-
self.unparse_expr(value, precedence::MAX);
1079+
self.unparse_expr(value, precedence::COMPREHENSION_ELEMENT);
10791080
self.unparse_comp(generators);
10801081
self.p("}");
10811082
}
@@ -1085,7 +1086,7 @@ impl<'a> Generator<'a> {
10851086
range: _range,
10861087
}) => {
10871088
self.p("(");
1088-
self.unparse_expr(elt, precedence::COMMA);
1089+
self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT);
10891090
self.unparse_comp(generators);
10901091
self.p(")");
10911092
}
@@ -1570,6 +1571,8 @@ mod tests {
15701571
assert_round_trip!("foo(1)");
15711572
assert_round_trip!("foo(1, 2)");
15721573
assert_round_trip!("foo(x for x in y)");
1574+
assert_round_trip!("foo([x for x in y])");
1575+
assert_round_trip!("foo([(x := 2) for x in y])");
15731576
assert_round_trip!("x = yield 1");
15741577
assert_round_trip!("return (yield 1)");
15751578
assert_round_trip!("lambda: (1, 2, 3)");
@@ -1622,8 +1625,8 @@ mod tests {
16221625
r#"def f() -> (int, str):
16231626
pass"#
16241627
);
1625-
assert_round_trip!("[(await x) async for x in y]");
1626-
assert_round_trip!("[(await i) for i in b if await c]");
1628+
assert_round_trip!("[await x async for x in y]");
1629+
assert_round_trip!("[await i for i in b if await c]");
16271630
assert_round_trip!("(await x async for x in y)");
16281631
assert_round_trip!(
16291632
r#"async def read_data(db):
@@ -1719,6 +1722,18 @@ class Foo:
17191722
pass"#
17201723
);
17211724

1725+
assert_round_trip!(r#"[lambda n: n for n in range(10)]"#);
1726+
assert_round_trip!(r#"[n[0:2] for n in range(10)]"#);
1727+
assert_round_trip!(r#"[n[0] for n in range(10)]"#);
1728+
assert_round_trip!(r#"[(n, n * 2) for n in range(10)]"#);
1729+
assert_round_trip!(r#"[1 if n % 2 == 0 else 0 for n in range(10)]"#);
1730+
assert_round_trip!(r#"[n % 2 == 0 or 0 for n in range(10)]"#);
1731+
assert_round_trip!(r#"[(n := 2) for n in range(10)]"#);
1732+
assert_round_trip!(r#"((n := 2) for n in range(10))"#);
1733+
assert_round_trip!(r#"[n * 2 for n in range(10)]"#);
1734+
assert_round_trip!(r#"{n * 2 for n in range(10)}"#);
1735+
assert_round_trip!(r#"{i: n * 2 for i, n in enumerate(range(10))}"#);
1736+
17221737
// Type aliases
17231738
assert_round_trip!(r#"type Foo = int | str"#);
17241739
assert_round_trip!(r#"type Foo[T] = list[T]"#);

0 commit comments

Comments
 (0)