@@ -3590,14 +3590,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3590
3590
3591
3591
private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
3592
3592
case tpe : MethodType =>
3593
- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3593
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3594
3594
case tpe : PolyType =>
3595
- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3595
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3596
3596
case tpe : RefinedType =>
3597
- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3598
- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3597
+ tpe.derivedRefinedType(
3598
+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3599
+ tpe.refinedName,
3600
+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3601
+ )
3599
3602
case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3600
- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3603
+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3601
3604
case tpe =>
3602
3605
val paramNames = params.map(_.name)
3603
3606
val paramTpts = params.map(_.tpt)
@@ -3606,18 +3609,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3606
3609
typed(ctxFunction).tpe
3607
3610
}
3608
3611
3609
- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3612
+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3613
+ case tpe : MethodType =>
3614
+ tpe.paramNames -> tpe.paramInfos
3615
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3616
+ extractTopMethodTermParams(tpe.refinedInfo)
3617
+ case _ =>
3618
+ Nil -> Nil
3619
+ }
3620
+
3621
+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3622
+ case tpe : MethodType =>
3623
+ tpe.resultType
3624
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3625
+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3626
+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3627
+ tpe.args.last
3628
+ case _ =>
3629
+ tpe
3630
+ }
3631
+
3632
+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3633
+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3634
+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3635
+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3636
+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3637
+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3638
+ val nestedCtx = ctx.fresh.setNewTyperState()
3639
+ typed(newDefDef)(using nestedCtx)
3640
+ case _ => tree
3641
+ }
3642
+
3643
+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3610
3644
tree.getAttachment(desugar.PolyFunctionApply ) match
3611
3645
case Some (params) if params.nonEmpty =>
3612
3646
tree.removeAttachment(desugar.PolyFunctionApply )
3613
3647
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
3614
3648
TypeTree (tpe).withSpan(tree.span) -> tpe
3649
+ // case Some(params) if params.isEmpty =>
3650
+ // println(s"tree: $tree")
3651
+ // healToPolyFunctionType(tree) -> pt
3615
3652
case _ => tree -> pt
3616
3653
}
3617
3654
3618
3655
/** Interpolate and simplify the type of the given tree. */
3619
3656
protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3620
- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3657
+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
3621
3658
if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3622
3659
if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
3623
3660
|| tree1.isDef // ... unless tree is a definition
0 commit comments