Skip to content

Fix #9103: Add additional config to NamedParts accumulator #9106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/Positioned.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import core.Contexts.Context
import core.Decorators._
import core.Flags.{JavaDefined, Extension}
import core.StdNames.nme
import ast.Trees.mods
import annotation.constructorOnly
import annotation.internal.sharable
import reporting.Reporter
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ object Trees {
def namedType: NamedType = tpe.asInstanceOf[NamedType]
}

def (mdef: untpd.DefTree).mods: untpd.Modifiers = mdef.rawMods

abstract class NamedDefTree[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends NameTree[T] with DefTree[T] {
type ThisTree[-T >: Untyped] <: NamedDefTree[T]

Expand Down Expand Up @@ -1538,6 +1540,7 @@ object Trees {
receiver: tpd.Tree, method: TermName, args: List[Tree], targs: List[Type],
expectedType: Type)(using parentCtx: Context): tpd.Tree = {
given ctx as Context = parentCtx.retractMode(Mode.ImplicitsEnabled)
import dotty.tools.dotc.ast.tpd.TreeOps

val typer = ctx.typer
val proto = FunProto(args, expectedType)
Expand Down
7 changes: 0 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,6 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
/** A repeated argument such as `arg: _*` */
def repeated(arg: Tree)(implicit ctx: Context): Typed = Typed(arg, Ident(tpnme.WILDCARD_STAR))

// ----- Accessing modifiers ----------------------------------------------------

abstract class ModsDecorator { def mods: Modifiers }

implicit class modsDeco(val mdef: DefTree)(implicit ctx: Context) {
def mods: Modifiers = mdef.rawMods
}

// --------- Copier/Transformer/Accumulator classes for untyped trees -----

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class TypeApplications(val self: Type) extends AnyVal {
case dealiased: TypeBounds =>
dealiased.derivedTypeBounds(dealiased.lo.appliedTo(args), dealiased.hi.appliedTo(args))
case dealiased: LazyRef =>
LazyRef(c => dealiased.ref(c).appliedTo(args))
LazyRef(c => dealiased.ref(c).appliedTo(args)(using c))
case dealiased: WildcardType =>
WildcardType(dealiased.optBounds.orElse(TypeBounds.empty).appliedTo(args).bounds)
case dealiased: TypeRef if dealiased.symbol == defn.NothingClass =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ object TypeOps:
val widenMap = new ApproximatingTypeMap {
@threadUnsafe lazy val forbidden = symsToAvoid.toSet
def toAvoid(sym: Symbol) = !sym.isStatic && forbidden.contains(sym)
def partsToAvoid = new NamedPartsAccumulator(tp => toAvoid(tp.symbol))
def partsToAvoid =
new NamedPartsAccumulator(tp => toAvoid(tp.symbol), widenSingletons = true)
def apply(tp: Type): Type = tp match {
case tp: TermRef
if toAvoid(tp.symbol) || partsToAvoid(mutable.Set.empty, tp.info).nonEmpty =>
Expand Down
62 changes: 43 additions & 19 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,18 @@ object Types {
final def foreachPart(p: Type => Unit, stopAtStatic: Boolean = false)(implicit ctx: Context): Unit =
new ForeachAccumulator(p, stopAtStatic).apply((), this)

/** The parts of this type which are type or term refs */
final def namedParts(implicit ctx: Context): collection.Set[NamedType] =
namedPartsWith(alwaysTrue)

/** The parts of this type which are type or term refs and which
* satisfy predicate `p`.
*
* @param p The predicate to satisfy
* @param excludeLowerBounds If set to true, the lower bounds of abstract
* types will be ignored.
*/
def namedPartsWith(p: NamedType => Boolean, excludeLowerBounds: Boolean = false)
(implicit ctx: Context): collection.Set[NamedType] =
new NamedPartsAccumulator(p, excludeLowerBounds).apply(mutable.LinkedHashSet(), this)
def namedPartsWith(p: NamedType => Boolean,
widenSingletons: Boolean = false,
excludeLowerBounds: Boolean = false)
(implicit ctx: Context): collection.Set[NamedType] =
new NamedPartsAccumulator(p, widenSingletons, excludeLowerBounds).apply(mutable.LinkedHashSet(), this)

/** Map function `f` over elements of an AndType, rebuilding with function `g` */
def mapReduceAnd[T](f: Type => T)(g: (T, T) => T)(implicit ctx: Context): T = stripTypeVar match {
Expand Down Expand Up @@ -4862,7 +4860,7 @@ object Types {
}
}

abstract class TypeMap(implicit protected val mapCtx: Context)
abstract class TypeMap(implicit protected var mapCtx: Context)
extends VariantTraversal with (Type => Type) { thisMap =>

protected def stopAtStatic: Boolean = true
Expand Down Expand Up @@ -4979,7 +4977,16 @@ object Types {
derivedSuperType(tp, this(thistp), this(supertp))

case tp: LazyRef =>
LazyRef(_ => this(tp.ref))
LazyRef { c =>
val ref1 = tp.ref(using c)
if c.runId == mapCtx.runId then this(ref1)
else // splice in new run into map context
val saved = mapCtx
mapCtx = mapCtx.fresh
.setPeriod(Period(c.runId, mapCtx.phaseId))
.setRun(c.run)
try this(ref1) finally mapCtx = saved
}

case tp: ClassInfo =>
mapClassInfo(tp)
Expand Down Expand Up @@ -5331,7 +5338,7 @@ object Types {

protected def applyToAnnot(x: T, annot: Annotation): T = x // don't go into annotations

protected final def applyToPrefix(x: T, tp: NamedType): T =
protected def applyToPrefix(x: T, tp: NamedType): T =
atVariance(variance max 0)(this(x, tp.prefix)) // see remark on NamedType case in TypeMap

def foldOver(x: T, tp: Type): T = {
Expand Down Expand Up @@ -5453,27 +5460,44 @@ object Types {
def apply(x: Unit, tp: Type): Unit = foldOver(p(tp), tp)
}

class NamedPartsAccumulator(p: NamedType => Boolean, excludeLowerBounds: Boolean = false)
(implicit ctx: Context) extends TypeAccumulator[mutable.Set[NamedType]] {
class TypeHashSet extends util.HashSet[Type](64):
override def hash(x: Type): Int = System.identityHashCode(x)
override def isEqual(x: Type, y: Type) = x.eq(y)

class NamedPartsAccumulator(p: NamedType => Boolean,
widenSingletons: Boolean = false, // if set, also consider underlying types in singleton path prefixes
excludeLowerBounds: Boolean = false)
(implicit ctx: Context) extends TypeAccumulator[mutable.Set[NamedType]] {

override def stopAtStatic: Boolean = false
def maybeAdd(x: mutable.Set[NamedType], tp: NamedType): mutable.Set[NamedType] = if (p(tp)) x += tp else x
val seen: util.HashSet[Type] = new util.HashSet[Type](64) {
override def hash(x: Type): Int = System.identityHashCode(x)
override def isEqual(x: Type, y: Type) = x.eq(y)
}

def maybeAdd(x: mutable.Set[NamedType], tp: NamedType): mutable.Set[NamedType] =
if (p(tp)) x += tp else x

val seen = TypeHashSet()

override def applyToPrefix(x: mutable.Set[NamedType], tp: NamedType): mutable.Set[NamedType] =
tp.prefix match
case pre: TermRef if !widenSingletons =>
foldOver(maybeAdd(x, pre), pre)
case pre: ThisType if !widenSingletons =>
x
case _ =>
super.applyToPrefix(x, tp)

def apply(x: mutable.Set[NamedType], tp: Type): mutable.Set[NamedType] =
if (seen contains tp) x
else {
seen.addEntry(tp)
tp match {
case tp: TypeRef =>
foldOver(maybeAdd(x, tp), tp)
case tp: TermRef =>
apply(foldOver(maybeAdd(x, tp), tp), tp.underlying)
case tp: ThisType =>
apply(x, tp.tref)
case NoPrefix =>
foldOver(x, tp)
case tp: TermRef =>
apply(foldOver(maybeAdd(x, tp), tp), tp.underlying)
case tp: AppliedType =>
foldOver(x, tp)
case TypeBounds(lo, hi) =>
Expand Down
16 changes: 7 additions & 9 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}

protected def toTextCore[T >: Untyped](tree: Tree[T]): Text = {
import untpd.{modsDeco => _, _}
import untpd._

def isLocalThis(tree: Tree) = tree.typeOpt match {
case tp: ThisType => tp.cls == ctx.owner.enclosingClass
Expand Down Expand Up @@ -647,7 +647,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}

override def toText[T >: Untyped](tree: Tree[T]): Text = controlled {
import untpd.{modsDeco => _, _}
import untpd._

var txt = toTextCore(tree)

Expand Down Expand Up @@ -722,11 +722,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}
}

/** Print modifiers from symbols if tree has type, overriding the untpd behavior. */
private implicit def modsDeco(mdef: untpd.DefTree): untpd.ModsDecorator =
new untpd.ModsDecorator {
def mods = if (mdef.hasType) Modifiers(mdef.symbol) else mdef.rawMods
}
/** Print modifiers from symbols if tree has type, overriding the behavior in Trees. */
def (mdef: untpd.DefTree).mods: untpd.Modifiers =
if mdef.hasType then Modifiers(mdef.symbol) else mdef.rawMods

private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers(
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
Expand Down Expand Up @@ -770,7 +768,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
vparamss.foldLeft(leading)((txt, params) => txt ~ paramsText(params))

protected def valDefToText[T >: Untyped](tree: ValDef[T]): Text = {
import untpd.{modsDeco => _}
import untpd._
dclTextOr(tree) {
modText(tree.mods, tree.symbol, keywordStr(if (tree.mods.is(Mutable)) "var" else "val"), isType = false) ~~
valDefText(nameIdText(tree)) ~ optAscription(tree.tpt) ~
Expand All @@ -784,7 +782,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
~ toText(params, ", ") ~ ")"

protected def defDefToText[T >: Untyped](tree: DefDef[T]): Text = {
import untpd.{modsDeco => _}
import untpd._
dclTextOr(tree) {
val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false)
val isExtension = tree.hasType && tree.symbol.is(Extension)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package semanticdb
import core._
import Phases._
import ast.tpd._
import ast.Trees.mods
import Contexts._
import Symbols._
import Flags._
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,6 @@ trait Checking {
* 2. Check that case class `enum` cases do not extend java.lang.Enum.
*/
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
import untpd.modsDeco
def isEnumAnonCls =
cls.isAnonymousClass &&
cls.owner.isTerm &&
Expand Down
Loading