Skip to content

Join Tree and Pattern in TASTy Reflection #7344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
type Statement = tpd.Tree

def matchStatement(tree: Tree)(given Context): Option[Statement] = tree match {
case _: PatternTree => None
case tree if tree.isTerm => Some(tree)
case _ => matchDefinition(tree)
}
Expand Down Expand Up @@ -231,6 +232,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
type Term = tpd.Tree

def matchTerm(tree: Tree)(given Context): Option[Term] = tree match {
case _: PatternTree => None
case x: tpd.SeqLiteral => Some(tree)
case _ if tree.isTerm => Some(tree)
case _ => None
Expand Down Expand Up @@ -884,14 +886,14 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
case _ => None
}

def CaseDef_pattern(self: CaseDef)(given Context): Pattern = self.pat
def CaseDef_pattern(self: CaseDef)(given Context): Tree = self.pat
def CaseDef_guard(self: CaseDef)(given Context): Option[Term] = optional(self.guard)
def CaseDef_rhs(self: CaseDef)(given Context): Term = self.body

def CaseDef_module_apply(pattern: Pattern, guard: Option[Term], body: Term)(given Context): CaseDef =
def CaseDef_module_apply(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
tpd.CaseDef(pattern, guard.getOrElse(tpd.EmptyTree), body)

def CaseDef_module_copy(original: CaseDef)(pattern: Pattern, guard: Option[Term], body: Term)(given Context): CaseDef =
def CaseDef_module_copy(original: CaseDef)(pattern: Tree, guard: Option[Term], body: Term)(given Context): CaseDef =
tpd.cpy.CaseDef(original)(pattern, guard.getOrElse(tpd.EmptyTree), body)

type TypeCaseDef = tpd.CaseDef
Expand All @@ -910,114 +912,55 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def TypeCaseDef_module_copy(original: TypeCaseDef)(pattern: TypeTree, body: TypeTree)(given Context): TypeCaseDef =
tpd.cpy.CaseDef(original)(pattern, tpd.EmptyTree, body)

//
// PATTERNS
//

type Pattern = tpd.Tree

def Pattern_pos(self: Pattern)(given Context): Position = self.sourcePos
def Pattern_tpe(self: Pattern)(given Context): Type = self.tpe.stripTypeVar
def Pattern_symbol(self: Pattern)(given Context): Symbol = self.symbol

type Value = tpd.Tree

def matchPattern_Value(pattern: Pattern): Option[Value] = pattern match {
case lit: tpd.Literal => Some(lit)
case ref: tpd.RefTree if ref.isTerm && !tpd.isWildcardArg(ref) => Some(ref)
case ths: tpd.This => Some(ths)
case _ => None
}

def Pattern_Value_value(self: Value)(given Context): Term = self

def Pattern_Value_module_apply(term: Term)(given Context): Value = term match {
case lit: tpd.Literal => lit
case ref: tpd.RefTree if ref.isTerm => ref
case ths: tpd.This => ths
}
def Pattern_Value_module_copy(original: Value)(term: Term)(given Context): Value = term match {
case lit: tpd.Literal => tpd.cpy.Literal(original)(lit.const)
case ref: tpd.RefTree if ref.isTerm => tpd.cpy.Ref(original.asInstanceOf[tpd.RefTree])(ref.name)
case ths: tpd.This => tpd.cpy.This(original)(ths.qual)
}

type Bind = tpd.Bind

def matchPattern_Bind(x: Pattern)(given Context): Option[Bind] = x match {
def matchTree_Bind(x: Tree)(given Context): Option[Bind] = x match {
case x: tpd.Bind if x.name.isTermName => Some(x)
case _ => None
}

def Pattern_Bind_name(self: Bind)(given Context): String = self.name.toString
def Tree_Bind_name(self: Bind)(given Context): String = self.name.toString

def Pattern_Bind_pattern(self: Bind)(given Context): Pattern = self.body
def Tree_Bind_pattern(self: Bind)(given Context): Tree = self.body

def Pattern_Bind_module_copy(original: Bind)(name: String, pattern: Pattern)(given Context): Bind =
def Tree_Bind_module_copy(original: Bind)(name: String, pattern: Tree)(given Context): Bind =
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))

type Unapply = tpd.UnApply

def matchPattern_Unapply(pattern: Pattern)(given Context): Option[Unapply] = pattern match {
def matchTree_Unapply(pattern: Tree)(given Context): Option[Unapply] = pattern match {
case pattern @ Trees.UnApply(_, _, _) => Some(pattern)
case Trees.Typed(pattern @ Trees.UnApply(_, _, _), _) => Some(pattern)
case _ => None
}

def Pattern_Unapply_fun(self: Unapply)(given Context): Term = self.fun
def Pattern_Unapply_implicits(self: Unapply)(given Context): List[Term] = self.implicits
def Pattern_Unapply_patterns(self: Unapply)(given Context): List[Pattern] = effectivePatterns(self.patterns)
def Tree_Unapply_fun(self: Unapply)(given Context): Term = self.fun
def Tree_Unapply_implicits(self: Unapply)(given Context): List[Term] = self.implicits
def Tree_Unapply_patterns(self: Unapply)(given Context): List[Tree] = effectivePatterns(self.patterns)

def Pattern_Unapply_module_copy(original: Unapply)(fun: Term, implicits: List[Term], patterns: List[Pattern])(given Context): Unapply =
def Tree_Unapply_module_copy(original: Unapply)(fun: Term, implicits: List[Term], patterns: List[Tree])(given Context): Unapply =
withDefaultPos(tpd.cpy.UnApply(original)(fun, implicits, patterns))

private def effectivePatterns(patterns: List[Pattern]): List[Pattern] = patterns match {
private def effectivePatterns(patterns: List[Tree]): List[Tree] = patterns match {
case patterns0 :+ Trees.SeqLiteral(elems, _) => patterns0 ::: elems
case _ => patterns
}

type Alternatives = tpd.Alternative

def matchPattern_Alternatives(pattern: Pattern)(given Context): Option[Alternatives] = pattern match {
def matchTree_Alternatives(pattern: Tree)(given Context): Option[Alternatives] = pattern match {
case pattern: tpd.Alternative => Some(pattern)
case _ => None
}

def Pattern_Alternatives_patterns(self: Alternatives)(given Context): List[Pattern] = self.trees
def Tree_Alternatives_patterns(self: Alternatives)(given Context): List[Tree] = self.trees

def Pattern_Alternatives_module_apply(patterns: List[Pattern])(given Context): Alternatives =
def Tree_Alternatives_module_apply(patterns: List[Tree])(given Context): Alternatives =
withDefaultPos(tpd.Alternative(patterns))

def Pattern_Alternatives_module_copy(original: Alternatives)(patterns: List[Pattern])(given Context): Alternatives =
def Tree_Alternatives_module_copy(original: Alternatives)(patterns: List[Tree])(given Context): Alternatives =
tpd.cpy.Alternative(original)(patterns)

type TypeTest = tpd.Typed

def matchPattern_TypeTest(pattern: Pattern)(given Context): Option[TypeTest] = pattern match {
case Trees.Typed(_: tpd.UnApply, _) => None
case pattern: tpd.Typed => Some(pattern)
case _ => None
}

def Pattern_TypeTest_tpt(self: TypeTest)(given Context): TypeTree = self.tpt

def Pattern_TypeTest_module_apply(tpt: TypeTree)(given ctx: Context): TypeTest =
withDefaultPos(tpd.Typed(untpd.Ident(nme.WILDCARD)(ctx.source).withType(tpt.tpe), tpt))

def Pattern_TypeTest_module_copy(original: TypeTest)(tpt: TypeTree)(given Context): TypeTest =
tpd.cpy.Typed(original)(untpd.Ident(nme.WILDCARD).withSpan(original.span).withType(tpt.tpe), tpt)

type WildcardPattern = tpd.Ident

def matchPattern_WildcardPattern(pattern: Pattern)(given Context): Option[WildcardPattern] =
pattern match {
case pattern: tpd.Ident if tpd.isWildcardArg(pattern) => Some(pattern)
case _ => None
}

def Pattern_WildcardPattern_module_apply(tpe: TypeOrBounds)(given Context): WildcardPattern =
untpd.Ident(nme.WILDCARD).withType(tpe)

//
// TYPES
//
Expand Down Expand Up @@ -1461,9 +1404,6 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Symbol_tree(self: Symbol)(given Context): Tree =
FromSymbol.definitionFromSym(self)

def Symbol_pattern(self: Symbol)(given ctx: Context): Pattern =
FromSymbol.definitionFromSym(self)

def Symbol_privateWithin(self: Symbol)(given Context): Option[Type] = {
val within = self.privateWithin
if (within.exists && !self.is(core.Flags.Protected)) Some(within.typeRef)
Expand Down
45 changes: 11 additions & 34 deletions library/src-bootstrapped/scala/tasty/reflect/TreeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@ package reflect
/** Tasty reflect case definition */
trait TreeUtils
extends Core
with PatternOps
with SymbolOps
with TreeOps { self: Reflection =>

abstract class TreeAccumulator[X] {

// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
def foldTree(x: X, tree: Tree)(given ctx: Context): X
def foldPattern(x: X, tree: Pattern)(given ctx: Context): X

def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = trees.foldLeft(x)(foldTree)
def foldPatterns(x: X, trees: Iterable[Pattern])(given ctx: Context): X = trees.foldLeft(x)(foldPattern)

def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = {
def localCtx(definition: Definition): Context = definition.symbol.localContext
Expand Down Expand Up @@ -92,32 +89,22 @@ trait TreeUtils
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
case WildcardTypeTree() => x
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldPattern(x, pat), guard), body)
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
case Bind(_, body) => foldTree(x, body)
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
case Alternatives(patterns) => foldTrees(x, patterns)
}
}

def foldOverPattern(x: X, tree: Pattern)(given ctx: Context): X = tree match {
case Pattern.Value(v) => foldTree(x, v)
case Pattern.Bind(_, body) => foldPattern(x, body)
case Pattern.Unapply(fun, implicits, patterns) => foldPatterns(foldTrees(foldTree(x, fun), implicits), patterns)
case Pattern.Alternatives(patterns) => foldPatterns(x, patterns)
case Pattern.TypeTest(tpt) => foldTree(x, tpt)
case Pattern.WildcardPattern() => x
}

}

abstract class TreeTraverser extends TreeAccumulator[Unit] {

def traverseTree(tree: Tree)(given ctx: Context): Unit = traverseTreeChildren(tree)
def traversePattern(tree: Pattern)(given ctx: Context): Unit = traversePatternChildren(tree)

def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = traverseTree(tree)
def foldPattern(x: Unit, tree: Pattern)(given ctx: Context) = traversePattern(tree)

protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = foldOverTree((), tree)
protected def traversePatternChildren(tree: Pattern)(given ctx: Context): Unit = foldOverPattern((), tree)

}

Expand All @@ -138,6 +125,12 @@ trait TreeUtils
transformCaseDef(tree)
case IsTypeCaseDef(tree) =>
transformTypeCaseDef(tree)
case IsBind(pattern) =>
Bind.copy(pattern)(pattern.name, pattern.pattern)
case IsUnapply(pattern) =>
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
case IsAlternatives(pattern) =>
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
}
}

Expand Down Expand Up @@ -237,26 +230,13 @@ trait TreeUtils
}

def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = {
CaseDef.copy(tree)(transformPattern(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
}

def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
}

def transformPattern(pattern: Pattern)(given ctx: Context): Pattern = pattern match {
case Pattern.Value(_) | Pattern.WildcardPattern() =>
pattern
case Pattern.IsTypeTest(pattern) =>
Pattern.TypeTest.copy(pattern)(transformTypeTree(pattern.tpt))
case Pattern.IsUnapply(pattern) =>
Pattern.Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformPatterns(pattern.patterns))
case Pattern.IsAlternatives(pattern) =>
Pattern.Alternatives.copy(pattern)(transformPatterns(pattern.patterns))
case Pattern.IsBind(pattern) =>
Pattern.Bind.copy(pattern)(pattern.name, transformPattern(pattern.pattern))
}

def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
trees mapConserve (transformStatement(_))

Expand All @@ -275,9 +255,6 @@ trait TreeUtils
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
trees mapConserve (transformTypeCaseDef(_))

def transformPatterns(trees: List[Pattern])(given ctx: Context): List[Pattern] =
trees mapConserve (transformPattern(_))

def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
transformTrees(trees).asInstanceOf[List[Tr]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package reflect
/** Tasty reflect case definition */
trait TreeUtils
extends Core
with PatternOps
with SymbolOps
with TreeOps { self: Reflection =>

Expand Down
Loading