Skip to content

Add syntax for higher order quote pattern holes #8876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ object desugar {
def quotedPatternTypeDef(tree: TypeDef)(implicit ctx: Context): TypeDef = {
assert(ctx.mode.is(Mode.QuotedPattern))
if (tree.name.startsWith("$") && !tree.isBackquoted) {
val patternBindHoleAnnot = New(ref(defn.InternalQuoted_patternTypeAnnot.typeRef)).withSpan(tree.span)
val patternBindHoleAnnot = New(ref(defn.InternalQuotedMatcher_patternTypeAnnot.typeRef)).withSpan(tree.span)
val mods = tree.mods.withAddedAnnotation(patternBindHoleAnnot)
tree.withMods(mods)
}
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
}
case class Throw(expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class Quote(quoted: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class Splice(expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class Splice(expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree {
def isInBraces: Boolean = span.end != expr.span.end
}
case class TypSplice(expr: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
Expand Down
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,14 @@ class Definitions {
@tu lazy val InternalQuoted_exprSplice : Symbol = InternalQuotedModule.requiredMethod("exprSplice")
@tu lazy val InternalQuoted_exprNestedSplice : Symbol = InternalQuotedModule.requiredMethod("exprNestedSplice")
@tu lazy val InternalQuoted_typeQuote : Symbol = InternalQuotedModule.requiredMethod("typeQuote")
@tu lazy val InternalQuoted_patternHole: Symbol = InternalQuotedModule.requiredMethod("patternHole")
@tu lazy val InternalQuoted_patternTypeAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternType")
@tu lazy val InternalQuoted_QuoteTypeTagAnnot: ClassSymbol = InternalQuotedModule.requiredClass("quoteTypeTag")
@tu lazy val InternalQuoted_fromAboveAnnot: ClassSymbol = InternalQuotedModule.requiredClass("fromAbove")

@tu lazy val InternalQuotedMatcher: Symbol = ctx.requiredModule("scala.internal.quoted.Matcher")
@tu lazy val InternalQuotedMatcher_patternHole: Symbol = InternalQuotedMatcher.requiredMethod("patternHole")
@tu lazy val InternalQuotedMatcher_patternHigherOrderHole: Symbol = InternalQuotedMatcher.requiredMethod("patternHigherOrderHole")
@tu lazy val InternalQuotedMatcher_higherOrderHole: Symbol = InternalQuotedMatcher.requiredMethod("higherOrderHole")
@tu lazy val InternalQuotedMatcher_patternTypeAnnot: ClassSymbol = InternalQuotedMatcher.requiredClass("patternType")
@tu lazy val InternalQuotedMatcher_fromAboveAnnot: ClassSymbol = InternalQuotedMatcher.requiredClass("fromAbove")

@tu lazy val InternalQuotedExprModule: Symbol = ctx.requiredModule("scala.internal.quoted.Expr")
@tu lazy val InternalQuotedExpr_unapply: Symbol = InternalQuotedExprModule.requiredMethod(nme.unapply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1982,9 +1982,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass
def Definitions_isTupleClass(sym: Symbol): Boolean = defn.isTupleClass(sym)

def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
def Definitions_InternalQuoted_patternTypeAnnot: Symbol = defn.InternalQuoted_patternTypeAnnot
def Definitions_InternalQuoted_fromAboveAnnot: Symbol = defn.InternalQuoted_fromAboveAnnot
def Definitions_InternalQuotedMatcher_patternHole: Symbol = defn.InternalQuotedMatcher_patternHole
def Definitions_InternalQuotedMatcher_higherOrderHole: Symbol = defn.InternalQuotedMatcher_higherOrderHole
def Definitions_InternalQuotedMatcher_patternTypeAnnot: Symbol = defn.InternalQuotedMatcher_patternTypeAnnot
def Definitions_InternalQuotedMatcher_fromAboveAnnot: Symbol = defn.InternalQuotedMatcher_fromAboveAnnot

// Types

Expand Down
37 changes: 29 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,27 @@ trait QuotesAndSplices {
def typedAppliedSplice(tree: untpd.Apply, pt: Type)(using Context): Tree = {
assert(ctx.mode.is(Mode.QuotedPattern))
val untpd.Apply(splice: untpd.Splice, args) = tree
if (isFullyDefined(pt, ForceDegree.flipBottom)) then
if !isFullyDefined(pt, ForceDegree.flipBottom) then
ctx.error(i"Type must be fully defined.", splice.sourcePos)
tree.withType(UnspecifiedErrorType)
else if splice.isInBraces then // ${x}(...) match an application
val typedArgs = args.map(arg => typedExpr(arg))
val argTypes = typedArgs.map(_.tpe.widenTermRefExpr)
val splice1 = typedSplice(splice, defn.FunctionOf(argTypes, pt))
Apply(splice1.select(nme.apply), typedArgs).withType(pt).withSpan(tree.span)
else
ctx.error(i"Type must be fully defined.", splice.sourcePos)
tree.withType(UnspecifiedErrorType)
else // $x(...) higher-order quasipattern
val typedArgs = args.map {
case arg: untpd.Ident =>
typedExpr(arg)
case arg =>
ctx.error("Open patttern exprected an identifier", arg.sourcePos)
EmptyTree
}
if args.isEmpty then
ctx.error("Missing arguments for open pattern", tree.sourcePos)
val argTypes = typedArgs.map(_.tpe.widenTermRefExpr)
val typedPat = typedSplice(splice, defn.FunctionOf(argTypes, pt))
ref(defn.InternalQuotedMatcher_patternHigherOrderHole).appliedToType(pt).appliedTo(typedPat, SeqLiteral(typedArgs, TypeTree(defn.AnyType)))
}

/** Translate ${ t: Type[T] }` into type `t.splice` while tracking the quotation level in the context */
Expand Down Expand Up @@ -154,7 +167,7 @@ trait QuotesAndSplices {
case pt: TypeBounds => pt
case _ => TypeBounds.empty
val typeSym = ctx.newSymbol(spliceOwner(ctx), name, EmptyFlags, typeSymInfo, NoSymbol, tree.expr.span)
typeSym.addAnnotation(Annotation(New(ref(defn.InternalQuoted_patternTypeAnnot.typeRef)).withSpan(tree.expr.span)))
typeSym.addAnnotation(Annotation(New(ref(defn.InternalQuotedMatcher_patternTypeAnnot.typeRef)).withSpan(tree.expr.span)))
val pat = typedPattern(tree.expr, defn.QuotedTypeClass.typeRef.appliedTo(typeSym.typeRef))(
using spliceContext.retractMode(Mode.QuotedPattern).withOwner(spliceOwner(ctx)))
pat.select(tpnme.splice)
Expand Down Expand Up @@ -224,8 +237,16 @@ trait QuotesAndSplices {
val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprClass.typeRef), tpt1 :: Nil)
val newSplice = ref(defn.InternalQuoted_exprSplice).appliedToType(tpt1.tpe).appliedTo(Typed(pat, exprTpt))
transform(newSplice)
case Apply(TypeApply(fn, targs), Apply(sp, pat :: Nil) :: args :: Nil) if fn.symbol == defn.InternalQuotedMatcher_patternHigherOrderHole =>
try ref(defn.InternalQuotedMatcher_higherOrderHole.termRef).appliedToTypeTrees(targs).appliedTo(args).withSpan(tree.span)
finally {
val patType = pat.tpe.widen
val patType1 = patType.translateFromRepeated(toArray = false)
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
}
case Apply(fn, pat :: Nil) if fn.symbol == defn.InternalQuoted_exprSplice =>
try ref(defn.InternalQuoted_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
try ref(defn.InternalQuotedMatcher_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
finally {
val patType = pat.tpe.widen
val patType1 = patType.translateFromRepeated(toArray = false)
Expand All @@ -241,7 +262,7 @@ trait QuotesAndSplices {
else
tree
case tdef: TypeDef =>
if tdef.symbol.hasAnnotation(defn.InternalQuoted_patternTypeAnnot) then
if tdef.symbol.hasAnnotation(defn.InternalQuotedMatcher_patternTypeAnnot) then
transformTypeBindingTypeDef(tdef, typePatBuf)
else if tdef.symbol.isClass then
val kind = if tdef.symbol.is(Module) then "objects" else "classes"
Expand Down Expand Up @@ -276,7 +297,7 @@ trait QuotesAndSplices {

private def transformTypeBindingTypeDef(tdef: TypeDef, buff: mutable.Builder[Tree, List[Tree]])(using Context): Tree = {
if (variance == -1)
tdef.symbol.addAnnotation(Annotation(New(ref(defn.InternalQuoted_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
tdef.symbol.addAnnotation(Annotation(New(ref(defn.InternalQuotedMatcher_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
val bindingType = getBinding(tdef.symbol).symbol.typeRef
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
val bindName = tdef.name.toString.stripPrefix("$").toTermName
Expand Down
46 changes: 46 additions & 0 deletions docs/docs/reference/metaprogramming/macros.md
Original file line number Diff line number Diff line change
Expand Up @@ -744,5 +744,51 @@ trait Show[-T] {
}
```

#### Open code patterns

Quote pattern matching also provides higher-order patterns to match open terms. If a quoted term contains a definition,
then the rest of the quote can refer to this definition.
```
'{
val x: Int = 4
x * x
}
```

To match such a term we need to match the definition and the rest of the code, but we need to expicilty state that the rest of the code may refer to this definition.
```scala
case '{ val y: Int = $x; $body(y): Int } =>
```
Here `$x` will match any closed expression while `$body(y)` will match expression that is closed under `y`. Then
the subxpression of type `Expr[Int]` is bound to `body` as an `Expr[Int => Int]`. The extra argument represents the references to `y`. Usually this expression is used in compination with `Expr.betaReduce` to replace the extra argument.

```scala
inline def eval(inline e: Int): Int = ${ evalExpr('e) }

private def evalExpr(using QuoteContext)(e: Expr[Int]): Expr[Int] = {
e match {
case '{ val y: Int = $x; $body(y): Int } =>
// body: Expr[Int => Int] where the argument represents references to y
evalExpr(Expr.betaReduce(body)(evalExpr(x)))
case '{ ($x: Int) * ($y: Int) } =>
(x, y) match
case (Const(a), Const(b)) => Expr(a * b)
case _ => e
case _ => e
}
}
```

```scala
eval { // expands to the code: (16: Int)
val x: Int = 4
x * x
}
```

We can also close over several bindings using `$b(a1, a2, ..., an)`.
To match an actual application we can use braces on the function part `${b}(a1, a2, ..., an)`.


### More details
[More details](./macros-spec.md)
17 changes: 0 additions & 17 deletions library/src-bootstrapped/scala/internal/quoted/CompileTime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,6 @@ object CompileTime {
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.typeQuote`")
def typeQuote[T <: AnyKind]: QuoteContext ?=> Type[T] = ???

/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternHole`")
def patternHole[T]: T = ???

// TODO remove
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternBindHole`")
class patternBindHole extends Annotation

/** A splice of a name in a quoted pattern is that marks the definition of a type splice */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternType`")
class patternType extends Annotation

/** A type pattern that must be aproximated from above */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.fromAbove`")
class fromAbove extends Annotation

/** Artifact of pickled type splices
*
* During quote reification a quote `'{ ... F[$t] ... }` will be transformed into
Expand Down
53 changes: 41 additions & 12 deletions library/src/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scala.internal.quoted

import scala.annotation.internal.sharable
import scala.annotation.{Annotation, compileTimeOnly}

import scala.quoted._

Expand Down Expand Up @@ -94,7 +95,32 @@ import scala.quoted._
*
* ```
*/
private[quoted] object Matcher {
object Matcher {

/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternHole`")
def patternHole[T]: T = ???

@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternHigherOrderHole`")
/** A higher order splice in a quoted pattern is desugared by the compiler into a call to this method */
def patternHigherOrderHole[U](pat: Any, args: Any*): U = ???

@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.higherOrderHole`")
/** A higher order splice in a quoted pattern is desugared by the compiler into a call to this method */
def higherOrderHole[U](args: Any*): U = ???

// TODO remove
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternBindHole`")
class patternBindHole extends Annotation

/** A splice of a name in a quoted pattern is that marks the definition of a type splice */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.patternType`")
class patternType extends Annotation

/** A type pattern that must be aproximated from above */
@compileTimeOnly("Illegal reference to `scala.internal.quoted.CompileTime.fromAbove`")
class fromAbove extends Annotation

class QuoteMatcher[QCtx <: QuoteContext & Singleton](using val qctx: QCtx) {
// TODO improve performance
Expand Down Expand Up @@ -164,13 +190,13 @@ private[quoted] object Matcher {
private def hasFromAboveAnnotation(sym: Symbol) = sym.annots.exists(isFromAboveAnnotation)

private def isPatternTypeAnnotation(tree: Tree): Boolean = tree match {
case New(tpt) => tpt.symbol == internal.Definitions_InternalQuoted_patternTypeAnnot
case annot => annot.symbol.owner == internal.Definitions_InternalQuoted_patternTypeAnnot
case New(tpt) => tpt.symbol == internal.Definitions_InternalQuotedMatcher_patternTypeAnnot
case annot => annot.symbol.owner == internal.Definitions_InternalQuotedMatcher_patternTypeAnnot
}

private def isFromAboveAnnotation(tree: Tree): Boolean = tree match {
case New(tpt) => tpt.symbol == internal.Definitions_InternalQuoted_fromAboveAnnot
case annot => annot.symbol.owner == internal.Definitions_InternalQuoted_fromAboveAnnot
case New(tpt) => tpt.symbol == internal.Definitions_InternalQuotedMatcher_fromAboveAnnot
case annot => annot.symbol.owner == internal.Definitions_InternalQuotedMatcher_fromAboveAnnot
}

/** Check that all trees match with `mtch` and concatenate the results with &&& */
Expand Down Expand Up @@ -226,23 +252,23 @@ private[quoted] object Matcher {
/* Term hole */
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
if patternHole.symbol == internal.Definitions_InternalQuotedMatcher_patternHole &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
matched(scrutinee.seal)

/* Term hole */
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
if patternHole.symbol == internal.Definitions_InternalQuotedMatcher_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
matched(scrutinee.seal)

/* Higher order term hole */
// Matches an open term and wraps it into a lambda that provides the free variables
// TODO do not encode with `hole`. Maybe use `higherOrderHole[(T1, ..., Tn) => R]((x1: T1, ..., xn: Tn)): R`
case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args)))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
case (scrutinee, pattern @ Apply(TypeApply(Ident("higherOrderHole"), List(Inferred())), Repeated(args, _) :: Nil))
if pattern.symbol == internal.Definitions_InternalQuotedMatcher_higherOrderHole =>

def bodyFn(lambdaArgs: List[Tree]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
new TreeMap {
Expand All @@ -252,8 +278,11 @@ private[quoted] object Matcher {
case tree => super.transformTerm(tree)
}.transformTree(scrutinee)
}
val names = args.map(_.name)
val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
val names = args.map {
case Block(List(DefDef("$anonfun", _, _, _, Some(Apply(Ident(name), _)))), _) => name
case arg => arg.symbol.name
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val resType = pattern.tpe
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
matched(res.seal)
Expand Down
15 changes: 9 additions & 6 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1511,14 +1511,17 @@ trait CompilerInterface {
def Definitions_TupleClass(arity: Int): Symbol
def Definitions_isTupleClass(sym: Symbol): Boolean

/** Symbol of scala.internal.Quoted.patternHole */
def Definitions_InternalQuoted_patternHole: Symbol
/** Symbol of scala.internal.CompileTime.patternHole */
def Definitions_InternalQuotedMatcher_patternHole: Symbol

/** Symbol of scala.internal.Quoted.patternType */
def Definitions_InternalQuoted_patternTypeAnnot: Symbol
/** Symbol of scala.internal.CompileTime.higherOrderHole */
def Definitions_InternalQuotedMatcher_higherOrderHole: Symbol

/** Symbol of scala.internal.Quoted.fromAbove */
def Definitions_InternalQuoted_fromAboveAnnot: Symbol
/** Symbol of scala.internal.CompileTime.patternType */
def Definitions_InternalQuotedMatcher_patternTypeAnnot: Symbol

/** Symbol of scala.internal.CompileTime.fromAbove */
def Definitions_InternalQuotedMatcher_fromAboveAnnot: Symbol

def Definitions_UnitType: Type
def Definitions_ByteType: Type
Expand Down
5 changes: 5 additions & 0 deletions tests/neg/quote-open-patterns-stages.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import scala.quoted._

def f(using QuoteContext)(x: Expr[Any]) = x match {
case '{ identity($y(x)) } => // error: access to value x from wrong staging level
}
6 changes: 6 additions & 0 deletions tests/neg/quote-open-patterns-typer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.quoted._

def f(using QuoteContext)(x: Expr[Any]) = x match {
case '{ val a: Int = 3; $y(identity(a)) } => // error: Exprected an identifier
case '{ identity($y()) } => // error: Missing arguments for open pattern
}
6 changes: 3 additions & 3 deletions tests/pos/quoted-splice-pattern-applied.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import scala.quoted._

def f(x: Expr[Int])(using QuoteContext) = x match {
case '{ $f($a: Int): Int } =>
case '{ ${f}($a: Int): Int } =>
val f1: Expr[Int => Int] = f
val a1: Expr[Int] = a
case '{ def a: Int = $f($b: Int); () } =>
case '{ def a: Int = ${f}($b: Int); () } =>
val f1: Expr[Int => Int] = f
val b1: Expr[Int] = b
case '{ val a: Int = 3; $f(a): Int } =>
case '{ val a: Int = 3; ${f}(a): Int } =>
val f1: Expr[Int => Int] = f
}
Loading