From 5f19c1d71656c8408a0a84c6037410aa7d5fd66f Mon Sep 17 00:00:00 2001 From: Wojciech Mazur Date: Thu, 20 Jun 2024 15:04:16 +0200 Subject: [PATCH 1/2] Remove erasure logic from ContextFunctionType [Cherry-picked 350dfa76ec7b4a3f9747b7ef35243b6b1f2ebe95][modified] --- .../dotty/tools/dotc/core/Definitions.scala | 19 ++++++++----------- .../dotty/tools/dotc/transform/Bridges.scala | 17 ++++++++--------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1c3a2c02bd6d..6e6321d10d9c 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1882,17 +1882,14 @@ class Definitions { */ object ContextFunctionType: def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] = - if ctx.erasedTypes then - atPhase(erasurePhase)(unapply(tp)) - else - asContextFunctionType(tp) match - case ErasedFunctionOf(mt) => - Some((mt.paramInfos, mt.resType, mt.erasedParams)) - case tp1 if tp1.exists => - val args = tp1.functionArgInfos - val erasedParams = erasedFunctionParameters(tp1) - Some((args.init, args.last, erasedParams)) - case _ => None + asContextFunctionType(tp) match + case PolyFunctionOf(mt: MethodType) => + Some((mt.paramInfos, mt.resType, mt.erasedParams)) + case tp1 if tp1.exists => + val args = tp1.functionArgInfos + val erasedParams = List.fill(functionArity(tp1)) { false } + Some((args.init, args.last, erasedParams)) + case _ => None /* Returns a list of erased booleans marking whether parameters are erased, for a function type. */ def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match { diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index 569b16681cde..0156b6c26c40 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -129,25 +129,24 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { assert(ctx.typer.isInstanceOf[Erasure.Typer]) ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType) else - val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp: @unchecked - val anonFun = newAnonFun(ctx.owner, - MethodType( - argTypes.zip(erasedParams.padTo(argTypes.length, false)) - .flatMap((t, e) => if e then None else Some(t)), - resType), - coord = ctx.owner.coord) + val mtWithoutErasedParams = atPhase(erasurePhase) { + val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp.dealias: @unchecked + val paramInfos = argTypes.zip(erasedParams).collect { case (argType, erased) if !erased => argType } + MethodType(paramInfos, resType) + } + val anonFun = newAnonFun(ctx.owner, mtWithoutErasedParams, coord = ctx.owner.coord) anonFun.info = transformInfo(anonFun, anonFun.info) def lambdaBody(refss: List[List[Tree]]) = val refs :: Nil = refss: @unchecked val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil => - argTypes.indices.toList.map(n => + mtWithoutErasedParams.paramInfos.indices.toList.map(n => bunchedParam .select(nme.primitive.arrayApply) .appliedTo(Literal(Constant(n)))) case refs1 => refs1 - expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun)) + expand(args ::: expandedRefs, mtWithoutErasedParams.resType, n - 1)(using ctx.withOwner(anonFun)) val unadapted = Closure(anonFun, lambdaBody) cpy.Block(unadapted)(unadapted.stats, From e0e421e268804c6e137b0970e580aa5c26f54266 Mon Sep 17 00:00:00 2001 From: Wojciech Mazur Date: Thu, 20 Jun 2024 15:06:40 +0200 Subject: [PATCH 2/2] Filter/count erased parameters directly on parameters types We can filter the erased parameters by looking at the `ErasedParamAnnot`. [Cherry-picked 7c0a848d0ed6330873a39129e6d02da169150fa7][modified] --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 2 +- .../src/dotty/tools/dotc/core/Definitions.scala | 7 +++---- .../src/dotty/tools/dotc/transform/Bridges.scala | 4 ++-- .../dotc/transform/ContextFunctionResults.scala | 16 +++++++++------- .../dotty/tools/dotc/typer/ErrorReporting.scala | 2 +- compiler/src/dotty/tools/dotc/typer/Namer.scala | 2 +- 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index dfa04de22c17..79e47ca7b8df 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1146,7 +1146,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def etaExpandCFT(using Context): Tree = def expand(target: Tree, tp: Type)(using Context): Tree = tp match - case defn.ContextFunctionType(argTypes, resType, _) => + case defn.ContextFunctionType(argTypes, resType) => val anonFun = newAnonFun( ctx.owner, MethodType.companion(isContextual = true)(argTypes, resType), diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6e6321d10d9c..c5a798e2dcd7 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1881,14 +1881,13 @@ class Definitions { * types `As`, the result type `B` and a whether the type is an erased context function. */ object ContextFunctionType: - def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] = + def unapply(tp: Type)(using Context): Option[(List[Type], Type)] = asContextFunctionType(tp) match case PolyFunctionOf(mt: MethodType) => - Some((mt.paramInfos, mt.resType, mt.erasedParams)) + Some((mt.paramInfos, mt.resType)) case tp1 if tp1.exists => val args = tp1.functionArgInfos - val erasedParams = List.fill(functionArity(tp1)) { false } - Some((args.init, args.last, erasedParams)) + Some((args.init, args.last)) case _ => None /* Returns a list of erased booleans marking whether parameters are erased, for a function type. */ diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index 0156b6c26c40..94f7b405c027 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -130,8 +130,8 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType) else val mtWithoutErasedParams = atPhase(erasurePhase) { - val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp.dealias: @unchecked - val paramInfos = argTypes.zip(erasedParams).collect { case (argType, erased) if !erased => argType } + val defn.ContextFunctionType(argTypes, resType) = tp.dealias: @unchecked + val paramInfos = argTypes.filterNot(_.hasAnnotation(defn.ErasedParamAnnot)) MethodType(paramInfos, resType) } val anonFun = newAnonFun(ctx.owner, mtWithoutErasedParams, coord = ctx.owner.coord) diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index b4eb71c541d3..41453aa62b50 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -20,7 +20,7 @@ object ContextFunctionResults: */ def annotateContextResults(mdef: DefDef)(using Context): Unit = def contextResultCount(rhs: Tree, tp: Type): Int = tp match - case defn.ContextFunctionType(_, resTpe, _) => + case defn.ContextFunctionType(_, resTpe) => rhs match case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe) case _ => 0 @@ -58,7 +58,8 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe) + case defn.ContextFunctionType(argTpes, resTpe) => + argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -72,7 +73,7 @@ object ContextFunctionResults: integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType, erasedParams) => + case defn.ContextFunctionType(argTypes, resType) => MethodType(argTypes, integrateContextResults(resType, crCount - 1)) /** The total number of parameters of method `sym`, not counting @@ -83,9 +84,10 @@ object ContextFunctionResults: def contextParamCount(tp: Type, crCount: Int): Int = if crCount == 0 then 0 else - val defn.ContextFunctionType(params, resTpe, erasedParams) = tp: @unchecked + val defn.ContextFunctionType(params, resTpe) = tp: @unchecked val rest = contextParamCount(resTpe, crCount - 1) - if erasedParams.contains(true) then erasedParams.count(_ == false) + rest else params.length + rest + val nonErasedParams = params.count(!_.hasAnnotation(defn.ErasedParamAnnot)) + nonErasedParams + rest def normalParamCount(tp: Type): Int = tp.widenExpr.stripPoly match case mt @ MethodType(pnames) => @@ -103,7 +105,7 @@ object ContextFunctionResults: def recur(tp: Type, n: Int): Type = if n == 0 then tp else tp match - case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1) + case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1) recur(meth.info.finalResultType, depth) /** Should selection `tree` be eliminated since it refers to an `apply` @@ -118,7 +120,7 @@ object ContextFunctionResults: case Select(qual, name) => if name == nme.apply then qual.tpe match - case defn.ContextFunctionType(_, _, _) => + case defn.ContextFunctionType(_, _) => integrateSelect(qual, n + 1) case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs integrateSelect(qual, n + 1) diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 25cbfdfec600..339d1f2f7bc6 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -167,7 +167,7 @@ object ErrorReporting { val normPt = normalize(pt, pt) def contextFunctionCount(tp: Type): Int = tp.stripped match - case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp) + case defn.ContextFunctionType(_, restp) => 1 + contextFunctionCount(restp) case _ => 0 def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp) def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index df708057dd71..9aad20113154 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1885,7 +1885,7 @@ class Namer { typer: Typer => val originalTp = defaultParamType val approxTp = wildApprox(originalTp) approxTp.stripPoly match - case atp @ defn.ContextFunctionType(_, resType, _) + case atp @ defn.ContextFunctionType(_, resType) if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound || resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) => originalTp