Skip to content

Commit 761614c

Browse files
committed
Allow sealing method references as function types
1 parent 0d4b138 commit 761614c

File tree

5 files changed

+117
-7
lines changed

5 files changed

+117
-7
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/QuotedOpsImpl.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package dotty.tools.dotc.tastyreflect
22

3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.ast.Trees
5+
import dotty.tools.dotc.core.Flags._
6+
import dotty.tools.dotc.core.Symbols.defn
7+
import dotty.tools.dotc.core.StdNames.nme
38
import dotty.tools.dotc.core.quoted.PickledQuotes
9+
import dotty.tools.dotc.core.Types.MethodType
410

511
trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
612

@@ -15,16 +21,28 @@ trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
1521
def TermToQuoteDeco(term: Term): TermToQuotedAPI = new TermToQuotedAPI {
1622

1723
def seal[T: scala.quoted.Type](implicit ctx: Context): scala.quoted.Expr[T] = {
18-
typecheck()
19-
new scala.quoted.Exprs.TastyTreeExpr(term).asInstanceOf[scala.quoted.Expr[T]]
20-
}
2124

22-
private def typecheck[T: scala.quoted.Type]()(implicit ctx: Context): Unit = {
23-
val tpt = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal
24-
if (!(term.tpe <:< tpt.tpe)) {
25+
val expectedType = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal.tpe
26+
27+
def etaExpand(term: Term): Term = term.tpe.widen match {
28+
case mtpe: MethodType =>
29+
val closureResType = mtpe.resType match {
30+
case t: MethodType => t.toFunctionType()
31+
case t => t
32+
}
33+
val closureTpe = MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
34+
val closureMethod = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
35+
tpd.Closure(closureMethod, tss => etaExpand(new tpd.TreeOps(term).appliedToArgs(tss.head)))
36+
case _ => term
37+
}
38+
39+
val expanded = etaExpand(term)
40+
if (expanded.tpe <:< expectedType) {
41+
new scala.quoted.Exprs.TastyTreeExpr(expanded).asInstanceOf[scala.quoted.Expr[T]]
42+
} else {
2543
throw new scala.tasty.TastyTypecheckError(
2644
s"""Term: ${term.show}
27-
|did not conform to type: ${tpt.tpe.show}
45+
|did not conform to type: ${expectedType.show}
2846
|""".stripMargin
2947
)
3048
}

compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ trait TypeOrBoundsOpsImpl extends scala.tasty.reflect.TypeOrBoundsOps with CoreI
77
def TypeDeco(tpe: Type): TypeAPI = new TypeAPI {
88
def =:=(other: Type)(implicit ctx: Context): Boolean = tpe =:= other
99
def <:<(other: Type)(implicit ctx: Context): Boolean = tpe <:< other
10+
11+
/** Widen from singleton type to its underlying non-singleton
12+
* base type by applying one or more `underlying` dereferences,
13+
* Also go from => T to T.
14+
* Identity for all other types. Example:
15+
*
16+
* class Outer { class C ; val x: C }
17+
* def o: Outer
18+
* <o.x.type>.widen = o.C
19+
*/
20+
def widen(implicit ctx: Context): Type = tpe.widen
21+
1022
}
1123

1224
def ConstantTypeDeco(x: ConstantType): Type.ConstantTypeAPI = new Type.ConstantTypeAPI {

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ trait TypeOrBoundsOps extends Core {
5252
trait TypeAPI {
5353
def =:=(other: Type)(implicit ctx: Context): Boolean
5454
def <:<(other: Type)(implicit ctx: Context): Boolean
55+
def widen(implicit ctx: Context): Type
5556
}
5657

5758
val IsType: IsTypeModule
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import scala.quoted._
2+
3+
import scala.tasty._
4+
5+
object Asserts {
6+
7+
inline def zeroLastArgs(x: => Int): Int =
8+
~zeroLastArgsImpl('(x))
9+
10+
/** Replaces last argument list by 0s */
11+
def zeroLastArgsImpl(x: Expr[Int])(implicit reflect: Reflection): Expr[Int] = {
12+
import reflect._
13+
// For simplicity assumes that all parameters are Int and parameter lists have no more than 3 elements
14+
x.unseal.underlyingArgument match {
15+
case Term.Apply(fn, args) =>
16+
fn.tpe.widen match {
17+
case Type.IsMethodType(_) =>
18+
args.size match {
19+
case 0 => fn.seal[() => Int].apply()
20+
case 1 => fn.seal[Int => Int].apply('(0))
21+
case 2 => fn.seal[(Int, Int) => Int].apply('(0), '(0))
22+
case 3 => fn.seal[(Int, Int, Int) => Int].apply('(0), '(0), '(0))
23+
}
24+
}
25+
case _ => x
26+
}
27+
}
28+
29+
inline def zeroAllArgs(x: => Int): Int =
30+
~zeroAllArgsImpl('(x))
31+
32+
/** Replaces all argument list by 0s */
33+
def zeroAllArgsImpl(x: Expr[Int])(implicit reflect: Reflection): Expr[Int] = {
34+
import reflect._
35+
// For simplicity assumes that all parameters are Int and parameter lists have no more than 3 elements
36+
def rec(term: Term): Term = term match {
37+
case Term.Apply(fn, args) =>
38+
val pre = rec(fn)
39+
args.size match {
40+
case 0 => pre.seal[() => Any].apply().unseal
41+
case 1 => pre.seal[Int => Any].apply('(0)).unseal
42+
case 2 => pre.seal[(Int, Int) => Any].apply('(0), '(0)).unseal
43+
case 3 => pre.seal[(Int, Int, Int) => Any].apply('(0), '(0), '(0)).unseal
44+
}
45+
case _ => term
46+
}
47+
48+
rec(x.unseal.underlyingArgument).seal[Int]
49+
}
50+
51+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
import Asserts._
3+
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
assert(zeroLastArgs(-1) == -1)
7+
assert(zeroLastArgs(f) == 41)
8+
assert(zeroLastArgs(f0()) == 42)
9+
assert(zeroLastArgs(f1(2)) == 1)
10+
assert(zeroLastArgs(f2(2, 3)) == 2)
11+
assert(zeroLastArgs(f3(2)(4, 5)) == 5)
12+
13+
assert(zeroAllArgs(-1) == -1)
14+
assert(zeroAllArgs(f) == 41)
15+
assert(zeroAllArgs(f0()) == 42)
16+
assert(zeroAllArgs(f1(2)) == 1)
17+
assert(zeroAllArgs(f2(2, 3)) == 2)
18+
assert(zeroAllArgs(f3(2)(4, 5)) == 3)
19+
}
20+
21+
def f: Int = 41
22+
def f0(): Int = 42
23+
def f1(i: Int): Int = 1 + i
24+
def f2(i: Int, j: Int): Int = 2 + i + j
25+
def f3(i: Int)(j: Int, k: Int): Int = 3 + i + j
26+
def f4(i: Int, j: Int)(k: Int, l: Int): Int = 4 + i + j + k + l
27+
28+
}

0 commit comments

Comments
 (0)