diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index f837ec233bb6..be8648206f09 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1036,6 +1036,40 @@ object desugar { name } + /** Strip parens and empty blocks around the body of `tree`. */ + def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction = + def stripped(body: Tree): Tree = body match + case Parens(body1) => + stripped(body1) + case Block(Nil, body1) => + stripped(body1) + case _ => body + cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction] + + /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R + * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } + */ + def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = + val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked + val funFlags = fun match + case fun: FunctionWithMods => + fun.mods.flags + case _ => EmptyFlags + + // TODO: make use of this in the desugaring when pureFuns is enabled. + // val isImpure = funFlags.is(Impure) + + // Function flags to be propagated to each parameter in the desugared method type. + val paramFlags = funFlags.toTermFlags & Given + val vparams = vparamTypes.zipWithIndex.map: + case (p: ValDef, _) => p.withAddedFlags(paramFlags) + case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic) + )).withSpan(tree.span) + end makePolyFunctionType + /** Invent a name for an anonympus given of type or template `impl`. */ def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName = val str = impl match @@ -1429,17 +1463,20 @@ object desugar { } /** Make closure corresponding to function. - * params => body + * [tparams] => params => body * ==> - * def $anonfun(params) = body + * def $anonfun[tparams](params) = body * Closure($anonfun) */ - def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, isContextual: Boolean, span: Span)(using Context): Block = + def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block = + val paramss: List[ParamClause] = + if tparams.isEmpty then vparams :: Nil + else tparams :: vparams :: Nil Block( - DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body) + DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body) .withSpan(span) .withMods(synthetic | Artifact), - Closure(Nil, Ident(nme.ANON_FUN), if (isContextual) ContextualEmptyTree else EmptyTree)) + Closure(Nil, Ident(nme.ANON_FUN), EmptyTree)) /** If `nparams` == 1, expand partial function * @@ -1728,62 +1765,6 @@ object desugar { } } - def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match { - case Parens(body1) => - makePolyFunction(targs, body1, pt) - case Block(Nil, body1) => - makePolyFunction(targs, body1, pt) - case Function(vargs, res) => - assert(targs.nonEmpty) - // TODO: Figure out if we need a `PolyFunctionWithMods` instead. - val mods = body match { - case body: FunctionWithMods => body.mods - case _ => untpd.EmptyModifiers - } - val polyFunctionTpt = ref(defn.PolyFunctionType) - val applyTParams = targs.asInstanceOf[List[TypeDef]] - if (ctx.mode.is(Mode.Type)) { - // Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R - // Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } - - val applyVParams = vargs.zipWithIndex.map { - case (p: ValDef, _) => p.withAddedFlags(mods.flags) - case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags) - } - RefinedTypeTree(polyFunctionTpt, List( - DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic) - )) - } - else { - // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body - // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R - // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body } - // where R2 is R, with all references to S_1..S_M replaced with T1..T_M. - - def typeTree(tp: Type) = tp match - case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass => - var bail = false - def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match - case tp: TypeRef => ref(tp) - case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name) - case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_))) - case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree } - val mapped = mapper(mt.resultType, topLevel = true) - if bail then TypeTree() else mapped - case _ => TypeTree() - - val applyVParams = vargs.asInstanceOf[List[ValDef]] - .map(varg => varg.withAddedFlags(mods.flags | Param)) - New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, - List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res)) - )) - } - case _ => - // may happen for erroneous input. An error will already have been reported. - assert(ctx.reporter.errorsReported) - EmptyTree - } - // begin desugar // Special case for `Parens` desugaring: unlike all the desugarings below, @@ -1796,8 +1777,6 @@ object desugar { } val desugared = tree match { - case PolyFunction(targs, body) => - makePolyFunction(targs, body, pt) orElse tree case SymbolLit(str) => Apply( ref(defn.ScalaSymbolClass.companionModule.termRef), diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 1725ed5e3f94..d5c01ae7b7b6 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -420,10 +420,7 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] case Closure(_, meth, _) => true case Block(Nil, expr) => isContextualClosure(expr) case Block(DefDef(nme.ANON_FUN, params :: _, _, _) :: Nil, cl: Closure) => - if params.isEmpty then - cl.tpt.eq(untpd.ContextualEmptyTree) || defn.isContextFunctionType(cl.tpt.typeOpt) - else - isUsingClause(params) + isUsingClause(params) case _ => false } diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 54c15b9909fa..bd172e8db6d3 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1192,7 +1192,6 @@ object Trees { @sharable val EmptyTree: Thicket = genericEmptyTree @sharable val EmptyValDef: ValDef = genericEmptyValDef - @sharable val ContextualEmptyTree: Thicket = new EmptyTree() // an empty tree marking a contextual closure // ----- Auxiliary creation methods ------------------ diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index e3488034fef8..d868cd039f30 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree /** Short-lived usage in typer, does not need copy/transform/fold infrastructure */ - case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree + case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree @sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] { override def isEmpty: Boolean = true diff --git a/compiler/src/dotty/tools/dotc/core/NameOps.scala b/compiler/src/dotty/tools/dotc/core/NameOps.scala index 04440c9e9b39..7cc602a20141 100644 --- a/compiler/src/dotty/tools/dotc/core/NameOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NameOps.scala @@ -236,10 +236,12 @@ object NameOps { */ def isPlainFunction(using Context): Boolean = functionArity >= 0 - /** Is a function name that contains `mustHave` as a substring */ - private def isSpecificFunction(mustHave: String)(using Context): Boolean = + /** Is a function name that contains `mustHave` as a substring + * and has arity `minArity` or greater. + */ + private def isSpecificFunction(mustHave: String, minArity: Int = 0)(using Context): Boolean = val suffixStart = functionSuffixStart - isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0 + isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= minArity def isContextFunction(using Context): Boolean = isSpecificFunction("Context") def isImpureFunction(using Context): Boolean = isSpecificFunction("Impure") diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 04dfbbb26ef7..6a93376bb1e3 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1872,6 +1872,8 @@ object Types { if alwaysDependent || mt.isResultDependent then RefinedType(funType, nme.apply, mt) else funType + case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent => + RefinedType(defn.PolyFunctionType, nme.apply, poly) } /** The signature of this type. This is by default NotAMethod, diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 493a6e1cc18e..6470f1ca4a04 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1511,6 +1511,7 @@ object Parsers { TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType) else if imods.isOneOf(Given | Impure) || erasedArgs.contains(true) then if imods.is(Given) && params.isEmpty then + imods &~= Given syntaxError(em"context function types require at least one parameter", paramSpan) FunctionWithMods(params, resultType, imods, erasedArgs.toList) else if !ctx.settings.YkindProjector.isDefault then diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index f3540502597c..784ab39e032e 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -297,9 +297,9 @@ class PlainPrinter(_ctx: Context) extends Printer { protected def paramsText(lam: LambdaType): Text = { val erasedParams = lam.erasedParams - def paramText(name: Name, tp: Type, erased: Boolean) = - keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true) - Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ") + def paramText(ref: ParamRef, erased: Boolean) = + keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true) + Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ") } protected def ParamRefNameString(name: Name): String = nameString(name) @@ -363,7 +363,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case tp @ ConstantType(value) => toText(value) case pref: TermParamRef => - nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder) + ParamRefNameString(pref) ~ lambdaHash(pref.binder) case tp: RecThis => val idx = openRecs.reverse.indexOf(tp.binder) if (idx >= 0) selfRecName(idx + 1) diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 98478fae92e8..3ceb274acd43 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -174,7 +174,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { ~ " " ~ argText(args.last) } - private def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match + protected def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match case info: MethodType => val capturesRoot = refs == rootSetText changePrec(GlobalPrec) { diff --git a/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala b/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala index ad56f29287fc..1d9e5310a954 100644 --- a/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala +++ b/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala @@ -196,6 +196,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe case AmbiguousExtensionMethodID // errorNumber 180 case UnqualifiedCallToAnyRefMethodID // errorNumber: 181 case NotConstantID // errorNumber: 182 + case ClosureCannotHaveInternalParameterDependenciesID // errorNumber: 183 def errorNumber = ordinal - 1 diff --git a/compiler/src/dotty/tools/dotc/reporting/Message.scala b/compiler/src/dotty/tools/dotc/reporting/Message.scala index a1fe6773c1d2..a536c5871b2a 100644 --- a/compiler/src/dotty/tools/dotc/reporting/Message.scala +++ b/compiler/src/dotty/tools/dotc/reporting/Message.scala @@ -51,6 +51,13 @@ object Message: */ private class Seen(disambiguate: Boolean): + /** The set of lambdas that were opened at some point during printing. */ + private val openedLambdas = new collection.mutable.HashSet[LambdaType] + + /** Register that `tp` was opened during printing. */ + def openLambda(tp: LambdaType): Unit = + openedLambdas += tp + val seen = new collection.mutable.HashMap[SeenKey, List[Recorded]]: override def default(key: SeenKey) = Nil @@ -89,8 +96,22 @@ object Message: val existing = seen(key) lazy val dealiased = followAlias(entry) - // alts: The alternatives in `existing` that are equal, or follow (an alias of) `entry` - var alts = existing.dropWhile(alt => dealiased ne followAlias(alt)) + /** All lambda parameters with the same name are given the same superscript as + * long as their corresponding binder has been printed. + * See tests/neg/lambda-rename.scala for test cases. + */ + def sameSuperscript(cur: Recorded, existing: Recorded) = + (cur eq existing) || + (cur, existing).match + case (cur: ParamRef, existing: ParamRef) => + (cur.paramName eq existing.paramName) && + openedLambdas.contains(cur.binder) && + openedLambdas.contains(existing.binder) + case _ => + false + + // The length of alts corresponds to the number of superscripts we need to print. + var alts = existing.dropWhile(alt => !sameSuperscript(dealiased, followAlias(alt))) if alts.isEmpty then alts = entry :: existing seen(key) = alts @@ -208,10 +229,20 @@ object Message: case tp: SkolemType => seen.record(tp.repr.toString, isType = true, tp) case _ => super.toTextRef(tp) + override def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text): Text = + info match + case info: LambdaType => + seen.openLambda(info) + case _ => + super.toTextMethodAsFunction(info, isPure, refs) + override def toText(tp: Type): Text = if !tp.exists || tp.isErroneous then seen.nonSensical = true tp match case tp: TypeRef if useSourceModule(tp.symbol) => Str("object ") ~ super.toText(tp) + case tp: LambdaType => + seen.openLambda(tp) + super.toText(tp) case _ => super.toText(tp) override def toText(sym: Symbol): Text = diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 104304c3409c..1a10c3950f87 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -2920,3 +2920,10 @@ class MatchTypeScrutineeCannotBeHigherKinded(tp: Type)(using Context) extends TypeMsg(MatchTypeScrutineeCannotBeHigherKindedID) : def msg(using Context) = i"the scrutinee of a match type cannot be higher-kinded" def explain(using Context) = "" + +class ClosureCannotHaveInternalParameterDependencies(mt: Type)(using Context) + extends TypeMsg(ClosureCannotHaveInternalParameterDependenciesID): + def msg(using Context) = + i"""cannot turn method type $mt into closure + |because it has internal parameter dependencies""" + def explain(using Context) = "" diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 305acea188da..c969e5983992 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -412,7 +412,7 @@ object Checking { case tree: RefTree => checkRef(tree, tree.symbol) foldOver(x, tree) - case tree: This => + case tree: This if tree.tpe.classSymbol == refineCls => selfRef(tree) case tree: TypeTree => val checkType = new TypeAccumulator[Unit] { diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index df708057dd71..0c6525dd92a7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1692,7 +1692,6 @@ class Namer { typer: Typer => def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = { def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType) - lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams } val tptProto = mdef.tpt match { case _: untpd.DerivedTypeTree => @@ -1700,7 +1699,10 @@ class Namer { typer: Typer => case TypeTree() => checkMembersOK(inferredType, mdef.srcPos) case DependentTypeTree(tpFun) => - val tpe = tpFun(termParamss.head) + // A lambda has at most one type parameter list followed by exactly one term parameter list. + val tpe = (paramss: @unchecked) match + case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams) + case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams) if (isFullyDefined(tpe, ForceDegree.none)) tpe else typedAheadExpr(mdef.rhs, tpe).tpe case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) => @@ -1724,7 +1726,8 @@ class Namer { typer: Typer => // So fixing levels at instantiation avoids the soundness problem but apparently leads // to type inference problems since it comes too late. if !Config.checkLevelsOnConstraints then - val hygienicType = TypeOps.avoid(rhsType, termParamss.flatten) + val termParams = paramss.collect { case TermSymbols(vparams) => vparams }.flatten + val hygienicType = TypeOps.avoid(rhsType, termParams) if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe)) report.error( em"""return type ${tpt.tpe} of lambda cannot be made hygienic diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 6ac45cbcf04d..0e3270281c9b 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -22,7 +22,11 @@ trait TypeAssigner { */ def qualifyingClass(tree: untpd.Tree, qual: Name, packageOK: Boolean)(using Context): Symbol = { def qualifies(sym: Symbol) = - sym.isClass && ( + sym.isClass && + // `this` in a polymorphic function type never refers to the desugared refinement. + // In other refinements, `this` does refer to the refinement but is deprecated + // (see `Checking#checkRefinementNonCyclic`). + !(sym.isRefinementClass && sym.derivesFrom(defn.PolyFunctionClass)) && ( qual.isEmpty || sym.name == qual || sym.is(Module) && sym.name.stripModuleClassSuffix == qual) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 458e60ddfa38..d75dffcc2e90 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1323,14 +1323,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity => - (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) + (formals, untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef)))) case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) if !SAMType.isParamDependentRec(mt) => val restpe = mt.resultType match case mt: MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined)) case tp => tp (formals, if (mt.isResultDependent) - untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))) + untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef))) else typeTree(restpe)) case _ => @@ -1625,12 +1625,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer ) cpy.ValDef(param)(tpt = paramTpt) if isErased then param0.withAddedFlags(Flags.Erased) else param0 - desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual, tree.span) + desugared = desugar.makeClosure(Nil, inferredParams, fnBody, resultTpt, tree.span) typed(desugared, pt) .showing(i"desugared fun $tree --> $desugared with pt = $pt", typr) } + + def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = + val tree1 = desugar.normalizePolyFunction(tree) + if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) + else typedPolyFunctionValue(tree1, pt) + + def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = + val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked + val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked + + val resultTpt = pt.dealias match + case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass => + untpd.DependentTypeTree((tsyms, vsyms) => + mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) + case _ => untpd.TypeTree() + + val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span) + typed(desugared, pt) + end typedPolyFunctionValue + def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = { val env1 = tree.env mapconserve (typed(_)) val meth1 = typedUnadapted(tree.meth) @@ -1658,12 +1678,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer TypeTree(targetTpe) case _ => if (mt.isParamDependent) - errorTree(tree, - em"""cannot turn method type $mt into closure - |because it has internal parameter dependencies""") - else if ((tree.tpt `eq` untpd.ContextualEmptyTree) && mt.paramNames.isEmpty) - // Note implicitness of function in target type since there are no method parameters that indicate it. - TypeTree(defn.FunctionOf(Nil, mt.resType, isContextual = true)) + errorTree(tree, ClosureCannotHaveInternalParameterDependencies(mt)) else if hasCaptureConversionArg(mt.resType) then errorTree(tree, em"""cannot turn method type $mt into closure @@ -1671,6 +1686,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else EmptyTree } + case poly @ PolyType(_, mt: MethodType) => + if (mt.isParamDependent) + errorTree(tree, ClosureCannotHaveInternalParameterDependencies(poly)) + else + // Polymorphic SAMs are not currently supported (#6904). + EmptyTree case tp => if !tp.isErroneous then throw new java.lang.Error(i"internal error: closing over non-method $tp, pos = ${tree.span}") @@ -2428,7 +2449,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case rhs => typedExpr(rhs, tpt1.tpe.widenExpr) } val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym) - postProcessInfo(sym) + postProcessInfo(vdef1, sym) vdef1.setDefTree } @@ -2537,19 +2558,31 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val ddef2 = assignType(cpy.DefDef(ddef)(name, paramss1, tpt1, rhs1), sym) - postProcessInfo(sym) + postProcessInfo(ddef2, sym) ddef2.setDefTree //todo: make sure dependent method types do not depend on implicits or by-name params } /** (1) Check that the signature of the class member does not return a repeated parameter type * (2) If info is an erased class, set erased flag of member + * (3) Check that erased classes are not parameters of polymorphic functions. */ - private def postProcessInfo(sym: Symbol)(using Context): Unit = + private def postProcessInfo(mdef: MemberDef, sym: Symbol)(using Context): Unit = if (!sym.isOneOf(Synthetic | InlineProxy | Param) && sym.info.finalResultType.isRepeatedParam) report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos) if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then sym.setFlag(Erased) + if + sym.info.isInstanceOf[PolyType] && + ((sym.name eq nme.ANON_FUN) || + (sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass)) + then + mdef match + case DefDef(_, _ :: vparams :: Nil, _, _) => + vparams.foreach: vparam => + if vparam.symbol.is(Erased) then + report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos) + case _ => def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = { val TypeDef(name, rhs) = tdef @@ -2696,19 +2729,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // check value class constraints checkDerivedValueClass(cls, body1) - // check PolyFunction constraints (no erased functions!) - if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then - body1.foreach { - case ddef: DefDef => - ddef.paramss.foreach { params => - val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef } - erasedParam.foreach { p => - report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos) - } - } - case _ => - } - val effectiveOwner = cls.owner.skipWeakOwner if !cls.isRefinementClass && !cls.isAllOf(PrivateLocal) @@ -3060,6 +3080,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case tree: untpd.Block => typedBlock(desugar.block(tree), pt)(using ctx.fresh.setNewScope) case tree: untpd.If => typedIf(tree, pt) case tree: untpd.Function => typedFunction(tree, pt) + case tree: untpd.PolyFunction => typedPolyFunction(tree, pt) case tree: untpd.Closure => typedClosure(tree, pt) case tree: untpd.Import => typedImport(tree) case tree: untpd.Export => typedExport(tree) @@ -3104,6 +3125,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val ifpt = defn.asContextFunctionType(pt) val result = if ifpt.exists + && defn.functionArity(ifpt) > 0 // ContextFunction0 is only used after ElimByName && xtree.isTerm && !untpd.isContextualClosure(xtree) && !ctx.mode.is(Mode.Pattern) diff --git a/project/Build.scala b/project/Build.scala index 4356d541eb8d..7a2e6aae46c1 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -526,6 +526,9 @@ object Build { // Settings shared between scala3-compiler and scala3-compiler-bootstrapped lazy val commonDottyCompilerSettings = Seq( + // Note: bench/profiles/projects.yml should be updated accordingly. + Compile / scalacOptions ++= Seq("-Yexplicit-nulls", "-Ysafe-init"), + // Generate compiler.properties, used by sbt (Compile / resourceGenerators) += Def.task { import java.util._ @@ -804,9 +807,6 @@ object Build { ) }, - // Note: bench/profiles/projects.yml should be updated accordingly. - Compile / scalacOptions ++= Seq("-Yexplicit-nulls", "-Ysafe-init"), - repl := (Compile / console).value, Compile / console / scalacOptions := Nil, // reset so that we get stock REPL behaviour! E.g. avoid -unchecked being enabled ) diff --git a/tests/neg-custom-args/fatal-warnings/refinements-this.scala b/tests/neg-custom-args/fatal-warnings/refinements-this.scala new file mode 100644 index 000000000000..20c1820c0041 --- /dev/null +++ b/tests/neg-custom-args/fatal-warnings/refinements-this.scala @@ -0,0 +1,3 @@ +class Outer: + type X = { type O = Outer.this.type } // ok + type Y = { type O = this.type } // error diff --git a/tests/neg/lambda-rename.check b/tests/neg/lambda-rename.check new file mode 100644 index 000000000000..e45a184ef31c --- /dev/null +++ b/tests/neg/lambda-rename.check @@ -0,0 +1,34 @@ +-- [E007] Type Mismatch Error: tests/neg/lambda-rename.scala:4:33 ------------------------------------------------------ +4 |val a: (x: Int) => Bar[x.type] = ??? : ((x: Int) => Foo[x.type]) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: (x: Int) => Foo[x.type] + | Required: (x: Int) => Bar[x.type] + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/lambda-rename.scala:7:33 ------------------------------------------------------ +7 |val b: HK[[X] =>> Foo[(X, X)]] = ??? : HK[[X] =>> Bar[(X, X)]] // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: HK[[X] =>> Bar[(X, X)]] + | Required: HK[[X] =>> Foo[(X, X)]] + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/lambda-rename.scala:10:33 ----------------------------------------------------- +10 |val c: HK[[X] =>> Foo[(X, X)]] = ??? : HK[[Y] =>> Foo[(X, X)]] // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: HK[[Y] =>> Foo[(X, X)]] + | Required: HK[[X²] =>> Foo[(X², X²)]] + | + | where: X is a class + | X² is a type variable + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/lambda-rename.scala:12:33 ----------------------------------------------------- +12 |val d: HK[[Y] =>> Foo[(X, X)]] = ??? : HK[[X] =>> Foo[(X, X)]] // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: HK[[X] =>> Foo[(X, X)]] + | Required: HK[[Y] =>> Foo[(X², X²)]] + | + | where: X is a type variable + | X² is a class + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg/lambda-rename.scala b/tests/neg/lambda-rename.scala new file mode 100644 index 000000000000..586d8ae28bdf --- /dev/null +++ b/tests/neg/lambda-rename.scala @@ -0,0 +1,12 @@ +class Foo[T] +class Bar[T] + +val a: (x: Int) => Bar[x.type] = ??? : ((x: Int) => Foo[x.type]) // error + +trait HK[F <: AnyKind] +val b: HK[[X] =>> Foo[(X, X)]] = ??? : HK[[X] =>> Bar[(X, X)]] // error + +class X +val c: HK[[X] =>> Foo[(X, X)]] = ??? : HK[[Y] =>> Foo[(X, X)]] // error + +val d: HK[[Y] =>> Foo[(X, X)]] = ??? : HK[[X] =>> Foo[(X, X)]] // error diff --git a/tests/neg/polymorphic-functions.scala b/tests/neg/polymorphic-functions.scala index d9783baee967..b949cf04194c 100644 --- a/tests/neg/polymorphic-functions.scala +++ b/tests/neg/polymorphic-functions.scala @@ -2,4 +2,6 @@ object Test { val pv0: [T] => List[T] = ??? // error val pv1: Any = [T] => Nil // error val pv2: [T] => List[T] = [T] => Nil // error // error + + val intraDep = [T] => (x: T, y: List[x.type]) => List(y) // error } diff --git a/tests/neg/polymorphic-functions1.check b/tests/neg/polymorphic-functions1.check index 7374075de072..eef268c298cf 100644 --- a/tests/neg/polymorphic-functions1.check +++ b/tests/neg/polymorphic-functions1.check @@ -1,7 +1,7 @@ --- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 --------------------------------------------- +-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:33 --------------------------------------------- 1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error - | ^ - | Found: [T] => (x: Int) => Int - | Required: [T] => (x: T) => x.type + | ^^^^^^^^^^^^^^^^^^^^ + | Found: [T] => (x: Int) => x.type + | Required: [T] => (x: T) => x.type | | longer explanation available when compiling with `-explain` diff --git a/tests/pos/i16756.scala b/tests/pos/i16756.scala new file mode 100644 index 000000000000..fa54dccd7eee --- /dev/null +++ b/tests/pos/i16756.scala @@ -0,0 +1,16 @@ +class DependentPoly { + + sealed trait Col[V] { + + trait Wrapper + val wrapper: Wrapper = ??? + } + + object Col1 extends Col[Int] + + object Col2 extends Col[Double] + + val polyFn: [C <: DependentPoly.this.Col[?]] => (x: C) => x.Wrapper = + [C <: Col[?]] => (x: C) => (x.wrapper: x.Wrapper) +} + diff --git a/tests/pos/polymorphic-functions-this.scala b/tests/pos/polymorphic-functions-this.scala new file mode 100644 index 000000000000..91e1b38ed714 --- /dev/null +++ b/tests/pos/polymorphic-functions-this.scala @@ -0,0 +1,10 @@ +trait Foo: + type X + def x: X + val f: [T <: this.X] => (T, this.X) => (T, this.X) = + [T <: this.X] => (x: T, y: this.X) => (x, y) + f(x, x) + + val g: [T <: this.type] => (T, this.type) => (T, this.type) = + [T <: this.type] => (x: T, y: this.type) => (x, y) + g(this, this) diff --git a/tests/run/polymorphic-functions.scala b/tests/run/polymorphic-functions.scala index 35b1469f2c3a..b1fcdb349413 100644 --- a/tests/run/polymorphic-functions.scala +++ b/tests/run/polymorphic-functions.scala @@ -85,6 +85,13 @@ object Test extends App { val v0a: String = v0 assert(v0 == "foo") + // Used to fail with: Found: ... => List[T] + // Expected: ... => List[x.type] + val md2: [T] => (x: T) => List[x.type] = [T] => (x: T) => List(x) + val x = 1 + val v1 = md2(x) + val v1a: List[x.type] = v1 + // Contextual trait Show[T] { def show(t: T): String } implicit val si: Show[Int] =