Skip to content

Refactor function type logic #18193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 13, 2023
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
&& tree.isTerm
&& {
val qualType = tree.qualifier.tpe
hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType)
hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType)
}
def loop(tree: Tree): Boolean = tree match
case TypeApply(fun, _) =>
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class CheckCaptures extends Recheck, SymTransformer:
capt.println(i"solving $t")
refs.solve()
traverse(parent)
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionOrPolyType(t) =>
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
traverse(rinfo)
case tp: TypeVar =>
case tp: TypeRef =>
Expand Down Expand Up @@ -638,7 +638,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case expected @ CapturingType(eparent, refs) =>
CapturingType(recur(eparent), refs, boxed = expected.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) =>
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case _ =>
Expand Down Expand Up @@ -707,7 +707,7 @@ class CheckCaptures extends Recheck, SymTransformer:
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
val ares1 = adapt(ares, eres, covariant)
Expand Down Expand Up @@ -769,7 +769,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) =>
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
Expand All @@ -779,7 +779,7 @@ class CheckCaptures extends Recheck, SymTransformer:
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
Expand Down Expand Up @@ -996,7 +996,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) =>
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
Expand Down Expand Up @@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
else tp
Expand Down Expand Up @@ -329,7 +329,7 @@ extends tpd.TreeTraverser:
args.last, CaptureSet.empty, currentCs ++ outerCs)
tp.derivedAppliedType(tycon1, args1 :+ resType1)
tp1.capturing(outerCs)
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
.capturing(outerCs)
case _ =>
Expand Down
28 changes: 17 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,7 @@ class Definitions {
def PolyFunctionType = PolyFunctionClass.typeRef

lazy val ErasedFunctionClass = requiredClass("scala.runtime.ErasedFunction")
def ErasedFunctionType = ErasedFunctionClass.typeRef

/** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */
def scalaClassName(cls: Symbol)(using Context): TypeName = cls.denot match
Expand Down Expand Up @@ -1709,21 +1710,29 @@ class Definitions {
* - scala.FunctionN
* - scala.ContextFunctionN
*/
def isFunctionType(tp: Type)(using Context): Boolean =
def isFunctionNType(tp: Type)(using Context): Boolean =
isNonRefinedFunction(tp.dropDependentRefinement)

/** Is `tp` a specialized, refined function type? Either an `ErasedFunction` or a `PolyFunction`. */
def isRefinedFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.PolyFunctionClass) || isErasedFunctionType(tp)
/** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */
def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean =
isPolyFunctionType(tp) || isErasedFunctionType(tp)

/** Does `tp` derive from `PolyFunction`? */
def isPolyFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.PolyFunctionClass)

/** Does `tp` derive from `ErasedFunction`? */
def isErasedFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.ErasedFunctionClass)

/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
* - scala.ContextFunctionN
* - ErasedFunction
* - PolyFunction
*/
def isFunctionOrPolyType(tp: Type)(using Context): Boolean =
isFunctionType(tp) || isRefinedFunctionType(tp)
def isFunctionType(tp: Type)(using Context): Boolean =
isFunctionNType(tp) || isPolyOrErasedFunctionType(tp)

private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
if !ctx.settings.Yscala2Stdlib.value then
Expand Down Expand Up @@ -1830,7 +1839,7 @@ class Definitions {
case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod =>
tp1
case tp1 =>
if tp1.typeSymbol.name.isContextFunction && isFunctionType(tp1) then tp1
if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1
else NoType

/** Is `tp` an context function type? */
Expand Down Expand Up @@ -1858,13 +1867,10 @@ class Definitions {
/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
case tp if isFunctionType(tp) => List.fill(functionArity(tp)) { false }
case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false }
case _ => Nil
}

def isErasedFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.ErasedFunctionClass)

/** A whitelist of Scala-2 classes that are known to be pure */
def isAssuredNoInits(sym: Symbol): Boolean =
(sym `eq` SomeClass) || isTupleClass(sym)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal {
* Handles `ErasedFunction`s and poly functions gracefully.
*/
final def functionArgInfos(using Context): List[Type] = self.dealias match
case RefinedType(parent, nme.apply, mt: MethodType) if defn.isErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType)
case RefinedType(parent, nme.apply, mt: MethodType) if parent.typeSymbol eq defn.PolyFunctionClass => (mt.paramInfos :+ mt.resultType)
case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType)
case _ => self.dropDependentRefinement.dealias.argInfos

/** Argument types where existential types in arguments are disallowed */
Expand Down
20 changes: 11 additions & 9 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
isSubType(info1, info2)

if defn.isFunctionType(tp2) then
tp1w.widenDealias match
case tp1: RefinedType =>
return isSubInfo(tp1.refinedInfo, tp2.refinedInfo)
case _ =>
else if tp2.parent.typeSymbol == defn.PolyFunctionClass then
tp1.member(nme.apply).info match
case info1: PolyType =>
return isSubInfo(info1, tp2.refinedInfo)
case _ =>
if defn.isPolyFunctionType(tp2) then
// TODO should we handle ErasedFunction is this same way?
tp1.member(nme.apply).info match
case info1: PolyType =>
return isSubInfo(info1, tp2.refinedInfo)
case _ =>
else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why there is a difference between the two comparisons. Would it work to always use the second one, which tests for RefinedType?

Copy link
Contributor Author

@nicolasstucki nicolasstucki Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to merge the two cases into one (in #18200).

tp1w.widenDealias match
case tp1: RefinedType =>
return isSubInfo(tp1.refinedInfo, tp2.refinedInfo)
case _ =>

val skipped2 = skipMatching(tp1w, tp2)
if (skipped2 eq tp2) || !Config.fastPathForRefinedSubtype then
Expand Down
37 changes: 14 additions & 23 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,21 +560,16 @@ object TypeErasure {
case _ => false
}

/** The erasure of `PolyFunction { def apply: $applyInfo }` */
def erasePolyFunctionApply(applyInfo: Type)(using Context): Type =
assert(applyInfo.isInstanceOf[PolyType])
val res = applyInfo.resultType
val paramss = res.paramNamess
assert(paramss.length == 1)
erasure(defn.FunctionType(paramss.head.length,
isContextual = res.isImplicitMethod))

def eraseErasedFunctionApply(erasedFn: MethodType)(using Context): Type =
val fnType = defn.FunctionType(
n = erasedFn.erasedParams.count(_ == false),
isContextual = erasedFn.isContextualMethod,
)
erasure(fnType)
/** The erasure of `(PolyFunction | ErasedFunction) { def apply: $applyInfo }` */
def eraseRefinedFunctionApply(applyInfo: Type)(using Context): Type =
def functionType(info: Type): Type = info match {
case info: PolyType =>
functionType(info.resultType)
case info: MethodType =>
assert(!info.resultType.isInstanceOf[MethodicType])
defn.FunctionType(n = info.erasedParams.count(_ == false))
}
erasure(functionType(applyInfo))
}

import TypeErasure._
Expand All @@ -592,7 +587,7 @@ import TypeErasure._
*/
class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, isSymbol: Boolean, inSigName: Boolean) {

/** The erasure |T| of a type T.
/** The erasure |T| of a type T.
*
* If computing the erasure of T requires erasing a WildcardType or an
* uninstantiated type variable, then an exception signaling an internal
Expand Down Expand Up @@ -659,10 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
else SuperType(eThis, eSuper)
case ExprType(rt) =>
defn.FunctionType(0)
case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass =>
erasePolyFunctionApply(refinedInfo)
case RefinedType(parent, nme.apply, refinedInfo: MethodType) if defn.isErasedFunctionType(parent) =>
eraseErasedFunctionApply(refinedInfo)
case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) =>
eraseRefinedFunctionApply(refinedInfo)
case tp: TypeVar if !tp.isInstantiated =>
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
WildcardType
Expand Down Expand Up @@ -943,13 +936,11 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
sigName(defn.FunctionOf(Nil, rt))
case tp: TypeVar if !tp.isInstantiated =>
tpnme.Uninstantiated
case tp @ RefinedType(parent, nme.apply, _) if parent.typeSymbol eq defn.PolyFunctionClass =>
case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) =>
// we need this case rather than falling through to the default
// because RefinedTypes <: TypeProxy and it would be caught by
// the case immediately below
sigName(this(tp))
case tp @ RefinedType(parent, nme.apply, refinedInfo) if defn.isErasedFunctionType(parent) =>
sigName(this(tp))
case tp: TypeProxy =>
sigName(tp.underlying)
case tp: WildcardType =>
Expand Down
39 changes: 22 additions & 17 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ object Types {
}
findMember(name, pre, required, excluded)
}

/** The implicit members with given name. If there are none and the denotation
* contains private members, also look for shadowed non-private implicits.
*/
Expand Down Expand Up @@ -1875,20 +1875,25 @@ object Types {
* @param alwaysDependent if true, always create a dependent function type.
*/
def toFunctionType(isJava: Boolean, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match {
case mt: MethodType if !mt.isParamDependent =>
val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast
val isContextual = mt.isContextualMethod && !ctx.erasedTypes
val result1 = mt.nonDependentResultApprox match {
case res: MethodType => res.toFunctionType(isJava)
case res => res
}
val funType = defn.FunctionOf(
formals1 mapConserve (_.translateFromRepeated(toArray = isJava)),
result1, isContextual)
if alwaysDependent || mt.isResultDependent then
RefinedType(funType, nme.apply, mt)
else funType
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
case mt: MethodType =>
assert(!mt.isParamDependent)
def nonDependentFunType =
val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast
val isContextual = mt.isContextualMethod && !ctx.erasedTypes
val result1 = mt.nonDependentResultApprox match {
case res: MethodType => res.toFunctionType(isJava)
case res => res
}
defn.FunctionOf(
formals1 mapConserve (_.translateFromRepeated(toArray = isJava)),
result1, isContextual)
if mt.hasErasedParams then
RefinedType(defn.ErasedFunctionType, nme.apply, mt)
else if alwaysDependent || mt.isResultDependent then
RefinedType(nonDependentFunType, nme.apply, mt)
else nonDependentFunType
case poly @ PolyType(_, mt: MethodType) =>
assert(!mt.isParamDependent)
RefinedType(defn.PolyFunctionType, nme.apply, poly)
}

Expand Down Expand Up @@ -4071,9 +4076,9 @@ object Types {
def addInto(tp: Type): Type = tp match
case tp @ AppliedType(tycon, args) if tycon.typeSymbol == defn.RepeatedParamClass =>
tp.derivedAppliedType(tycon, addInto(args.head) :: Nil)
case tp @ AppliedType(tycon, args) if defn.isFunctionType(tp) =>
case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) =>
wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last)))
case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionOrPolyType(tp) =>
case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) =>
wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo)))
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = addInto(tp.resType))
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
// don't eta contract if the application would be printed specially
toText(tycon)
case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug =>
case tp: RefinedType if defn.isFunctionType(tp) && !printDebug =>
toTextMethodAsFunction(tp.refinedInfo,
isPure = Feature.pureFunsEnabled && !tp.typeSymbol.name.isImpureFunction)
case tp: TypeRef =>
Expand Down Expand Up @@ -771,7 +771,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
override protected def toTextCapturing(tp: Type, refsText: Text, boxText: Text): Text = tp match
case tp: AppliedType if defn.isFunctionSymbol(tp.typeSymbol) && !printDebug =>
boxText ~ toTextFunction(tp, refsText)
case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug =>
case tp: RefinedType if defn.isFunctionType(tp) && !printDebug =>
boxText ~ toTextMethodAsFunction(tp.refinedInfo, isPure = !tp.typeSymbol.name.isImpureFunction, refsText)
case _ =>
super.toTextCapturing(tp, refsText, boxText)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object BetaReduce:
recur(expr, argss)
case _ => None
tree match
case Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
case Apply(Select(fn, nme.apply), args) if defn.isFunctionNType(fn.tpe) =>
recur(fn, List(args)) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,8 @@ object Erasure {
// Instead, we manually lookup the type of `apply` in the qualifier.
inContext(preErasureCtx) {
val qualTp = tree.qualifier.typeOpt.widen
if qualTp.derivesFrom(defn.PolyFunctionClass) then
erasePolyFunctionApply(qualTp.select(nme.apply).widen).classSymbol
else if defn.isErasedFunctionType(qualTp) then
eraseErasedFunctionApply(qualTp.select(nme.apply).widen.asInstanceOf[MethodType]).classSymbol
if defn.isPolyOrErasedFunctionType(qualTp) then
eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol
else
NoSymbol
}
Expand Down
Loading