Skip to content

Commit 0afccf5

Browse files
committed
Merge pull request #898 from dotty-staging/add/auto-uncurry
Implement auto tupling of function arguments
2 parents 4be70a5 + 4ceb3e7 commit 0afccf5

File tree

7 files changed

+105
-20
lines changed

7 files changed

+105
-20
lines changed

src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,26 @@ object desugar {
588588
Function(params, Match(selector, cases))
589589
}
590590

591+
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
592+
*
593+
* x$1 => {
594+
* def p1 = x$1._1
595+
* ...
596+
* def pn = x$1._n
597+
* body
598+
* }
599+
*/
600+
def makeTupledFunction(params: List[ValDef], body: Tree)(implicit ctx: Context): Tree = {
601+
val param = makeSyntheticParameter()
602+
def selector(n: Int) = Select(refOfDef(param), nme.selectorName(n))
603+
val vdefs =
604+
params.zipWithIndex.map{
605+
case (param, idx) =>
606+
DefDef(param.name, Nil, Nil, TypeTree(), selector(idx)).withPos(param.pos)
607+
}
608+
Function(param :: Nil, Block(vdefs, body))
609+
}
610+
591611
/** Add annotation with class `cls` to tree:
592612
* tree @cls
593613
*/

src/dotty/tools/dotc/transform/DropEmptyCompanions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DropEmptyCompanions extends MiniPhaseTransform { thisTransform =>
4040
case TypeDef(_, impl: Template) if tree.symbol.is(SyntheticModule) &&
4141
tree.symbol.companionClass.exists &&
4242
impl.body.forall(_.symbol.isPrimaryConstructor) =>
43-
println(i"removing ${tree.symbol}")
43+
ctx.log(i"removing ${tree.symbol}")
4444
true
4545
case _ =>
4646
false

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -613,26 +613,44 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
613613
if (protoFormals.length == params.length) protoFormals(i)
614614
else errorType(i"wrong number of parameters, expected: ${protoFormals.length}", tree.pos)
615615

616-
val inferredParams: List[untpd.ValDef] =
617-
for ((param, i) <- params.zipWithIndex) yield
618-
if (!param.tpt.isEmpty) param
619-
else cpy.ValDef(param)(
620-
tpt = untpd.TypeTree(
621-
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
622-
623-
// Define result type of closure as the expected type, thereby pushing
624-
// down any implicit searches. We do this even if the expected type is not fully
625-
// defined, which is a bit of a hack. But it's needed to make the following work
626-
// (see typers.scala and printers/PlainPrinter.scala for examples).
627-
//
628-
// def double(x: Char): String = s"$x$x"
629-
// "abc" flatMap double
630-
//
631-
val resultTpt = protoResult match {
632-
case WildcardType(_) => untpd.TypeTree()
633-
case _ => untpd.TypeTree(protoResult)
616+
/** Is `formal` a product type which is elementwise compatible with `params`? */
617+
def ptIsCorrectProduct(formal: Type) = {
618+
val pclass = defn.ProductNType(params.length).symbol
619+
isFullyDefined(formal, ForceDegree.noBottom) &&
620+
formal.derivesFrom(pclass) &&
621+
formal.baseArgTypes(pclass).corresponds(params) {
622+
(argType, param) =>
623+
param.tpt.isEmpty || argType <:< typedAheadType(param.tpt).tpe
624+
}
634625
}
635-
typed(desugar.makeClosure(inferredParams, fnBody, resultTpt), pt)
626+
627+
val desugared =
628+
if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) {
629+
desugar.makeTupledFunction(params, fnBody)
630+
}
631+
else {
632+
val inferredParams: List[untpd.ValDef] =
633+
for ((param, i) <- params.zipWithIndex) yield
634+
if (!param.tpt.isEmpty) param
635+
else cpy.ValDef(param)(
636+
tpt = untpd.TypeTree(
637+
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
638+
639+
// Define result type of closure as the expected type, thereby pushing
640+
// down any implicit searches. We do this even if the expected type is not fully
641+
// defined, which is a bit of a hack. But it's needed to make the following work
642+
// (see typers.scala and printers/PlainPrinter.scala for examples).
643+
//
644+
// def double(x: Char): String = s"$x$x"
645+
// "abc" flatMap double
646+
//
647+
val resultTpt = protoResult match {
648+
case WildcardType(_) => untpd.TypeTree()
649+
case _ => untpd.TypeTree(protoResult)
650+
}
651+
desugar.makeClosure(inferredParams, fnBody, resultTpt)
652+
}
653+
typed(desugared, pt)
636654
}
637655
}
638656

test/dotc/tests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class tests extends CompilerTest {
111111
@Test def neg_abstractOverride() = compileFile(negDir, "abstract-override", xerrors = 2)
112112
@Test def neg_blockescapes() = compileFile(negDir, "blockescapesNeg", xerrors = 1)
113113
@Test def neg_bounds() = compileFile(negDir, "bounds", xerrors = 2)
114+
@Test def neg_functionArity() = compileFile(negDir, "function-arity", xerrors = 7)
114115
@Test def neg_typedapply() = compileFile(negDir, "typedapply", xerrors = 3)
115116
@Test def neg_typedIdents() = compileDir(negDir, "typedIdents", xerrors = 2)
116117
@Test def neg_assignments() = compileFile(negDir, "assignments", xerrors = 3)

tests/neg/function-arity.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
object Test {
2+
3+
// From #873:
4+
5+
trait X extends Function1[Int, String]
6+
implicit def f2x(f: Function1[Int, String]): X = ???
7+
({case _ if "".isEmpty => 0} : X) // error: expected String, found Int
8+
9+
// Tests where parameter list cannot be made into a pattern
10+
11+
def unary[T](x: T => Unit) = ???
12+
unary((x, y) => ()) // error
13+
14+
unary[(Int, Int)]((x, y) => ())
15+
16+
unary[(Int, Int)](() => ()) // error
17+
unary[(Int, Int)]((x, y, _) => ()) // error
18+
19+
unary[(Int, Int)]((x: String, y) => ()) // error
20+
21+
def foo(a: Tuple2[Int, Int] => String): String = ""
22+
def foo(a: Any => String) = ()
23+
foo((a: Int, b: String) => a + b) // error: none of the overloaded alternatives of method foo match arguments (Int, Int)
24+
}
25+
object jasonComment {
26+
implicit def i2s(i: Int): String = i.toString
27+
((x: String, y: String) => 42) : (((Int, Int)) => String) // error
28+
}

tests/pos/i873.scala renamed to tests/pos/function-arity.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,14 @@ object Test {
77
({case _ if "".isEmpty => ""} : X) // allowed, implicit view used to adapt
88

99
// ({case _ if "".isEmpty => 0} : X) // expected String, found Int
10+
11+
def unary[T](a: T, b: T, f: ((T, T)) => T): T = f((a, b))
12+
unary(1, 2, (x, y) => x)
13+
unary(1, 2, (x: Int, y) => x)
14+
unary(1, 2, (x: Int, y: Int) => x)
15+
16+
val xs = List(1, 2, 3)
17+
def f(x: Int, y: Int) = x * y
18+
xs.zipWithIndex.map(_ + _)
19+
xs.zipWithIndex.map(f)
1020
}

tests/run/function-arity.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
object Test {
2+
class T[A] { def foo(f: (=> A) => Int) = f(???) }
3+
4+
def main(args: Array[String]): Unit = {
5+
new T[(Int, Int)].foo((ii) => 0)
6+
new T[(Int, Int)].foo((x, y) => 0) // check that this does not run into ???
7+
}
8+
}

0 commit comments

Comments
 (0)