Skip to content

Commit da16f43

Browse files
authored
Replace is{Poly|Erased}FunctionType with {PolyOrErased,Poly,Erased}FunctionOf (#18207)
2 parents b7e797d + 163cdf5 commit da16f43

File tree

11 files changed

+68
-42
lines changed

11 files changed

+68
-42
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
954954
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
955955
def isStructuralTermSelect(tree: Select) =
956956
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
957+
case defn.PolyOrErasedFunctionOf(_) =>
958+
false
957959
case RefinedType(parent, rname, rinfo) =>
958960
rname == tree.name || hasRefinement(parent)
959961
case tp: TypeProxy =>
@@ -966,10 +968,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
966968
false
967969
!tree.symbol.exists
968970
&& tree.isTerm
969-
&& {
970-
val qualType = tree.qualifier.tpe
971-
hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType)
972-
}
971+
&& hasRefinement(tree.qualifier.tpe)
973972
def loop(tree: Tree): Boolean = tree match
974973
case TypeApply(fun, _) =>
975974
loop(fun)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,7 @@ class Definitions {
11151115
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
11161116
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
11171117
ft.dealias match
1118-
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
1118+
case ErasedFunctionOf(mt) =>
11191119
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
11201120
case _ =>
11211121
val tsym = ft.dealias.typeSymbol
@@ -1127,6 +1127,42 @@ class Definitions {
11271127
}
11281128
}
11291129

1130+
object PolyOrErasedFunctionOf {
1131+
/** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info.
1132+
*
1133+
* Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }`
1134+
*/
1135+
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
1136+
case RefinedType(parent, nme.apply, mt: MethodicType)
1137+
if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) =>
1138+
Some(mt)
1139+
case _ => None
1140+
}
1141+
1142+
object PolyFunctionOf {
1143+
/** Matches a refined `PolyFunction` type and extracts the apply info.
1144+
*
1145+
* Pattern: `PolyFunction { def apply: $pt }`
1146+
*/
1147+
def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match
1148+
case RefinedType(parent, nme.apply, pt: PolyType)
1149+
if parent.derivesFrom(defn.PolyFunctionClass) =>
1150+
Some(pt)
1151+
case _ => None
1152+
}
1153+
1154+
object ErasedFunctionOf {
1155+
/** Matches a refined `ErasedFunction` type and extracts the apply info.
1156+
*
1157+
* Pattern: `ErasedFunction { def apply: $mt }`
1158+
*/
1159+
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
1160+
case RefinedType(parent, nme.apply, mt: MethodType)
1161+
if parent.derivesFrom(defn.ErasedFunctionClass) =>
1162+
Some(mt)
1163+
case _ => None
1164+
}
1165+
11301166
object PartialFunctionOf {
11311167
def apply(arg: Type, result: Type)(using Context): Type =
11321168
PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil)
@@ -1714,26 +1750,16 @@ class Definitions {
17141750
def isFunctionNType(tp: Type)(using Context): Boolean =
17151751
isNonRefinedFunction(tp.dropDependentRefinement)
17161752

1717-
/** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */
1718-
def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean =
1719-
isPolyFunctionType(tp) || isErasedFunctionType(tp)
1720-
1721-
/** Does `tp` derive from `PolyFunction`? */
1722-
def isPolyFunctionType(tp: Type)(using Context): Boolean =
1723-
tp.derivesFrom(defn.PolyFunctionClass)
1724-
1725-
/** Does `tp` derive from `ErasedFunction`? */
1726-
def isErasedFunctionType(tp: Type)(using Context): Boolean =
1727-
tp.derivesFrom(defn.ErasedFunctionClass)
1728-
17291753
/** Returns whether `tp` is an instance or a refined instance of:
17301754
* - scala.FunctionN
17311755
* - scala.ContextFunctionN
17321756
* - ErasedFunction
17331757
* - PolyFunction
17341758
*/
17351759
def isFunctionType(tp: Type)(using Context): Boolean =
1736-
isFunctionNType(tp) || isPolyOrErasedFunctionType(tp)
1760+
isFunctionNType(tp)
1761+
|| tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement?
1762+
|| tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement?
17371763

17381764
private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
17391765
if !ctx.settings.Yscala2Stdlib.value then
@@ -1837,7 +1863,7 @@ class Definitions {
18371863
tp.stripTypeVar.dealias match
18381864
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
18391865
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
1840-
case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod =>
1866+
case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod =>
18411867
tp1
18421868
case tp1 =>
18431869
if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1
@@ -1857,7 +1883,7 @@ class Definitions {
18571883
atPhase(erasurePhase)(unapply(tp))
18581884
else
18591885
asContextFunctionType(tp) match
1860-
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
1886+
case ErasedFunctionOf(mt) =>
18611887
Some((mt.paramInfos, mt.resType, mt.erasedParams))
18621888
case tp1 if tp1.exists =>
18631889
val args = tp1.functionArgInfos
@@ -1867,7 +1893,7 @@ class Definitions {
18671893

18681894
/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
18691895
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
1870-
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
1896+
case ErasedFunctionOf(mt) => mt.erasedParams
18711897
case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false }
18721898
case _ => Nil
18731899
}

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal {
509509
* Handles `ErasedFunction`s and poly functions gracefully.
510510
*/
511511
final def functionArgInfos(using Context): List[Type] = self.dealias match
512-
case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType)
512+
case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType)
513513
case _ => self.dropDependentRefinement.dealias.argInfos
514514

515515
/** Argument types where existential types in arguments are disallowed */

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
666666
isSubType(info1, info2)
667667

668668
if defn.isFunctionType(tp2) then
669-
if defn.isPolyFunctionType(tp2) then
669+
if tp2.derivesFrom(defn.PolyFunctionClass) then
670670
// TODO should we handle ErasedFunction is this same way?
671671
tp1.member(nme.apply).info match
672672
case info1: PolyType =>

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
654654
else SuperType(eThis, eSuper)
655655
case ExprType(rt) =>
656656
defn.FunctionType(0)
657-
case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) =>
658-
eraseRefinedFunctionApply(refinedInfo)
657+
case defn.PolyOrErasedFunctionOf(mt) =>
658+
eraseRefinedFunctionApply(mt)
659659
case tp: TypeVar if !tp.isInstantiated =>
660660
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
661661
WildcardType
@@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
936936
sigName(defn.FunctionOf(Nil, rt))
937937
case tp: TypeVar if !tp.isInstantiated =>
938938
tpnme.Uninstantiated
939-
case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) =>
939+
case tp @ defn.PolyOrErasedFunctionOf(_) =>
940940
// we need this case rather than falling through to the default
941941
// because RefinedTypes <: TypeProxy and it would be caught by
942942
// the case immediately below

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,9 +1747,7 @@ object Types {
17471747
if !tf1.exists then tf2
17481748
else if !tf2.exists then tf1
17491749
else NoType
1750-
case t if defn.isNonRefinedFunction(t) =>
1751-
t
1752-
case t if defn.isErasedFunctionType(t) =>
1750+
case t if defn.isFunctionType(t) =>
17531751
t
17541752
case t @ SAMType(_, _) =>
17551753
t

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ object Erasure {
679679
// Instead, we manually lookup the type of `apply` in the qualifier.
680680
inContext(preErasureCtx) {
681681
val qualTp = tree.qualifier.typeOpt.widen
682-
if defn.isPolyOrErasedFunctionType(qualTp) then
682+
if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then
683683
eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol
684684
else
685685
NoSymbol

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,11 @@ object TreeChecker {
447447
val tpe = tree.typeOpt
448448

449449
// PolyFunction and ErasedFunction apply methods stay structural until Erasure
450-
val isRefinedFunctionApply = (tree.name eq nme.apply) && defn.isPolyOrErasedFunctionType(tree.qualifier.typeOpt)
450+
val isRefinedFunctionApply = (tree.name eq nme.apply) && {
451+
val qualTpe = tree.qualifier.typeOpt
452+
qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass)
453+
}
454+
451455
// Outer selects are pickled specially so don't require a symbol
452456
val isOuterSelect = tree.name.is(OuterSelectName)
453457
val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name)

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
105105
expected =:= defn.FunctionOf(actualArgs, actualRet,
106106
defn.isContextFunctionType(baseFun))
107107
val arity: Int =
108-
if defn.isErasedFunctionType(fun) then -1 // TODO support?
108+
if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support?
109109
else if defn.isFunctionNType(fun) then
110110
// TupledFunction[(...) => R, ?]
111111
fun.functionArgInfos match

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13271327

13281328
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13291329
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
1330-
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1330+
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
1331+
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1332+
case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity =>
13311333
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13321334
case SAMType(mt @ MethodTpe(_, formals, _), samParent) =>
13331335
val restpe = mt.resultType match
@@ -1649,11 +1651,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16491651
// If the expected type is a polymorphic function with the same number of
16501652
// type and value parameters, then infer the types of value parameters from the expected type.
16511653
val inferredVParams = pt match
1652-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType))
1653-
if (parent.typeSymbol eq defn.PolyFunctionClass)
1654-
&& tparams.lengthCompare(poly.paramNames) == 0
1655-
&& vparams.lengthCompare(mt.paramNames) == 0
1656-
=>
1654+
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
1655+
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
16571656
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
16581657
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
16591658
// it must be a valid method parameter type.
@@ -1668,7 +1667,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16681667
vparams
16691668

16701669
val resultTpt = pt.dealias match
1671-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1670+
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
16721671
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16731672
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
16741673
case _ => untpd.TypeTree()
@@ -3235,8 +3234,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
32353234
else formals.map(untpd.TypeTree)
32363235
}
32373236

3238-
val erasedParams = pt.dealias match {
3239-
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
3237+
val erasedParams = pt match {
3238+
case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams
32403239
case _ => paramTypes.map(_ => false)
32413240
}
32423241

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1788,7 +1788,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
17881788
def isContextFunctionType: Boolean =
17891789
dotc.core.Symbols.defn.isContextFunctionType(self)
17901790
def isErasedFunctionType: Boolean =
1791-
dotc.core.Symbols.defn.isErasedFunctionType(self)
1791+
self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass)
17921792
def isDependentFunctionType: Boolean =
17931793
val tpNoRefinement = self.dropDependentRefinement
17941794
tpNoRefinement != self

0 commit comments

Comments
 (0)