Skip to content

Add quote pattern bindings #6212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ object desugar {
* def x: Int = expr
* def x_=($1: <TypeTree()>): Unit = ()
*/
def valDef(vdef: ValDef)(implicit ctx: Context): Tree = {
val ValDef(name, tpt, rhs) = vdef
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0)
val mods = vdef.mods
val setterNeeded =
(mods is Mutable) && ctx.owner.isClass && (!(mods is PrivateLocal) || (ctx.owner is Trait))
Expand Down Expand Up @@ -197,8 +197,8 @@ object desugar {
* ==>
* inline def f(x: Boolean): Any = (if (x) 1 else ""): Any
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = {
val DefDef(_, tparams, vparamss, tpt, rhs) = meth
private def defDef(meth0: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = {
val meth @ DefDef(_, tparams, vparamss, tpt, rhs) = transformQuotedPatternName(meth0)
val methName = normalizeName(meth, tpt).asTermName
val mods = meth.mods
val epbuf = new ListBuffer[ValDef]
Expand Down Expand Up @@ -272,6 +272,32 @@ object desugar {
}
}

/** Transforms a definition with a name starting with a `$` in a quoted pattern into a `quoted.binding.Binding` splice.
*
* The desugaring consists in renaming the the definition and adding the `@patternBindHole` annotation. This
* annotation is used during typing to perform the full transformation.
*
* A definition
* ```scala
* case '{ def $a(...) = ... a() ...; ... a() ... }
* ```
* into
* ```scala
* case '{ @patternBindHole def a(...) = ... a() ...; ... a() ... }
* ```
*/
def transformQuotedPatternName(tree: ValOrDefDef)(implicit ctx: Context): ValOrDefDef = {
if (ctx.mode.is(Mode.QuotedPattern) && !tree.isBackquoted && tree.name != nme.ANON_FUN && tree.name.startsWith("$")) {
val name = tree.name.toString.substring(1).toTermName
val newTree: ValOrDefDef = tree match {
case tree: ValDef => cpy.ValDef(tree)(name)
case tree: DefDef => cpy.DefDef(tree)(name)
}
val mods = tree.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(tree.span))
newTree.withMods(mods)
} else tree
}

// Add all evidence parameters in `params` as implicit parameters to `meth` */
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(implicit ctx: Context): DefDef =
params match {
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,14 @@ object Trees {

/** A ValDef or DefDef tree */
abstract class ValOrDefDef[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends MemberDef[T] with WithLazyField[Tree[T]] {
type ThisTree[-T >: Untyped] <: ValOrDefDef[T]
def name: TermName
def tpt: Tree[T]
def unforcedRhs: LazyTree = unforced
def rhs(implicit ctx: Context): Tree[T] = forceIfLazy

/** Is this a `BackquotedValDef` or `BackquotedDefDef` ? */
def isBackquoted: Boolean = false
}

// ----------- Tree case classes ------------------------------------
Expand Down Expand Up @@ -706,6 +710,12 @@ object Trees {
protected def force(x: AnyRef): Unit = preRhs = x
}

class BackquotedValDef[-T >: Untyped] private[ast] (name: TermName, tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
extends ValDef[T](name, tpt, preRhs) {
override def isBackquoted: Boolean = true
override def productPrefix: String = "BackquotedValDef"
}

/** mods def name[tparams](vparams_1)...(vparams_n): tpt = rhs */
case class DefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]],
vparamss: List[List[ValDef[T]]], tpt: Tree[T], private var preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
Expand All @@ -716,6 +726,13 @@ object Trees {
protected def force(x: AnyRef): Unit = preRhs = x
}

class BackquotedDefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]],
vparamss: List[List[ValDef[T]]], tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
extends DefDef[T](name, tparams, vparamss, tpt, preRhs) {
override def isBackquoted: Boolean = true
override def productPrefix: String = "BackquotedDefDef"
}

/** mods class name template or
* mods trait name template or
* mods type name = rhs or
Expand Down Expand Up @@ -932,7 +949,9 @@ object Trees {
type Alternative = Trees.Alternative[T]
type UnApply = Trees.UnApply[T]
type ValDef = Trees.ValDef[T]
type BackquotedValDef = Trees.BackquotedValDef[T]
type DefDef = Trees.DefDef[T]
type BackquotedDefDef = Trees.BackquotedDefDef[T]
type TypeDef = Trees.TypeDef[T]
type Template = Trees.Template[T]
type Import = Trees.Import[T]
Expand Down Expand Up @@ -1125,10 +1144,16 @@ object Trees {
case _ => finalize(tree, untpd.UnApply(fun, implicits, patterns)(sourceFile(tree)))
}
def ValDef(tree: Tree)(name: TermName, tpt: Tree, rhs: LazyTree)(implicit ctx: Context): ValDef = tree match {
case tree: BackquotedValDef =>
if ((name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree
else finalize(tree, untpd.BackquotedValDef(name, tpt, rhs)(sourceFile(tree)))
case tree: ValDef if (name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree
case _ => finalize(tree, untpd.ValDef(name, tpt, rhs)(sourceFile(tree)))
}
def DefDef(tree: Tree)(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit ctx: Context): DefDef = tree match {
case tree: BackquotedDefDef =>
if ((name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree
else finalize(tree, untpd.BackquotedDefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree)))
case tree: DefDef if (name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree
case _ => finalize(tree, untpd.DefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree)))
}
Expand Down
10 changes: 8 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def Alternative(trees: List[Tree])(implicit src: SourceFile): Alternative = new Alternative(trees)
def UnApply(fun: Tree, implicits: List[Tree], patterns: List[Tree])(implicit src: SourceFile): UnApply = new UnApply(fun, implicits, patterns)
def ValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new ValDef(name, tpt, rhs)
def BackquotedValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new BackquotedValDef(name, tpt, rhs)
def DefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new DefDef(name, tparams, vparamss, tpt, rhs)
def BackquotedDefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new BackquotedDefDef(name, tparams, vparamss, tpt, rhs)
def TypeDef(name: TypeName, rhs: Tree)(implicit src: SourceFile): TypeDef = new TypeDef(name, rhs)
def Template(constr: DefDef, parents: List[Tree], derived: List[Tree], self: ValDef, body: LazyTreeList)(implicit src: SourceFile): Template =
if (derived.isEmpty) new Template(constr, parents, self, body)
Expand Down Expand Up @@ -406,8 +408,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def makeAndType(left: Tree, right: Tree)(implicit ctx: Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil)

def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers)(implicit ctx: Context): ValDef =
ValDef(pname, tpe, EmptyTree).withMods(mods | Param)
def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers, isBackquoted: Boolean = false)(implicit ctx: Context): ValDef = {
val vdef =
if (isBackquoted) BackquotedValDef(pname, tpe, EmptyTree)
else ValDef(pname, tpe, EmptyTree)
vdef.withMods(mods | Param)
}

def makeSyntheticParameter(n: Int = 1, tpt: Tree = null, flags: FlagSet = EmptyFlags)(implicit ctx: Context): ValDef =
ValDef(nme.syntheticParamName(n), if (tpt == null) TypeTree() else tpt, EmptyTree)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ class Definitions {
def InternalQuoted_typeQuote(implicit ctx: Context): Symbol = InternalQuoted_typeQuoteR.symbol
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole")

lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
Expand All @@ -741,6 +742,9 @@ class Definitions {
lazy val QuotedTypeModuleRef: TermRef = ctx.requiredModuleRef("scala.quoted.Type")
def QuotedTypeModule(implicit ctx: Context): Symbol = QuotedTypeModuleRef.symbol

lazy val QuotedMatchingBindingType: TypeRef = ctx.requiredClassRef("scala.quoted.matching.Bind")
def QuotedMatchingBindingClass(implicit ctx: Context): ClassSymbol = QuotedMatchingBindingType.symbol.asClass

def Unpickler_unpickleExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleExpr")
def Unpickler_liftedExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.liftedExpr")
def Unpickler_unpickleType: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleType")
Expand Down
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,12 @@ object Parsers {
/** Convert tree to formal parameter
*/
def convertToParam(tree: Tree, expected: String = "formal parameter"): ValDef = tree match {
case Ident(name) =>
makeParameter(name.asTermName, TypeTree()).withSpan(tree.span)
case Typed(Ident(name), tpt) =>
makeParameter(name.asTermName, tpt).withSpan(tree.span)
case id @ Ident(name) =>
makeParameter(name.asTermName, TypeTree(), isBackquoted = id.isBackquoted).withSpan(tree.span)
case Typed(id @ Ident(name), tpt) =>
makeParameter(name.asTermName, tpt, isBackquoted = id.isBackquoted).withSpan(tree.span)
case Typed(Splice(Ident(name)), tpt) =>
makeParameter(("$" + name).toTermName, tpt).withSpan(tree.span)
case _ =>
syntaxError(s"not a legal $expected", tree.span)
makeParameter(nme.ERROR, tree)
Expand Down Expand Up @@ -2370,7 +2372,9 @@ object Parsers {
}
} else EmptyTree
lhs match {
case (id @ Ident(name: TermName)) :: Nil => {
case (id: BackquotedIdent) :: Nil if id.name.isTermName =>
finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start)
case Ident(name: TermName) :: Nil => {
finalizeDef(ValDef(name, tpt, rhs), mods, start)
} case _ =>
PatDef(mods, lhs, tpt, rhs)
Expand Down Expand Up @@ -2414,10 +2418,10 @@ object Parsers {
else
(Nil, Method)
val mods1 = addFlag(mods, flags)
val name = ident()
val ident = termIdent()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val vparamss = paramClauses() match {
case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(name) =>
case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(ident.name) =>
rparams :: leadingParamss ::: rparamss
case rparamss =>
leadingParamss ::: rparamss
Expand Down Expand Up @@ -2447,7 +2451,9 @@ object Parsers {
accept(EQUALS)
expr()
}
finalizeDef(DefDef(name, tparams, vparamss, tpt, rhs), mods1, start)

if (ident.isBackquoted) finalizeDef(BackquotedDefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start)
else finalizeDef(DefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass

def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol = defn.InternalQuoted_patternBindHoleAnnot

// Types

Expand Down
17 changes: 17 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,7 @@ class Typer extends Namer
}

def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
val ctx0 = ctx
object splitter extends tpd.TreeMap {
val patBuf = new mutable.ListBuffer[Tree]
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
Expand All @@ -1973,6 +1974,22 @@ class Typer extends Namer
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
}
case ddef: ValOrDefDef =>
if (ddef.symbol.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot)) {
val bindingType = ddef.symbol.info match {
case t: ExprType => t.resType
case t: MethodType => t.toFunctionType()
case t: PolyType =>
HKTypeLambda(t.paramNames)(
x => t.paramInfos.mapConserve(_.subst(t, x).asInstanceOf[TypeBounds]),
x => t.resType.subst(t, x).toFunctionType())
case t => t
}
val bindingExprTpe = AppliedType(defn.QuotedMatchingBindingType, bindingType :: Nil)
val sym = ctx0.newPatternBoundSymbol(ddef.name, bindingExprTpe, ddef.span)
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingExprTpe)).withSpan(ddef.span)
}
super.transform(tree)
case _ =>
super.transform(tree)
}
Expand Down
5 changes: 5 additions & 0 deletions library/src-bootstrapped/scala/internal/Quoted.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package scala.internal

import scala.annotation.Annotation
import scala.quoted._

object Quoted {
Expand All @@ -19,4 +20,8 @@ object Quoted {
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
def patternHole[T]: T =
throw new Error("Internal error: this method call should have been replaced by the compiler")

/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
class patternBindHole extends Annotation

}
34 changes: 29 additions & 5 deletions library/src-bootstrapped/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scala.internal.quoted
import scala.annotation.internal.sharable

import scala.quoted._
import scala.quoted.matching.Bind
import scala.tasty._

object Matcher {
Expand Down Expand Up @@ -30,7 +31,8 @@ object Matcher {
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
*/
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
import reflection._
import reflection.{Bind => BindPattern, _}

// TODO improve performance

/** Check that the trees match and return the contents from the pattern holes.
Expand All @@ -51,6 +53,18 @@ object Matcher {
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
}

def bindingMatch(sym: Symbol) =
Some(Tuple1(new Bind(sym.name, sym)))

def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
case Annotated(tpt2, _) => hasBindTypeAnnotation(tpt2)
case _ => false
}

def hasBindAnnotation(sym: Symbol) =
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }

def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
if (scrutinees.size != patterns.size) None
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
Expand Down Expand Up @@ -142,24 +156,30 @@ object Matcher {
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))

case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
val bindMatch =
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
else Some(())
val returnTptMatch = treeMatches(tpt1, tpt2)
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
foldMatchings(returnTptMatch, rhsMatchings)
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)

case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
val paramssMatch =
if (paramss1.size != paramss2.size) None
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
val bindMatch =
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
else Some(())
val tptMatch = treeMatches(tpt1, tpt2)
val rhsEnv =
env + (scrutinee.symbol -> pattern.symbol) ++
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)

foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)

case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
// TODO match tpt1 with tpt2?
Expand All @@ -180,6 +200,10 @@ object Matcher {
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
foldMatchings(bodyMacth, casesMatch, finalizerMatch)

// Ignore type annotations
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)

// No Match
case _ =>
if (debug)
Expand Down
34 changes: 34 additions & 0 deletions library/src-bootstrapped/scala/quoted/matching/Bind.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package scala.quoted
package matching

import scala.tasty.Reflection // TODO do not depend on reflection directly

/** Bind of an Expr[T] used to know if some Expr[T] is a reference to the binding
*
* @param name string name of this binding
* @param id unique id used for equality
*/
class Bind[T <: AnyKind] private[scala](val name: String, private[Bind] val id: Object) { self =>

override def equals(obj: Any): Boolean = obj match {
case obj: Bind[_] => obj.id == id
case _ => false
}

override def hashCode(): Int = id.hashCode()

}

object Bind {

def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[Bind[T]] = {
import reflect.{Bind => BindPattern, _}
expr.unseal match {
case IsIdent(ref) =>
val sym = ref.symbol
Some(new Bind[T](sym.name, sym))
case _ => None
}
}

}
Loading