Skip to content

Commit 05bde2a

Browse files
oderskyKacperFKorban
authored andcommitted
Streamline translation of for expressions
- [] Avoid redundant map call if the yielded value is the same as the last result. This makes for expressions more efficient and provides more opportunities for tail recursion.
1 parent c6fbe6f commit 05bde2a

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,38 +1807,44 @@ object desugar {
18071807
*
18081808
* 1.
18091809
*
1810-
* for (P <- G) E ==> G.foreach (P => E)
1810+
* for (P <- G) E ==> G.foreach (P => E)
18111811
*
1812-
* Here and in the following (P => E) is interpreted as the function (P => E)
1813-
* if P is a variable pattern and as the partial function { case P => E } otherwise.
1812+
* Here and in the following (P => E) is interpreted as the function (P => E)
1813+
* if P is a variable pattern and as the partial function { case P => E } otherwise.
18141814
*
18151815
* 2.
18161816
*
1817-
* for (P <- G) yield E ==> G.map (P => E)
1817+
* for (P <- G) yield P ==> G
1818+
*
1819+
* if P is a variable or a tuple of variables and G is not a withFilter.
1820+
*
1821+
* for (P <- G) yield E ==> G.map (P => E)
1822+
*
1823+
* otherwise
18181824
*
18191825
* 3.
18201826
*
1821-
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
1822-
* ==>
1823-
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
1827+
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
1828+
* ==>
1829+
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
18241830
*
18251831
* 4.
18261832
*
1827-
* for (P <- G; E; ...) ...
1828-
* =>
1829-
* for (P <- G.filter (P => E); ...) ...
1833+
* for (P <- G; E; ...) ...
1834+
* =>
1835+
* for (P <- G.filter (P => E); ...) ...
18301836
*
18311837
* 5. For any N:
18321838
*
1833-
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
1834-
* ==>
1835-
* for (TupleN(P_1, P_2, ... P_N) <-
1836-
* for (x_1 @ P_1 <- G) yield {
1837-
* val x_2 @ P_2 = E_2
1838-
* ...
1839-
* val x_N & P_N = E_N
1840-
* TupleN(x_1, ..., x_N)
1841-
* } ...)
1839+
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
1840+
* ==>
1841+
* for (TupleN(P_1, P_2, ... P_N) <-
1842+
* for (x_1 @ P_1 <- G) yield {
1843+
* val x_2 @ P_2 = E_2
1844+
* ...
1845+
* val x_N & P_N = E_N
1846+
* TupleN(x_1, ..., x_N)
1847+
* } ...)
18421848
*
18431849
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
18441850
* and the variable constituting P_i is used instead of x_i
@@ -1951,7 +1957,7 @@ object desugar {
19511957
case GenCheckMode.FilterAlways => false // pattern was prefixed by `case`
19521958
case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr)
19531959
case GenCheckMode.Check => true
1954-
case GenCheckMode.Ignore => true
1960+
case GenCheckMode.Ignore | GenCheckMode.Filtered => true
19551961

19561962
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
19571963
* matched against `rhs`.
@@ -1961,9 +1967,18 @@ object desugar {
19611967
Select(rhs, name)
19621968
}
19631969

1970+
def deepEquals(t1: Tree, t2: Tree): Boolean =
1971+
(unsplice(t1), unsplice(t2)) match
1972+
case (Ident(n1), Ident(n2)) => n1 == n2
1973+
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
1974+
case _ => false
1975+
19641976
enums match {
19651977
case (gen: GenFrom) :: Nil =>
1966-
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
1978+
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
1979+
&& deepEquals(gen.pat, body)
1980+
then gen.expr // avoid a redundant map with identity
1981+
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
19671982
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
19681983
val cont = makeFor(mapName, flatMapName, rest, body)
19691984
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
@@ -1985,7 +2000,7 @@ object desugar {
19852000
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
19862001
case (gen: GenFrom) :: test :: rest =>
19872002
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
1988-
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
2003+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
19892004
makeFor(mapName, flatMapName, genFrom :: rest, body)
19902005
case _ =>
19912006
EmptyTree //may happen for erroneous input

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
183183

184184
/** An enum to control checking or filtering of patterns in GenFrom trees */
185185
enum GenCheckMode {
186-
case Ignore // neither filter nor check since filtering was done before
186+
case Ignore // neither filter since pattern is trivially irrefutable
187+
case Filtered // neither filter nor check since filtering was done before
187188
case Check // check that pattern is irrefutable
188189
case CheckAndFilter // both check and filter (transitional period starting with 3.2)
189190
case FilterNow // filter out non-matching elements if we are not in 3.2 or later

tests/run/fors.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ hello world
4545
hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4
4646
(2,1) (4,3)
4747

48+
testTailrec
49+
List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c)))
50+
4851
testGivens
4952
123
5053
456

tests/run/fors.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
//############################################################################
66

7+
import annotation.tailrec
8+
79
object Test extends App {
810
val xs = List(1, 2, 3)
911
val ys = List(Symbol("a"), Symbol("b"), Symbol("c"))
@@ -108,6 +110,17 @@ object Test extends App {
108110
for case (x, y) <- xs do print(s"${(y, x)} "); println()
109111
}
110112

113+
/////////////////// elimination of map ///////////////////
114+
115+
@tailrec
116+
def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] =
117+
if n == 0 then xs.zip(ys)
118+
else for (x, y) <- pair(xs.map(_ + 1), ys, n - 1) yield (x, y)
119+
120+
def testTailrec() =
121+
println("\ntestTailrec")
122+
println(pair(xs, ys, 3))
123+
111124
def testGivens(): Unit = {
112125
println("\ntestGivens")
113126

@@ -141,5 +154,6 @@ object Test extends App {
141154
testOld()
142155
testNew()
143156
testFiltering()
157+
testTailrec()
144158
testGivens()
145159
}

0 commit comments

Comments
 (0)