diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index ad2676624b0f..a80fcc59a806 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1152,7 +1152,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def etaExpandCFT(using Context): Tree = def expand(target: Tree, tp: Type)(using Context): Tree = tp match - case defn.ContextFunctionType(argTypes, resType, _) => + case defn.ContextFunctionType(argTypes, resType) => val anonFun = newAnonFun( ctx.owner, MethodType.companion(isContextual = true)(argTypes, resType), diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index b4df6bcd4ca5..7a88af1a052a 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1878,18 +1878,14 @@ class Definitions { * types `As`, the result type `B` and a whether the type is an erased context function. */ object ContextFunctionType: - def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] = - if ctx.erasedTypes then - atPhase(erasurePhase)(unapply(tp)) - else - asContextFunctionType(tp) match - case PolyFunctionOf(mt: MethodType) => - Some((mt.paramInfos, mt.resType, mt.erasedParams)) - case tp1 if tp1.exists => - val args = tp1.functionArgInfos - val erasedParams = List.fill(functionArity(tp1)) { false } - Some((args.init, args.last, erasedParams)) - case _ => None + def unapply(tp: Type)(using Context): Option[(List[Type], Type)] = + asContextFunctionType(tp) match + case PolyFunctionOf(mt: MethodType) => + Some((mt.paramInfos, mt.resType)) + case tp1 if tp1.exists => + val args = tp1.functionArgInfos + Some((args.init, args.last)) + case _ => None /** A whitelist of Scala-2 classes that are known to be pure */ def isAssuredNoInits(sym: Symbol): Boolean = diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index 569b16681cde..94f7b405c027 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -129,25 +129,24 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { assert(ctx.typer.isInstanceOf[Erasure.Typer]) ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType) else - val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp: @unchecked - val anonFun = newAnonFun(ctx.owner, - MethodType( - argTypes.zip(erasedParams.padTo(argTypes.length, false)) - .flatMap((t, e) => if e then None else Some(t)), - resType), - coord = ctx.owner.coord) + val mtWithoutErasedParams = atPhase(erasurePhase) { + val defn.ContextFunctionType(argTypes, resType) = tp.dealias: @unchecked + val paramInfos = argTypes.filterNot(_.hasAnnotation(defn.ErasedParamAnnot)) + MethodType(paramInfos, resType) + } + val anonFun = newAnonFun(ctx.owner, mtWithoutErasedParams, coord = ctx.owner.coord) anonFun.info = transformInfo(anonFun, anonFun.info) def lambdaBody(refss: List[List[Tree]]) = val refs :: Nil = refss: @unchecked val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil => - argTypes.indices.toList.map(n => + mtWithoutErasedParams.paramInfos.indices.toList.map(n => bunchedParam .select(nme.primitive.arrayApply) .appliedTo(Literal(Constant(n)))) case refs1 => refs1 - expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun)) + expand(args ::: expandedRefs, mtWithoutErasedParams.resType, n - 1)(using ctx.withOwner(anonFun)) val unadapted = Closure(anonFun, lambdaBody) cpy.Block(unadapted)(unadapted.stats, diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index b9478fb893a0..01a77427698a 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -20,7 +20,7 @@ object ContextFunctionResults: */ def annotateContextResults(mdef: DefDef)(using Context): Unit = def contextResultCount(rhs: Tree, tp: Type): Int = tp match - case defn.ContextFunctionType(_, resTpe, _) => + case defn.ContextFunctionType(_, resTpe) => rhs match case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe) case _ => 0 @@ -58,7 +58,8 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe) + case defn.ContextFunctionType(argTpes, resTpe) => + argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -72,7 +73,7 @@ object ContextFunctionResults: integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType, erasedParams) => + case defn.ContextFunctionType(argTypes, resType) => MethodType(argTypes, integrateContextResults(resType, crCount - 1)) /** The total number of parameters of method `sym`, not counting @@ -83,10 +84,10 @@ object ContextFunctionResults: def contextParamCount(tp: Type, crCount: Int): Int = if crCount == 0 then 0 else - val defn.ContextFunctionType(params, resTpe, erasedParams) = tp: @unchecked + val defn.ContextFunctionType(params, resTpe) = tp: @unchecked val rest = contextParamCount(resTpe, crCount - 1) - // TODO use mt.nonErasedParamCount - if erasedParams.contains(true) then erasedParams.count(_ == false) + rest else params.length + rest // TODO use mt.nonErasedParamCount + val nonErasedParams = params.count(!_.hasAnnotation(defn.ErasedParamAnnot)) + nonErasedParams + rest def normalParamCount(tp: Type): Int = tp.widenExpr.stripPoly match case mt @ MethodType(pnames) => mt.nonErasedParamCount + normalParamCount(mt.resType) @@ -100,7 +101,7 @@ object ContextFunctionResults: def recur(tp: Type, n: Int): Type = if n == 0 then tp else tp match - case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1) + case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1) recur(meth.info.finalResultType, depth) /** Should selection `tree` be eliminated since it refers to an `apply` @@ -115,7 +116,7 @@ object ContextFunctionResults: case Select(qual, name) => if name == nme.apply then qual.tpe match - case defn.ContextFunctionType(_, _, _) => + case defn.ContextFunctionType(_, _) => integrateSelect(qual, n + 1) case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs integrateSelect(qual, n + 1) diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 25cbfdfec600..339d1f2f7bc6 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -167,7 +167,7 @@ object ErrorReporting { val normPt = normalize(pt, pt) def contextFunctionCount(tp: Type): Int = tp.stripped match - case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp) + case defn.ContextFunctionType(_, restp) => 1 + contextFunctionCount(restp) case _ => 0 def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp) def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index b43240a1fbb1..8ed881ca0d81 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1893,7 +1893,7 @@ class Namer { typer: Typer => val originalTp = defaultParamType val approxTp = wildApprox(originalTp) approxTp.stripPoly match - case atp @ defn.ContextFunctionType(_, resType, _) + case atp @ defn.ContextFunctionType(_, resType) if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound || resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) => originalTp