diff --git a/compiler/src/dotty/tools/dotc/core/Decorators.scala b/compiler/src/dotty/tools/dotc/core/Decorators.scala index f695923d517b..857b78115fbb 100644 --- a/compiler/src/dotty/tools/dotc/core/Decorators.scala +++ b/compiler/src/dotty/tools/dotc/core/Decorators.scala @@ -171,7 +171,7 @@ object Decorators { def & (ys: List[T]): List[T] = xs filter (ys contains _) } - extension [T, U](xss: List[List[T]]): + extension [T, U](xss: List[List[T]]) def nestedMap(f: T => U): List[List[U]] = xss.map(_.map(f)) def nestedMapConserve(f: T => U): List[List[U]] = @@ -180,14 +180,14 @@ object Decorators { xss.zipWithConserve(yss)((xs, ys) => xs.zipWithConserve(ys)(f)) end extension - extension (text: Text): + extension (text: Text) def show(using Context): String = text.mkString(ctx.settings.pageWidth.value, ctx.settings.printLines.value) /** Test whether a list of strings representing phases contains * a given phase. See [[config.CompilerCommand#explainAdvanced]] for the * exact meaning of "contains" here. */ - extension (names: List[String]) { + extension (names: List[String]) def containsPhase(phase: Phase): Boolean = names.nonEmpty && { phase match { @@ -203,18 +203,16 @@ object Decorators { } } } - } - extension [T](x: T) { + extension [T](x: T) def reporting( op: WrappedResult[T] ?=> String, printer: config.Printers.Printer = config.Printers.default): T = { printer.println(op(using WrappedResult(x))) x } - } - extension [T](x: T) { + extension [T](x: T) def assertingErrorsReported(using Context): T = { assert(ctx.reporter.errorsReported) x @@ -223,9 +221,12 @@ object Decorators { assert(ctx.reporter.errorsReported, msg) x } - } - extension (sc: StringContext) { + extension [T <: AnyRef](xs: ::[T]) + def derivedCons(x1: T, xs1: List[T]) = + if (xs.head eq x1) && (xs.tail eq xs1) then xs else x1 :: xs1 + + extension (sc: StringContext) /** General purpose string formatting */ def i(args: Any*)(using Context): String = new StringFormatter(sc).assemble(args) @@ -241,9 +242,8 @@ object Decorators { */ def ex(args: Any*)(using Context): String = explained(em(args: _*)) - } - extension [T <: AnyRef](arr: Array[T]): + extension [T <: AnyRef](arr: Array[T]) def binarySearch(x: T): Int = java.util.Arrays.binarySearch(arr.asInstanceOf[Array[Object]], x) } diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index e34f4495b648..f00edcb189c6 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -1,6 +1,6 @@ package dotty.tools.dotc.core -import Types._, Symbols._, Contexts._ +import Types._, Symbols._, Contexts._, Decorators._ /** Substitution operations on types. See the corresponding `subst` and * `substThis` methods on class Type for an explanation. @@ -16,6 +16,8 @@ object Substituters: else tp.derivedSelect(subst(tp.prefix, from, to, theMap)) case _: ThisType => tp + case tp: AppliedType => + tp.map(subst(_, from, to, theMap)) case _ => (if (theMap != null) theMap else new SubstBindingMap(from, to)) .mapOver(tp) @@ -94,7 +96,7 @@ object Substituters: ts = ts.tail } tp - case _: ThisType | _: BoundType => + case _: BoundType => tp case _ => (if (theMap != null) theMap else new SubstSymMap(from, to)) @@ -152,45 +154,47 @@ object Substituters: else tp.derivedSelect(substParams(tp.prefix, from, to, theMap)) case _: ThisType => tp + case tp: AppliedType => + tp.map(substParams(_, from, to, theMap)) case _ => (if (theMap != null) theMap else new SubstParamsMap(from, to)) .mapOver(tp) } final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = subst(tp, from, to, this) + def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) } final class Subst1Map(from: Symbol, to: Type)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = subst1(tp, from, to, this) + def apply(tp: Type): Type = subst1(tp, from, to, this)(using mapCtx) } final class Subst2Map(from1: Symbol, to1: Type, from2: Symbol, to2: Type)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = subst2(tp, from1, to1, from2, to2, this) + def apply(tp: Type): Type = subst2(tp, from1, to1, from2, to2, this)(using mapCtx) } final class SubstMap(from: List[Symbol], to: List[Type])(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = subst(tp, from, to, this) + def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) } final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = substSym(tp, from, to, this) + def apply(tp: Type): Type = substSym(tp, from, to, this)(using mapCtx) } final class SubstThisMap(from: ClassSymbol, to: Type)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = substThis(tp, from, to, this) + def apply(tp: Type): Type = substThis(tp, from, to, this)(using mapCtx) } final class SubstRecThisMap(from: Type, to: Type)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = substRecThis(tp, from, to, this) + def apply(tp: Type): Type = substRecThis(tp, from, to, this)(using mapCtx) } final class SubstParamMap(from: ParamRef, to: Type)(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = substParam(tp, from, to, this) + def apply(tp: Type): Type = substParam(tp, from, to, this)(using mapCtx) } final class SubstParamsMap(from: BindingType, to: List[Type])(using Context) extends DeepTypeMap { - def apply(tp: Type): Type = substParams(tp, from, to, this) + def apply(tp: Type): Type = substParams(tp, from, to, this)(using mapCtx) } /** An approximating substitution that can handle wildcards in the `to` list */ diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 5354cf14f2bc..b9a2f484156c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -100,6 +100,8 @@ object TypeOps: val sym = tp.symbol if (sym.isStatic && !sym.maybeOwner.seesOpaques || (tp.prefix `eq` NoPrefix)) tp else derivedSelect(tp, atVariance(variance max 0)(this(tp.prefix))) + case tp: LambdaType => + mapOverLambda(tp) // special cased common case case tp: ThisType => toPrefix(pre, cls, tp.cls) case _: BoundType => @@ -136,6 +138,9 @@ object TypeOps: tp2 case tp1 => tp1 } + case tp: AppliedType => + val normed = tp.tryNormalize + if normed.exists then normed else tp.map(simplify(_, theMap)) case tp: TypeParamRef => val tvar = ctx.typerState.constraint.typeVarOfParam(tp) if (tvar.exists) tvar else tp @@ -147,7 +152,7 @@ object TypeOps: simplify(l, theMap) & simplify(r, theMap) case OrType(l, r) if !ctx.mode.is(Mode.Type) => simplify(l, theMap) | simplify(r, theMap) - case _: AppliedType | _: MatchType => + case _: MatchType => val normed = tp.tryNormalize if (normed.exists) normed else mapOver case tp: MethodicType => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 01486f89c969..3520195b89a5 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -111,33 +111,37 @@ object Types { */ def isProvisional(using Context): Boolean = mightBeProvisional && testProvisional - private def testProvisional(using Context) = { - val accu = new TypeAccumulator[Boolean] { - override def apply(x: Boolean, t: Type) = - x || t.mightBeProvisional && { - t.mightBeProvisional = t match { - case t: TypeVar => - !t.inst.exists || apply(x, t.inst) - case t: TypeRef => + private def testProvisional(using Context): Boolean = + class ProAcc extends TypeAccumulator[Boolean]: + override def apply(x: Boolean, t: Type) = x || test(t, this) + def test(t: Type, theAcc: TypeAccumulator[Boolean]): Boolean = + if t.mightBeProvisional then + t.mightBeProvisional = t match + case t: TypeRef => + !t.currentSymbol.isStatic && { (t: Type).mightBeProvisional = false // break cycles - t.symbol.is(Provisional) || - apply(x, t.prefix) || { - t.info match { - case info: AliasingBounds => apply(x, info.alias) - case TypeBounds(lo, hi) => apply(apply(x, lo), hi) + t.symbol.is(Provisional) + || test(t.prefix, theAcc) + || t.info.match + case info: AliasingBounds => test(info.alias, theAcc) + case TypeBounds(lo, hi) => test(lo, theAcc) || test(hi, theAcc) case _ => false - } - } - case t: LazyRef => - !t.completed || apply(x, t.ref) - case _ => - foldOver(x, t) - } - t.mightBeProvisional - } - } - accu.apply(false, this) - } + } + case t: TermRef => + !t.currentSymbol.isStatic && test(t.prefix, theAcc) + case t: AppliedType => + t.fold(false, (x, tp) => x || test(tp, theAcc)) + case t: TypeVar => + !t.inst.exists || test(t.inst, theAcc) + case t: LazyRef => + !t.completed || test(t.ref, theAcc) + case _ => + (if theAcc != null then theAcc else ProAcc()).foldOver(false, t) + end if + t.mightBeProvisional + end test + test(this, null) + end testProvisional /** Is this type different from NoType? */ final def exists: Boolean = this.ne(NoType) @@ -3214,11 +3218,8 @@ object Types { def newLikeThis(paramNames: List[ThisName], paramInfos: List[PInfo], resType: Type)(using Context): This = def substParams(pinfos: List[PInfo], to: This): List[PInfo] = pinfos match - case pinfo :: rest => - val pinfo1 = pinfo.subst(this, to).asInstanceOf[PInfo] - val rest1 = substParams(rest, to) - if (pinfo1 eq pinfo) && (rest1 eq rest) then pinfos - else pinfo1 :: rest1 + case pinfos @ (pinfo :: rest) => + pinfos.derivedCons(pinfo.subst(this, to).asInstanceOf[PInfo], substParams(rest, to)) case nil => nil companion(paramNames)( @@ -3282,32 +3283,36 @@ object Types { private var myDependencyStatus: DependencyStatus = Unknown private var myParamDependencyStatus: DependencyStatus = Unknown - private def depStatus(initial: DependencyStatus, tp: Type)(using Context): DependencyStatus = { - def combine(x: DependencyStatus, y: DependencyStatus) = { + private def depStatus(initial: DependencyStatus, tp: Type)(using Context): DependencyStatus = + class DepAcc extends TypeAccumulator[DependencyStatus]: + def apply(status: DependencyStatus, tp: Type) = compute(status, tp, this) + def combine(x: DependencyStatus, y: DependencyStatus) = val status = (x & StatusMask) max (y & StatusMask) val provisional = (x | y) & Provisional - (if (status == TrueDeps) status else status | provisional).toByte - } - val depStatusAcc = new TypeAccumulator[DependencyStatus] { - def apply(status: DependencyStatus, tp: Type) = - if (status == TrueDeps) status - else - tp match { - case TermParamRef(`thisLambdaType`, _) => TrueDeps - case tp: TypeRef => - val status1 = foldOver(status, tp) - tp.info match { // follow type alias to avoid dependency - case TypeAlias(alias) if status1 == TrueDeps && status != TrueDeps => - combine(apply(status, alias), FalseDeps) - case _ => - status1 - } - case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) - case _ => foldOver(status, tp) + (if status == TrueDeps then status else status | provisional).toByte + def compute(status: DependencyStatus, tp: Type, theAcc: TypeAccumulator[DependencyStatus]): DependencyStatus = + def applyPrefix(tp: NamedType) = + if tp.currentSymbol.isStatic then status + else compute(status, tp.prefix, theAcc) + if status == TrueDeps then status + else tp match + case tp: TypeRef => + val status1 = applyPrefix(tp) + tp.info match { // follow type alias to avoid dependency + case TypeAlias(alias) if status1 == TrueDeps => + combine(compute(status, alias, theAcc), FalseDeps) + case _ => + status1 } - } - depStatusAcc(initial, tp) - } + case tp: TermRef => applyPrefix(tp) + case tp: AppliedType => tp.fold(status, compute(_, _, theAcc)) + case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) + case TermParamRef(`thisLambdaType`, _) => TrueDeps + case _: ThisType | _: BoundType | NoPrefix => status + case _ => + (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) + compute(initial, tp, null) + end depStatus /** The dependency status of this method. Some examples: * @@ -3840,6 +3845,18 @@ object Types { superType } + inline def map(inline op: Type => Type)(using Context) = + def mapArgs(args: List[Type]): List[Type] = args match + case args @ (arg :: rest) => args.derivedCons(op(arg), mapArgs(rest)) + case nil => nil + derivedAppliedType(op(tycon), mapArgs(args)) + + inline def fold[T](x: T, inline op: (T, Type) => T)(using Context): T = + def foldArgs(x: T, args: List[Type]): T = args match + case arg :: rest => foldArgs(op(x, arg), rest) + case nil => x + foldArgs(op(x, tycon), args) + override def tryNormalize(using Context): Type = tycon match { case tycon: TypeRef => def tryMatchAlias = tycon.info match { @@ -4945,10 +4962,29 @@ object Types { protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type = tp.derivedLambdaType(tp.paramNames, formals, restpe) + protected def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match + case arg :: otherArgs if tparams.nonEmpty => + val arg1 = arg match + case arg: TypeBounds => this(arg) + case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg)) + val otherArgs1 = mapArgs(otherArgs, tparams.tail) + if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args + else arg1 :: otherArgs1 + case nil => + nil + + protected def mapOverLambda(tp: LambdaType) = + val restpe = tp.resultType + val saved = variance + variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance + val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]] + variance = saved + derivedLambdaType(tp)(ptypes1, this(restpe)) + /** Map this function over given type */ def mapOver(tp: Type): Type = { - record(s"mapOver ${getClass}") - record("mapOver total") + record(s"TypeMap mapOver ${getClass}") + record("TypeMap mapOver total") val ctx = this.mapCtx // optimization for performance given Context = ctx tp match { @@ -4963,27 +4999,12 @@ object Types { // if `p <: q` then `p.A <: q.A`, and well-formedness requires that `A` is a member // of `p`'s upper bound. derivedSelect(tp, prefix1) - case _: ThisType - | _: BoundType - | NoPrefix => tp case tp: AppliedType => - def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match { - case arg :: otherArgs if tparams.nonEmpty => - val arg1 = arg match { - case arg: TypeBounds => this(arg) - case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg)) - } - val otherArgs1 = mapArgs(otherArgs, tparams.tail) - if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args - else arg1 :: otherArgs1 - case nil => - nil - } derivedAppliedType(tp, this(tp.tycon), mapArgs(tp.args, tp.tyconTypeParams)) - case tp: RefinedType => - derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo)) + case tp: LambdaType => + mapOverLambda(tp) case tp: AliasingBounds => derivedAlias(tp, atVariance(0)(this(tp.alias))) @@ -4994,9 +5015,6 @@ object Types { variance = -variance derivedTypeBounds(tp, lo1, this(tp.hi)) - case tp: RecType => - derivedRecType(tp, this(tp.parent)) - case tp: TypeVar => val inst = tp.instanceOpt if (inst.exists) apply(inst) else tp @@ -5004,16 +5022,25 @@ object Types { case tp: ExprType => derivedExprType(tp, this(tp.resultType)) - case tp: LambdaType => - def mapOverLambda = { - val restpe = tp.resultType - val saved = variance - variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance - val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]] - variance = saved - derivedLambdaType(tp)(ptypes1, this(restpe)) - } - mapOverLambda + case tp @ AnnotatedType(underlying, annot) => + val underlying1 = this(underlying) + if (underlying1 eq underlying) tp + else derivedAnnotatedType(tp, underlying1, mapOver(annot)) + + case _: ThisType + | _: BoundType + | NoPrefix => + tp + + case tp: ProtoType => + tp.map(this) + + case tp: RefinedType => + derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo)) + + case tp: RecType => + record("TypeMap.RecType") + derivedRecType(tp, this(tp.parent)) case tp @ SuperType(thistp, supertp) => derivedSuperType(tp, this(thistp), this(supertp)) @@ -5047,20 +5074,12 @@ object Types { case tp: SkolemType => derivedSkolemType(tp, this(tp.info)) - case tp @ AnnotatedType(underlying, annot) => - val underlying1 = this(underlying) - if (underlying1 eq underlying) tp - else derivedAnnotatedType(tp, underlying1, mapOver(annot)) - case tp: WildcardType => derivedWildcardType(tp, mapOver(tp.optBounds)) case tp: JavaArrayType => derivedJavaArrayType(tp, this(tp.elemType)) - case tp: ProtoType => - tp.map(this) - case _ => tp } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 6920d972b387..168d8e84fc67 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -617,6 +617,10 @@ object ProtoTypes { wildApprox(tp.refinedInfo, theMap, seen, internal)) case tp: AliasingBounds => // default case, inlined for speed tp.derivedAlias(wildApprox(tp.alias, theMap, seen, internal)) + case tp: TypeBounds => + tp.derivedTypeBounds( + wildApprox(tp.lo, theMap, seen, internal), + wildApprox(tp.hi, theMap, seen, internal)) case tp @ TypeParamRef(tl, _) if internal.contains(tl) => tp case tp @ TypeParamRef(poly, pnum) => def wildApproxBounds(bounds: TypeBounds) =