Skip to content

Commit f61c631

Browse files
Merge pull request #7506 from dotty-staging/add-rewrite-prototype-and-fixes
Add rewrite prototype and couple of fixes
2 parents b5f2f18 + 845fe93 commit f61c631

File tree

15 files changed

+495
-73
lines changed

15 files changed

+495
-73
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
103103
def PackageClause_apply(pid: Ref, stats: List[Tree])(given Context): PackageClause =
104104
withDefaultPos(tpd.PackageDef(pid.asInstanceOf[tpd.RefTree], stats))
105105

106-
def PackageClause_copy(original: PackageClause)(pid: Ref, stats: List[Tree])(given Context): PackageClause =
106+
def PackageClause_copy(original: Tree)(pid: Ref, stats: List[Tree])(given Context): PackageClause =
107107
tpd.cpy.PackageDef(original)(pid, stats)
108108

109109
type Statement = tpd.Tree
@@ -128,7 +128,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
128128
def Import_apply(expr: Term, selectors: List[ImportSelector])(given Context): Import =
129129
withDefaultPos(tpd.Import(expr, selectors))
130130

131-
def Import_copy(original: Import)(expr: Term, selectors: List[ImportSelector])(given Context): Import =
131+
def Import_copy(original: Tree)(expr: Term, selectors: List[ImportSelector])(given Context): Import =
132132
tpd.cpy.Import(original)(expr, selectors)
133133

134134
type Definition = tpd.Tree
@@ -171,7 +171,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
171171
def ClassDef_body(self: ClassDef)(given Context): List[Statement] = ClassDef_rhs(self).body
172172
private def ClassDef_rhs(self: ClassDef) = self.rhs.asInstanceOf[tpd.Template]
173173

174-
def ClassDef_copy(original: ClassDef)(name: String, constr: DefDef, parents: List[Term | TypeTree], derived: List[TypeTree], selfOpt: Option[ValDef], body: List[Statement])(given Context): ClassDef = {
174+
def ClassDef_copy(original: Tree)(name: String, constr: DefDef, parents: List[Term | TypeTree], derived: List[TypeTree], selfOpt: Option[ValDef], body: List[Statement])(given Context): ClassDef = {
175175
val Trees.TypeDef(_, originalImpl: tpd.Template) = original
176176
tpd.cpy.TypeDef(original)(name.toTypeName, tpd.cpy.Template(originalImpl)(constr, parents, derived, selfOpt.getOrElse(tpd.EmptyValDef), body))
177177
}
@@ -186,7 +186,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
186186
def TypeDef_rhs(self: TypeDef)(given Context): TypeTree | TypeBoundsTree = self.rhs
187187

188188
def TypeDef_apply(symbol: Symbol)(given Context): TypeDef = withDefaultPos(tpd.TypeDef(symbol.asType))
189-
def TypeDef_copy(original: TypeDef)(name: String, rhs: TypeTree | TypeBoundsTree)(given Context): TypeDef =
189+
def TypeDef_copy(original: Tree)(name: String, rhs: TypeTree | TypeBoundsTree)(given Context): TypeDef =
190190
tpd.cpy.TypeDef(original)(name.toTypeName, rhs)
191191

192192
type DefDef = tpd.DefDef
@@ -204,7 +204,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
204204
def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given Context): DefDef =
205205
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))
206206

207-
def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
207+
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
208208
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))
209209

210210
type ValDef = tpd.ValDef
@@ -220,7 +220,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
220220
def ValDef_apply(symbol: Symbol, rhs: Option[Term])(given Context): ValDef =
221221
tpd.ValDef(symbol.asTerm, rhs.getOrElse(tpd.EmptyTree))
222222

223-
def ValDef_copy(original: ValDef)(name: String, tpt: TypeTree, rhs: Option[Term])(given Context): ValDef =
223+
def ValDef_copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term])(given Context): ValDef =
224224
tpd.cpy.ValDef(original)(name.toTermName, tpt, rhs.getOrElse(tpd.EmptyTree))
225225

226226
type Term = tpd.Tree
@@ -347,8 +347,8 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
347347
def NamedArg_apply(name: String, arg: Term)(given Context): NamedArg =
348348
withDefaultPos(tpd.NamedArg(name.toTermName, arg))
349349

350-
def NamedArg_copy(tree: NamedArg)(name: String, arg: Term)(given Context): NamedArg =
351-
tpd.cpy.NamedArg(tree)(name.toTermName, arg)
350+
def NamedArg_copy(original: Tree)(name: String, arg: Term)(given Context): NamedArg =
351+
tpd.cpy.NamedArg(original)(name.toTermName, arg)
352352

353353
type Apply = tpd.Apply
354354

@@ -672,7 +672,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
672672

673673
def TypeIdent_name(self: TypeIdent)(given Context): String = self.name.toString
674674

675-
def TypeIdent_copy(original: TypeIdent)(name: String)(given Context): TypeIdent =
675+
def TypeIdent_copy(original: Tree)(name: String)(given Context): TypeIdent =
676676
tpd.cpy.Ident(original)(name.toTypeName)
677677

678678
type TypeSelect = tpd.Select
@@ -688,7 +688,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
688688
def TypeSelect_apply(qualifier: Term, name: String)(given Context): TypeSelect =
689689
withDefaultPos(tpd.Select(qualifier, name.toTypeName))
690690

691-
def TypeSelect_copy(original: TypeSelect)(qualifier: Term, name: String)(given Context): TypeSelect =
691+
def TypeSelect_copy(original: Tree)(qualifier: Term, name: String)(given Context): TypeSelect =
692692
tpd.cpy.Select(original)(qualifier, name.toTypeName)
693693

694694

@@ -702,7 +702,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
702702
def Projection_qualifier(self: Projection)(given Context): TypeTree = self.qualifier
703703
def Projection_name(self: Projection)(given Context): String = self.name.toString
704704

705-
def Projection_copy(original: Projection)(qualifier: TypeTree, name: String)(given Context): Projection =
705+
def Projection_copy(original: Tree)(qualifier: TypeTree, name: String)(given Context): Projection =
706706
tpd.cpy.Select(original)(qualifier, name.toTypeName)
707707

708708
type Singleton = tpd.SingletonTypeTree
@@ -717,7 +717,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
717717
def Singleton_apply(ref: Term)(given Context): Singleton =
718718
withDefaultPos(tpd.SingletonTypeTree(ref))
719719

720-
def Singleton_copy(original: Singleton)(ref: Term)(given Context): Singleton =
720+
def Singleton_copy(original: Tree)(ref: Term)(given Context): Singleton =
721721
tpd.cpy.SingletonTypeTree(original)(ref)
722722

723723
type Refined = tpd.RefinedTypeTree
@@ -730,7 +730,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
730730
def Refined_tpt(self: Refined)(given Context): TypeTree = self.tpt
731731
def Refined_refinements(self: Refined)(given Context): List[Definition] = self.refinements
732732

733-
def Refined_copy(original: Refined)(tpt: TypeTree, refinements: List[Definition])(given Context): Refined =
733+
def Refined_copy(original: Tree)(tpt: TypeTree, refinements: List[Definition])(given Context): Refined =
734734
tpd.cpy.RefinedTypeTree(original)(tpt, refinements)
735735

736736
type Applied = tpd.AppliedTypeTree
@@ -746,7 +746,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
746746
def Applied_apply(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
747747
withDefaultPos(tpd.AppliedTypeTree(tpt, args))
748748

749-
def Applied_copy(original: Applied)(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
749+
def Applied_copy(original: Tree)(tpt: TypeTree, args: List[TypeTree | TypeBoundsTree])(given Context): Applied =
750750
tpd.cpy.AppliedTypeTree(original)(tpt, args)
751751

752752
type Annotated = tpd.Annotated
@@ -762,7 +762,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
762762
def Annotated_apply(arg: TypeTree, annotation: Term)(given Context): Annotated =
763763
withDefaultPos(tpd.Annotated(arg, annotation))
764764

765-
def Annotated_copy(original: Annotated)(arg: TypeTree, annotation: Term)(given Context): Annotated =
765+
def Annotated_copy(original: Tree)(arg: TypeTree, annotation: Term)(given Context): Annotated =
766766
tpd.cpy.Annotated(original)(arg, annotation)
767767

768768
type MatchTypeTree = tpd.MatchTypeTree
@@ -779,7 +779,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
779779
def MatchTypeTree_apply(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
780780
withDefaultPos(tpd.MatchTypeTree(bound.getOrElse(tpd.EmptyTree), selector, cases))
781781

782-
def MatchTypeTree_copy(original: MatchTypeTree)(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
782+
def MatchTypeTree_copy(original: Tree)(bound: Option[TypeTree], selector: TypeTree, cases: List[TypeCaseDef])(given Context): MatchTypeTree =
783783
tpd.cpy.MatchTypeTree(original)(bound.getOrElse(tpd.EmptyTree), selector, cases)
784784

785785
type ByName = tpd.ByNameTypeTree
@@ -794,7 +794,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
794794
def ByName_apply(result: TypeTree)(given Context): ByName =
795795
withDefaultPos(tpd.ByNameTypeTree(result))
796796

797-
def ByName_copy(original: ByName)(result: TypeTree)(given Context): ByName =
797+
def ByName_copy(original: Tree)(result: TypeTree)(given Context): ByName =
798798
tpd.cpy.ByNameTypeTree(original)(result)
799799

800800
type LambdaTypeTree = tpd.LambdaTypeTree
@@ -810,7 +810,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
810810
def Lambdaapply(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
811811
withDefaultPos(tpd.LambdaTypeTree(tparams, body))
812812

813-
def Lambdacopy(original: LambdaTypeTree)(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
813+
def Lambdacopy(original: Tree)(tparams: List[TypeDef], body: TypeTree | TypeBoundsTree)(given Context): LambdaTypeTree =
814814
tpd.cpy.LambdaTypeTree(original)(tparams, body)
815815

816816
type TypeBind = tpd.Bind
@@ -823,7 +823,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
823823
def TypeBind_name(self: TypeBind)(given Context): String = self.name.toString
824824
def TypeBind_body(self: TypeBind)(given Context): TypeTree | TypeBoundsTree = self.body
825825

826-
def TypeBind_copy(original: TypeBind)(name: String, tpt: TypeTree | TypeBoundsTree)(given Context): TypeBind =
826+
def TypeBind_copy(original: Tree)(name: String, tpt: TypeTree | TypeBoundsTree)(given Context): TypeBind =
827827
tpd.cpy.Bind(original)(name.toTypeName, tpt)
828828

829829
type TypeBlock = tpd.Block
@@ -839,7 +839,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
839839
def TypeBlock_apply(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
840840
withDefaultPos(tpd.Block(aliases, tpt))
841841

842-
def TypeBlock_copy(original: TypeBlock)(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
842+
def TypeBlock_copy(original: Tree)(aliases: List[TypeDef], tpt: TypeTree)(given Context): TypeBlock =
843843
tpd.cpy.Block(original)(aliases, tpt)
844844

845845
type TypeBoundsTree = tpd.TypeBoundsTree
@@ -883,7 +883,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
883883
def CaseDef_module_apply(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
884884
tpd.CaseDef(pattern, guard.getOrElse(tpd.EmptyTree), body)
885885

886-
def CaseDef_module_copy(original: CaseDef)(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
886+
def CaseDef_module_copy(original: Tree)(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
887887
tpd.cpy.CaseDef(original)(pattern, guard.getOrElse(tpd.EmptyTree), body)
888888

889889
type TypeCaseDef = tpd.CaseDef
@@ -899,7 +899,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
899899
def TypeCaseDef_module_apply(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
900900
tpd.CaseDef(pattern, tpd.EmptyTree, body)
901901

902-
def TypeCaseDef_module_copy(original: TypeCaseDef)(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
902+
def TypeCaseDef_module_copy(original: Tree)(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
903903
tpd.cpy.CaseDef(original)(pattern, tpd.EmptyTree, body)
904904

905905
type Bind = tpd.Bind
@@ -913,7 +913,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
913913

914914
def Tree_Bind_pattern(self: Bind)(given Context): Tree = self.body
915915

916-
def Tree_Bind_module_copy(original: Bind)(name: String, pattern: Tree)(given Context): Bind =
916+
def Tree_Bind_module_copy(original: Tree)(name: String, pattern: Tree)(given Context): Bind =
917917
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
918918

919919
type Unapply = tpd.UnApply
@@ -928,7 +928,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
928928
def Tree_Unapply_implicits(self: Unapply)(given Context): List[Term] = self.implicits
929929
def Tree_Unapply_patterns(self: Unapply)(given Context): List[Tree] = effectivePatterns(self.patterns)
930930

931-
def Tree_Unapply_module_copy(original: Unapply)(fun: Term, implicits: List[Term], patterns: List[Tree])(given Context): Unapply =
931+
def Tree_Unapply_module_copy(original: Tree)(fun: Term, implicits: List[Term], patterns: List[Tree])(given Context): Unapply =
932932
withDefaultPos(tpd.cpy.UnApply(original)(fun, implicits, patterns))
933933

934934
private def effectivePatterns(patterns: List[Tree]): List[Tree] = patterns match {
@@ -948,7 +948,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
948948
def Tree_Alternatives_module_apply(patterns: List[Tree])(given Context): Alternatives =
949949
withDefaultPos(tpd.Alternative(patterns))
950950

951-
def Tree_Alternatives_module_copy(original: Alternatives)(patterns: List[Tree])(given Context): Alternatives =
951+
def Tree_Alternatives_module_copy(original: Tree)(patterns: List[Tree])(given Context): Alternatives =
952952
tpd.cpy.Alternative(original)(patterns)
953953

954954
//

library/src/scala/quoted/Expr.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,14 @@ package internal {
190190
* May contain references to code defined outside this TastyTreeExpr instance.
191191
*/
192192
final class TastyTreeExpr[Tree](val tree: Tree, val scopeId: Int) extends Expr[Any] {
193+
override def equals(that: Any): Boolean = that match {
194+
case that: TastyTreeExpr[_] =>
195+
// TastyTreeExpr are wrappers around trees, therfore they are equals if their trees are equal.
196+
// All scopeId should be equal unless two different runs of the compiler created the trees.
197+
tree == that.tree && scopeId == that.scopeId
198+
case _ => false
199+
}
200+
override def hashCode: Int = tree.hashCode
193201
override def toString: String = s"Expr(<tasty tree>)"
194202
}
195203

library/src/scala/quoted/Type.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ package internal {
7171

7272
/** An Type backed by a tree */
7373
final class TreeType[Tree](val typeTree: Tree, val scopeId: Int) extends scala.quoted.Type[Any] {
74+
override def equals(that: Any): Boolean = that match {
75+
case that: TreeType[_] => typeTree ==
76+
// TastyTreeExpr are wrappers around trees, therfore they are equals if their trees are equal.
77+
// All scopeId should be equal unless two different runs of the compiler created the trees.
78+
that.typeTree && scopeId == that.scopeId
79+
case _ => false
80+
}
81+
override def hashCode: Int = typeTree.hashCode
7482
override def toString: String = s"Type(<tasty tree>)"
7583
}
7684

0 commit comments

Comments
 (0)