diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index bb52b6a5e342..5ac1c23d2e7f 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -947,6 +947,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = { def isStructuralTermSelect(tree: Select) = def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match + case defn.PolyOrErasedFunctionOf(_) => + false case RefinedType(parent, rname, rinfo) => rname == tree.name || hasRefinement(parent) case tp: TypeProxy => @@ -959,10 +961,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => false !tree.symbol.exists && tree.isTerm - && { - val qualType = tree.qualifier.tpe - hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType) - } + && hasRefinement(tree.qualifier.tpe) def loop(tree: Tree): Boolean = tree match case TypeApply(fun, _) => loop(fun) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index a0584e97a026..f094063c97c7 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1108,7 +1108,7 @@ class Definitions { FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil) def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = { ft.dealias match - case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + case ErasedFunctionOf(mt) => Some(mt.paramInfos, mt.resType, mt.isContextualMethod) case _ => val tsym = ft.dealias.typeSymbol @@ -1120,6 +1120,42 @@ class Definitions { } } + object PolyOrErasedFunctionOf { + /** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info. + * + * Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }` + */ + def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match + case RefinedType(parent, nme.apply, mt: MethodicType) + if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) => + Some(mt) + case _ => None + } + + object PolyFunctionOf { + /** Matches a refined `PolyFunction` type and extracts the apply info. + * + * Pattern: `PolyFunction { def apply: $pt }` + */ + def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match + case RefinedType(parent, nme.apply, pt: PolyType) + if parent.derivesFrom(defn.PolyFunctionClass) => + Some(pt) + case _ => None + } + + object ErasedFunctionOf { + /** Matches a refined `ErasedFunction` type and extracts the apply info. + * + * Pattern: `ErasedFunction { def apply: $mt }` + */ + def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match + case RefinedType(parent, nme.apply, mt: MethodType) + if parent.derivesFrom(defn.ErasedFunctionClass) => + Some(mt) + case _ => None + } + object PartialFunctionOf { def apply(arg: Type, result: Type)(using Context): Type = PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil) @@ -1705,18 +1741,6 @@ class Definitions { def isFunctionNType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) - /** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */ - def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean = - isPolyFunctionType(tp) || isErasedFunctionType(tp) - - /** Does `tp` derive from `PolyFunction`? */ - def isPolyFunctionType(tp: Type)(using Context): Boolean = - tp.derivesFrom(defn.PolyFunctionClass) - - /** Does `tp` derive from `ErasedFunction`? */ - def isErasedFunctionType(tp: Type)(using Context): Boolean = - tp.derivesFrom(defn.ErasedFunctionClass) - /** Returns whether `tp` is an instance or a refined instance of: * - scala.FunctionN * - scala.ContextFunctionN @@ -1724,7 +1748,9 @@ class Definitions { * - PolyFunction */ def isFunctionType(tp: Type)(using Context): Boolean = - isFunctionNType(tp) || isPolyOrErasedFunctionType(tp) + isFunctionNType(tp) + || tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement? + || tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement? private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) = for base <- bases; tp <- paramTypes do @@ -1825,7 +1851,7 @@ class Definitions { tp.stripTypeVar.dealias match case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) => asContextFunctionType(TypeComparer.bounds(tp1).hiBound) - case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod => + case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod => tp1 case tp1 => if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1 @@ -1845,7 +1871,7 @@ class Definitions { atPhase(erasurePhase)(unapply(tp)) else asContextFunctionType(tp) match - case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + case ErasedFunctionOf(mt) => Some((mt.paramInfos, mt.resType, mt.erasedParams)) case tp1 if tp1.exists => val args = tp1.functionArgInfos @@ -1855,7 +1881,7 @@ class Definitions { /* Returns a list of erased booleans marking whether parameters are erased, for a function type. */ def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match { - case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + case ErasedFunctionOf(mt) => mt.erasedParams case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false } case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala index c1b2541c460b..8ce0da9bc50f 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala @@ -509,7 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal { * Handles `ErasedFunction`s and poly functions gracefully. */ final def functionArgInfos(using Context): List[Type] = self.dealias match - case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType) + case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType) case _ => self.dropDependentRefinement.dealias.argInfos /** Argument types where existential types in arguments are disallowed */ diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 01709a9cd41d..8dac1f95e1b1 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -659,7 +659,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubType(info1, info2) if defn.isFunctionType(tp2) then - if defn.isPolyFunctionType(tp2) then + if tp2.derivesFrom(defn.PolyFunctionClass) then // TODO should we handle ErasedFunction is this same way? tp1.member(nme.apply).info match case info1: PolyType => diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index b0938cfb7c64..baa6bf21e64e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -654,8 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst else SuperType(eThis, eSuper) case ExprType(rt) => defn.FunctionType(0) - case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) => - eraseRefinedFunctionApply(refinedInfo) + case defn.PolyOrErasedFunctionOf(mt) => + eraseRefinedFunctionApply(mt) case tp: TypeVar if !tp.isInstantiated => assert(inSigName, i"Cannot erase uninstantiated type variable $tp") WildcardType @@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst sigName(defn.FunctionOf(Nil, rt)) case tp: TypeVar if !tp.isInstantiated => tpnme.Uninstantiated - case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) => + case tp @ defn.PolyOrErasedFunctionOf(_) => // we need this case rather than falling through to the default // because RefinedTypes <: TypeProxy and it would be caught by // the case immediately below diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 6b1a87ad35d1..30999b620732 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1737,9 +1737,7 @@ object Types { if !tf1.exists then tf2 else if !tf2.exists then tf1 else NoType - case t if defn.isNonRefinedFunction(t) => - t - case t if defn.isErasedFunctionType(t) => + case t if defn.isFunctionType(t) => t case t @ SAMType(_, _) => t diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index b106b237cfb9..fce29afdd638 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -677,7 +677,7 @@ object Erasure { // Instead, we manually lookup the type of `apply` in the qualifier. inContext(preErasureCtx) { val qualTp = tree.qualifier.typeOpt.widen - if defn.isPolyOrErasedFunctionType(qualTp) then + if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol else NoSymbol diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 05cca875b791..20290a2ee1f7 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -446,7 +446,11 @@ object TreeChecker { val tpe = tree.typeOpt // PolyFunction and ErasedFunction apply methods stay structural until Erasure - val isRefinedFunctionApply = (tree.name eq nme.apply) && defn.isPolyOrErasedFunctionType(tree.qualifier.typeOpt) + val isRefinedFunctionApply = (tree.name eq nme.apply) && { + val qualTpe = tree.qualifier.typeOpt + qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass) + } + // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 11038d54fa73..85d1f084e6a0 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): expected =:= defn.FunctionOf(actualArgs, actualRet, defn.isContextFunctionType(baseFun)) val arity: Int = - if defn.isErasedFunctionType(fun) then -1 // TODO support? + if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support? else if defn.isFunctionNType(fun) then // TupledFunction[(...) => R, ?] fun.functionArgInfos match diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index efb715566068..77c170007188 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1322,7 +1322,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) - if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity => + if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => + (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) + case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity => (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) case SAMType(mt @ MethodTpe(_, formals, restpe), _) => (formals, @@ -3162,8 +3164,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else formals.map(untpd.TypeTree) } - val erasedParams = pt.dealias match { - case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + val erasedParams = pt match { + case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams case _ => paramTypes.map(_ => false) } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 4e1d75624f2c..c3f7445dd22a 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -1779,7 +1779,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler def isContextFunctionType: Boolean = dotc.core.Symbols.defn.isContextFunctionType(self) def isErasedFunctionType: Boolean = - dotc.core.Symbols.defn.isErasedFunctionType(self) + self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass) def isDependentFunctionType: Boolean = val tpNoRefinement = self.dropDependentRefinement tpNoRefinement != self