Skip to content

Add rewrite prototype and couple of fixes #7506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def PackageClause_apply(pid: Ref, stats: List[Tree])(given Context): PackageClause =
withDefaultPos(tpd.PackageDef(pid.asInstanceOf[tpd.RefTree], stats))

def PackageClause_copy(original: PackageClause)(pid: Ref, stats: List[Tree])(given Context): PackageClause =
def PackageClause_copy(original: Tree)(pid: Ref, stats: List[Tree])(given Context): PackageClause =
tpd.cpy.PackageDef(original)(pid, stats)

type Statement = tpd.Tree
Expand All @@ -128,7 +128,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Import_apply(expr: Term, selectors: List[ImportSelector])(given Context): Import =
withDefaultPos(tpd.Import(expr, selectors))

def Import_copy(original: Import)(expr: Term, selectors: List[ImportSelector])(given Context): Import =
def Import_copy(original: Tree)(expr: Term, selectors: List[ImportSelector])(given Context): Import =
tpd.cpy.Import(original)(expr, selectors)

type Definition = tpd.Tree
Expand Down Expand Up @@ -171,7 +171,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def ClassDef_body(self: ClassDef)(given Context): List[Statement] = ClassDef_rhs(self).body
private def ClassDef_rhs(self: ClassDef) = self.rhs.asInstanceOf[tpd.Template]

def ClassDef_copy(original: ClassDef)(name: String, constr: DefDef, parents: List[Term | TypeTree], derived: List[TypeTree], selfOpt: Option[ValDef], body: List[Statement])(given Context): ClassDef = {
def ClassDef_copy(original: Tree)(name: String, constr: DefDef, parents: List[Term | TypeTree], derived: List[TypeTree], selfOpt: Option[ValDef], body: List[Statement])(given Context): ClassDef = {
val Trees.TypeDef(_, originalImpl: tpd.Template) = original
tpd.cpy.TypeDef(original)(name.toTypeName, tpd.cpy.Template(originalImpl)(constr, parents, derived, selfOpt.getOrElse(tpd.EmptyValDef), body))
}
Expand All @@ -186,7 +186,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeDef_rhs(self: TypeDef)(given Context): TypeTree | TypeBoundsTree = self.rhs

def TypeDef_apply(symbol: Symbol)(given Context): TypeDef = withDefaultPos(tpd.TypeDef(symbol.asType))
def TypeDef_copy(original: TypeDef)(name: String, rhs: TypeTree | TypeBoundsTree)(given Context): TypeDef =
def TypeDef_copy(original: Tree)(name: String, rhs: TypeTree | TypeBoundsTree)(given Context): TypeDef =
tpd.cpy.TypeDef(original)(name.toTypeName, rhs)

type DefDef = tpd.DefDef
Expand All @@ -204,7 +204,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given Context): DefDef =
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))

def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))

type ValDef = tpd.ValDef
Expand All @@ -220,7 +220,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def ValDef_apply(symbol: Symbol, rhs: Option[Term])(given Context): ValDef =
tpd.ValDef(symbol.asTerm, rhs.getOrElse(tpd.EmptyTree))

def ValDef_copy(original: ValDef)(name: String, tpt: TypeTree, rhs: Option[Term])(given Context): ValDef =
def ValDef_copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term])(given Context): ValDef =
tpd.cpy.ValDef(original)(name.toTermName, tpt, rhs.getOrElse(tpd.EmptyTree))

type Term = tpd.Tree
Expand Down Expand Up @@ -347,8 +347,8 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def NamedArg_apply(name: String, arg: Term)(given Context): NamedArg =
withDefaultPos(tpd.NamedArg(name.toTermName, arg))

def NamedArg_copy(tree: NamedArg)(name: String, arg: Term)(given Context): NamedArg =
tpd.cpy.NamedArg(tree)(name.toTermName, arg)
def NamedArg_copy(original: Tree)(name: String, arg: Term)(given Context): NamedArg =
tpd.cpy.NamedArg(original)(name.toTermName, arg)

type Apply = tpd.Apply

Expand Down Expand Up @@ -672,7 +672,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend

def TypeIdent_name(self: TypeIdent)(given Context): String = self.name.toString

def TypeIdent_copy(original: TypeIdent)(name: String)(given Context): TypeIdent =
def TypeIdent_copy(original: Tree)(name: String)(given Context): TypeIdent =
tpd.cpy.Ident(original)(name.toTypeName)

type TypeSelect = tpd.Select
Expand All @@ -688,7 +688,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeSelect_apply(qualifier: Term, name: String)(given Context): TypeSelect =
withDefaultPos(tpd.Select(qualifier, name.toTypeName))

def TypeSelect_copy(original: TypeSelect)(qualifier: Term, name: String)(given Context): TypeSelect =
def TypeSelect_copy(original: Tree)(qualifier: Term, name: String)(given Context): TypeSelect =
tpd.cpy.Select(original)(qualifier, name.toTypeName)


Expand All @@ -702,7 +702,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Projection_qualifier(self: Projection)(given Context): TypeTree = self.qualifier
def Projection_name(self: Projection)(given Context): String = self.name.toString

def Projection_copy(original: Projection)(qualifier: TypeTree, name: String)(given Context): Projection =
def Projection_copy(original: Tree)(qualifier: TypeTree, name: String)(given Context): Projection =
tpd.cpy.Select(original)(qualifier, name.toTypeName)

type Singleton = tpd.SingletonTypeTree
Expand All @@ -717,7 +717,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Singleton_apply(ref: Term)(given Context): Singleton =
withDefaultPos(tpd.SingletonTypeTree(ref))

def Singleton_copy(original: Singleton)(ref: Term)(given Context): Singleton =
def Singleton_copy(original: Tree)(ref: Term)(given Context): Singleton =
tpd.cpy.SingletonTypeTree(original)(ref)

type Refined = tpd.RefinedTypeTree
Expand All @@ -730,7 +730,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Refined_tpt(self: Refined)(given Context): TypeTree = self.tpt
def Refined_refinements(self: Refined)(given Context): List[Definition] = self.refinements

def Refined_copy(original: Refined)(tpt: TypeTree, refinements: List[Definition])(given Context): Refined =
def Refined_copy(original: Tree)(tpt: TypeTree, refinements: List[Definition])(given Context): Refined =
tpd.cpy.RefinedTypeTree(original)(tpt, refinements)

type Applied = tpd.AppliedTypeTree
Expand All @@ -746,7 +746,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Applied_apply(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
withDefaultPos(tpd.AppliedTypeTree(tpt, args))

def Applied_copy(original: Applied)(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
def Applied_copy(original: Tree)(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
tpd.cpy.AppliedTypeTree(original)(tpt, args)

type Annotated = tpd.Annotated
Expand All @@ -762,7 +762,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Annotated_apply(arg: TypeTree, annotation: Term)(given Context): Annotated =
withDefaultPos(tpd.Annotated(arg, annotation))

def Annotated_copy(original: Annotated)(arg: TypeTree, annotation: Term)(given Context): Annotated =
def Annotated_copy(original: Tree)(arg: TypeTree, annotation: Term)(given Context): Annotated =
tpd.cpy.Annotated(original)(arg, annotation)

type MatchTypeTree = tpd.MatchTypeTree
Expand All @@ -779,7 +779,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def MatchTypeTree_apply(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
withDefaultPos(tpd.MatchTypeTree(bound.getOrElse(tpd.EmptyTree), selector, cases))

def MatchTypeTree_copy(original: MatchTypeTree)(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
def MatchTypeTree_copy(original: Tree)(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
tpd.cpy.MatchTypeTree(original)(bound.getOrElse(tpd.EmptyTree), selector, cases)

type ByName = tpd.ByNameTypeTree
Expand All @@ -794,7 +794,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def ByName_apply(result: TypeTree)(given Context): ByName =
withDefaultPos(tpd.ByNameTypeTree(result))

def ByName_copy(original: ByName)(result: TypeTree)(given Context): ByName =
def ByName_copy(original: Tree)(result: TypeTree)(given Context): ByName =
tpd.cpy.ByNameTypeTree(original)(result)

type LambdaTypeTree = tpd.LambdaTypeTree
Expand All @@ -810,7 +810,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Lambdaapply(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
withDefaultPos(tpd.LambdaTypeTree(tparams, body))

def Lambdacopy(original: LambdaTypeTree)(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
def Lambdacopy(original: Tree)(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
tpd.cpy.LambdaTypeTree(original)(tparams, body)

type TypeBind = tpd.Bind
Expand All @@ -823,7 +823,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeBind_name(self: TypeBind)(given Context): String = self.name.toString
def TypeBind_body(self: TypeBind)(given Context): TypeTree | TypeBoundsTree = self.body

def TypeBind_copy(original: TypeBind)(name: String, tpt: TypeTree | TypeBoundsTree)(given Context): TypeBind =
def TypeBind_copy(original: Tree)(name: String, tpt: TypeTree | TypeBoundsTree)(given Context): TypeBind =
tpd.cpy.Bind(original)(name.toTypeName, tpt)

type TypeBlock = tpd.Block
Expand All @@ -839,7 +839,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeBlock_apply(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
withDefaultPos(tpd.Block(aliases, tpt))

def TypeBlock_copy(original: TypeBlock)(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
def TypeBlock_copy(original: Tree)(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
tpd.cpy.Block(original)(aliases, tpt)

type TypeBoundsTree = tpd.TypeBoundsTree
Expand Down Expand Up @@ -883,7 +883,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def CaseDef_module_apply(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
tpd.CaseDef(pattern, guard.getOrElse(tpd.EmptyTree), body)

def CaseDef_module_copy(original: CaseDef)(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
def CaseDef_module_copy(original: Tree)(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
tpd.cpy.CaseDef(original)(pattern, guard.getOrElse(tpd.EmptyTree), body)

type TypeCaseDef = tpd.CaseDef
Expand All @@ -899,7 +899,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeCaseDef_module_apply(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
tpd.CaseDef(pattern, tpd.EmptyTree, body)

def TypeCaseDef_module_copy(original: TypeCaseDef)(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
def TypeCaseDef_module_copy(original: Tree)(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
tpd.cpy.CaseDef(original)(pattern, tpd.EmptyTree, body)

type Bind = tpd.Bind
Expand All @@ -913,7 +913,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend

def Tree_Bind_pattern(self: Bind)(given Context): Tree = self.body

def Tree_Bind_module_copy(original: Bind)(name: String, pattern: Tree)(given Context): Bind =
def Tree_Bind_module_copy(original: Tree)(name: String, pattern: Tree)(given Context): Bind =
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))

type Unapply = tpd.UnApply
Expand All @@ -928,7 +928,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Tree_Unapply_implicits(self: Unapply)(given Context): List[Term] = self.implicits
def Tree_Unapply_patterns(self: Unapply)(given Context): List[Tree] = effectivePatterns(self.patterns)

def Tree_Unapply_module_copy(original: Unapply)(fun: Term, implicits: List[Term], patterns: List[Tree])(given Context): Unapply =
def Tree_Unapply_module_copy(original: Tree)(fun: Term, implicits: List[Term], patterns: List[Tree])(given Context): Unapply =
withDefaultPos(tpd.cpy.UnApply(original)(fun, implicits, patterns))

private def effectivePatterns(patterns: List[Tree]): List[Tree] = patterns match {
Expand All @@ -948,7 +948,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Tree_Alternatives_module_apply(patterns: List[Tree])(given Context): Alternatives =
withDefaultPos(tpd.Alternative(patterns))

def Tree_Alternatives_module_copy(original: Alternatives)(patterns: List[Tree])(given Context): Alternatives =
def Tree_Alternatives_module_copy(original: Tree)(patterns: List[Tree])(given Context): Alternatives =
tpd.cpy.Alternative(original)(patterns)

//
Expand Down
8 changes: 8 additions & 0 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ package internal {
* May contain references to code defined outside this TastyTreeExpr instance.
*/
final class TastyTreeExpr[Tree](val tree: Tree, val scopeId: Int) extends Expr[Any] {
override def equals(that: Any): Boolean = that match {
case that: TastyTreeExpr[_] =>
// TastyTreeExpr are wrappers around trees, therfore they are equals if their trees are equal.
// All scopeId should be equal unless two different runs of the compiler created the trees.
tree == that.tree && scopeId == that.scopeId
case _ => false
}
override def hashCode: Int = tree.hashCode
override def toString: String = s"Expr(<tasty tree>)"
}

Expand Down
8 changes: 8 additions & 0 deletions library/src/scala/quoted/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ package internal {

/** An Type backed by a tree */
final class TreeType[Tree](val typeTree: Tree, val scopeId: Int) extends scala.quoted.Type[Any] {
override def equals(that: Any): Boolean = that match {
case that: TreeType[_] => typeTree ==
// TastyTreeExpr are wrappers around trees, therfore they are equals if their trees are equal.
// All scopeId should be equal unless two different runs of the compiler created the trees.
that.typeTree && scopeId == that.scopeId
case _ => false
}
override def hashCode: Int = typeTree.hashCode
override def toString: String = s"Type(<tasty tree>)"
}

Expand Down
Loading