Skip to content

Commit 60b2d02

Browse files
committed
Improve defn.PolyFunctionOf extractor
* Only match `RefinedType` representing the `PolyFunction`. This will allow us to use `derivedRefinedType` on the function type. * Only match the refinement if it is a `MethodOrPoly`. `ExprType` is not a valid `PolyFunction` refinement. * Remove `dealias` in `PolyFunctionOf` extractor. There was only one case where this was necessary and it added unnecessary overhead.
1 parent 6e370a9 commit 60b2d02

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ class CheckCaptures extends Recheck, SymTransformer:
808808

809809
try
810810
val eres = expected.dealias.stripCapturing match
811-
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
811+
case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType
812812
case expected: PolyType => expected.resType
813813
case _ => WildcardType
814814

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,11 +1140,12 @@ class Definitions {
11401140
*
11411141
* Pattern: `PolyFunction { def apply: $mt }`
11421142
*/
1143-
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
1144-
case RefinedType(parent, nme.apply, mt: MethodicType)
1145-
if parent.derivesFrom(defn.PolyFunctionClass) =>
1146-
Some(mt)
1147-
case _ => None
1143+
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] =
1144+
tpe.refinedInfo match
1145+
case mt: MethodOrPoly
1146+
if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) =>
1147+
Some(mt)
1148+
case _ => None
11481149

11491150
private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
11501151
def isValidMethodType(info: Type) = info match

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,10 +1648,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16481648
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
16491649
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
16501650
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
1651+
val dpt = pt.dealias
16511652

16521653
// If the expected type is a polymorphic function with the same number of
16531654
// type and value parameters, then infer the types of value parameters from the expected type.
1654-
val inferredVParams = pt match
1655+
val inferredVParams = dpt match
16551656
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
16561657
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
16571658
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
@@ -1667,7 +1668,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16671668
case _ =>
16681669
vparams
16691670

1670-
val resultTpt = pt.dealias match
1671+
val resultTpt = dpt match
16711672
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
16721673
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16731674
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))

0 commit comments

Comments
 (0)