diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 45ec0b60ce0c..6f8aed16fdc3 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -90,6 +90,7 @@ class Compiler { List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations + new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses new TailRec, // Rewrite tail recursion to loops new Mixin, // Expand trait fields and trait initializers new LazyVals, // Expand lazy vals diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 74998424ac43..f8c4b11ee8a9 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1418,6 +1418,41 @@ object desugar { } } + def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match { + case Function(vargs, res) => + // TODO: Figure out if we need a `PolyFunctionWithMods` instead. + val mods = body match { + case body: FunctionWithMods => body.mods + case _ => untpd.EmptyModifiers + } + val polyFunctionTpt = ref(defn.PolyFunctionType) + val applyTParams = targs.asInstanceOf[List[TypeDef]] + if (ctx.mode.is(Mode.Type)) { + // Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R + // Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } + + val applyVParams = vargs.zipWithIndex.map { case (p, n) => + makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags) + } + RefinedTypeTree(polyFunctionTpt, List( + DefDef(nme.apply, applyTParams, List(applyVParams), res, EmptyTree) + )) + } else { + // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body + // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body } + + val applyVParams = vargs.asInstanceOf[List[ValDef]] + .map(varg => varg.withAddedFlags(mods.flags | Param)) + New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, + List(DefDef(nme.apply, applyTParams, List(applyVParams), TypeTree(), res)) + )) + } + case _ => + // may happen for erroneous input. An error will already have been reported. + assert(ctx.reporter.errorsReported) + EmptyTree + } + // begin desugar // Special case for `Parens` desugaring: unlike all the desugarings below, @@ -1430,6 +1465,8 @@ object desugar { } val desugared = tree match { + case PolyFunction(targs, body) => + makePolyFunction(targs, body) orElse tree case SymbolLit(str) => Literal(Constant(scala.Symbol(str))) case InterpolatedString(id, segments) => diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 1b526deddf09..9ef16d2ee9cc 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -331,6 +331,7 @@ object Trees { } def withFlags(flags: FlagSet): ThisTree[Untyped] = withMods(untpd.Modifiers(flags)) + def withAddedFlags(flags: FlagSet): ThisTree[Untyped] = withMods(rawMods | flags) def setComment(comment: Option[Comment]): this.type = { comment.map(putAttachment(DocComment, _)) diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 1e9c67c6c920..cabcf79c71c3 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -72,6 +72,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile) extends Function(args, body) + /** A polymorphic function type */ + case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree { + override def isTerm = body.isTerm + override def isType = body.isType + } + /** A function created from a wildcard expression * @param placeholderParams a list of definitions of synthetic parameters. * @param body the function body where wildcards are replaced by @@ -491,6 +497,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: Function if (args eq tree.args) && (body eq tree.body) => tree case _ => finalize(tree, untpd.Function(args, body)(tree.source)) } + def PolyFunction(tree: Tree)(targs: List[Tree], body: Tree)(implicit ctx: Context): Tree = tree match { + case tree: PolyFunction if (targs eq tree.targs) && (body eq tree.body) => tree + case _ => finalize(tree, untpd.PolyFunction(targs, body)(tree.source)) + } def InfixOp(tree: Tree)(left: Tree, op: Ident, right: Tree)(implicit ctx: Context): Tree = tree match { case tree: InfixOp if (left eq tree.left) && (op eq tree.op) && (right eq tree.right) => tree case _ => finalize(tree, untpd.InfixOp(left, op, right)(tree.source)) @@ -579,6 +589,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { cpy.InterpolatedString(tree)(id, segments.mapConserve(transform)) case Function(args, body) => cpy.Function(tree)(transform(args), transform(body)) + case PolyFunction(targs, body) => + cpy.PolyFunction(tree)(transform(targs), transform(body)) case InfixOp(left, op, right) => cpy.InfixOp(tree)(transform(left), op, transform(right)) case PostfixOp(od, op) => @@ -634,6 +646,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, segments) case Function(args, body) => this(this(x, args), body) + case PolyFunction(targs, body) => + this(this(x, targs), body) case InfixOp(left, op, right) => this(this(this(x, left), op), right) case PostfixOp(od, op) => diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 27d34b00cf9f..a834a2805755 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1035,6 +1035,9 @@ class Definitions { if (n <= MaxImplementedFunctionArity && (!isContextual || ctx.erasedTypes) && !isErased) ImplementedFunctionType(n) else FunctionClass(n, isContextual, isErased).typeRef + lazy val PolyFunctionClass = ctx.requiredClass("scala.PolyFunction") + def PolyFunctionType = PolyFunctionClass.typeRef + /** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */ def scalaClassName(cls: Symbol)(implicit ctx: Context): TypeName = if (cls.isClass && cls.owner == ScalaPackageClass) cls.asClass.name else EmptyTypeName diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 3315bacdfb7e..4dad0c56fec9 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -196,10 +196,27 @@ object TypeErasure { MethodType(Nil, defn.BoxedUnitType) else if (sym.isAnonymousFunction && einfo.paramInfos.length > MaxImplementedFunctionArity) MethodType(nme.ALLARGS :: Nil, JavaArrayType(defn.ObjectType) :: Nil, einfo.resultType) + else if (sym.name == nme.apply && sym.owner.derivesFrom(defn.PolyFunctionClass)) { + // The erasure of `apply` in subclasses of PolyFunction has to match + // the erasure of FunctionN#apply, since after `ElimPolyFunction` we replace + // a `PolyFunction` parent by a `FunctionN` parent. + einfo.derivedLambdaType( + paramInfos = einfo.paramInfos.map(_ => defn.ObjectType), + resType = defn.ObjectType + ) + } else einfo case einfo => - einfo + // Erase the parameters of `apply` in subclasses of PolyFunction + // Preserve PolyFunction argument types to support PolyFunctions with + // PolyFunction arguments + if (sym.is(TermParam) && sym.owner.name == nme.apply + && sym.owner.owner.derivesFrom(defn.PolyFunctionClass) + && !(tp <:< defn.PolyFunctionType)) { + defn.ObjectType + } else + einfo } } @@ -383,6 +400,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean * - otherwise, if T is a type parameter coming from Java, []Object * - otherwise, Object * - For a term ref p.x, the type # x. + * - For a refined type scala.PolyFunction { def apply[...](x_1, ..., x_N): R }, scala.FunctionN * - For a typeref scala.Any, scala.AnyVal, scala.Singleton, scala.Tuple, or scala.*: : |java.lang.Object| * - For a typeref scala.Unit, |scala.runtime.BoxedUnit|. * - For a typeref scala.FunctionN, where N > MaxImplementedFunctionArity, scala.FunctionXXL @@ -429,6 +447,12 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean SuperType(this(thistpe), this(supertpe)) case ExprType(rt) => defn.FunctionType(0) + case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass => + assert(refinedInfo.isInstanceOf[PolyType]) + val res = refinedInfo.resultType + val paramss = res.paramNamess + assert(paramss.length == 1) + this(defn.FunctionType(paramss.head.length, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod)) case tp: TypeProxy => this(tp.underlying) case AndType(tp1, tp2) => @@ -581,6 +605,11 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean case tp: TypeVar => val inst = tp.instanceOpt if (inst.exists) sigName(inst) else tpnme.Uninstantiated + case tp @ RefinedType(parent, nme.apply, _) if parent.typeSymbol eq defn.PolyFunctionClass => + // we need this case rather than falling through to the default + // because RefinedTypes <: TypeProxy and it would be caught by + // the case immediately below + sigName(this(tp)) case tp: TypeProxy => sigName(tp.underlying) case _: ErrorType | WildcardType | NoType => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index f4b11170725c..493f7d1fd8e4 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3194,6 +3194,7 @@ object Types { companion.eq(ContextualMethodType) || companion.eq(ErasedContextualMethodType) + def computeSignature(implicit ctx: Context): Signature = { val params = if (isErasedMethod) Nil else paramInfos resultSignature.prepend(params, isJavaMethod) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 5982bcdaf4cd..07af3d873522 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -850,9 +850,13 @@ object Parsers { */ def toplevelTyp(): Tree = rejectWildcardType(typ()) - /** Type ::= FunTypeMods FunArgTypes `=>' Type - * | HkTypeParamClause `=>>' Type + /** Type ::= FunType + * | HkTypeParamClause ‘=>>’ Type + * | MatchType * | InfixType + * FunType ::= { 'erased' | 'given' } (MonoFunType | PolyFunType) + * MonoFunType ::= FunArgTypes ‘=>’ Type + * PolyFunType ::= HKTypeParamClause '=>' Type * FunArgTypes ::= InfixType * | `(' [ FunArgType {`,' FunArgType } ] `)' * | '(' TypedFunParam {',' TypedFunParam } ')' @@ -924,7 +928,18 @@ object Parsers { val tparams = typeParamClause(ParamOwner.TypeParam) if (in.token == TLARROW) atSpan(start, in.skipToken())(LambdaTypeTree(tparams, toplevelTyp())) - else { accept(TLARROW); typ() } + else if (in.token == ARROW) { + val arrowOffset = in.skipToken() + val body = toplevelTyp() + atSpan(start, arrowOffset) { + body match { + case _: Function => PolyFunction(tparams, body) + case _ => + syntaxError("Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) + Ident(nme.ERROR.toTypeName) + } + } + } else { accept(TLARROW); typ() } } else infixType() @@ -1223,6 +1238,7 @@ object Parsers { * | `throw' Expr * | `return' [Expr] * | ForExpr + * | HkTypeParamClause ‘=>’ Expr * | [SimpleExpr `.'] id `=' Expr * | SimpleExpr1 ArgumentExprs `=' Expr * | Expr2 @@ -1323,6 +1339,19 @@ object Parsers { atSpan(in.skipToken()) { Return(if (isExprIntro) expr() else EmptyTree, EmptyTree) } case FOR => forExpr() + case LBRACKET => + val start = in.offset + val tparams = typeParamClause(ParamOwner.TypeParam) + val arrowOffset = accept(ARROW) + val body = expr() + atSpan(start, arrowOffset) { + body match { + case _: Function => PolyFunction(tparams, body) + case _ => + syntaxError("Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset) + errorTermTree + } + } case _ => if (isIdent(nme.inline) && !in.inModifierPosition() && in.lookaheadIn(canStartExpressionTokens)) { val start = in.skipToken() diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 21612a1e32bb..d165b9eda205 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -558,6 +558,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { (keywordText("erased ") provided isErased) ~ argsText ~ " => " ~ toText(body) } + case PolyFunction(targs, body) => + val targsText = "[" ~ Text(targs.map((arg: Tree) => toText(arg)), ", ") ~ "]" + changePrec(GlobalPrec) { + targsText ~ " => " ~ toText(body) + } case InfixOp(l, op, r) => val opPrec = parsing.precedence(op.name) changePrec(opPrec) { toText(l) ~ " " ~ toText(op) ~ " " ~ toText(r) } diff --git a/compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala b/compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala index 546e8319f7f5..18a731545eed 100644 --- a/compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala +++ b/compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala @@ -83,11 +83,19 @@ class ElimErasedValueType extends MiniPhase with InfoTransformer { thisPhase => override def matches(sym1: Symbol, sym2: Symbol) = sym1.signature == sym2.signature } + def checkNoConflict(sym1: Symbol, sym2: Symbol, info: Type)(implicit ctx: Context): Unit = { val site = root.thisType val info1 = site.memberInfo(sym1) val info2 = site.memberInfo(sym2) - if (!info1.matchesLoosely(info2)) + // PolyFunction apply methods will be eliminated later during + // ElimPolyFunction, so we let them pass here. + def bothPolyApply = + sym1.name == nme.apply && + (sym1.owner.derivesFrom(defn.PolyFunctionClass) || + sym2.owner.derivesFrom(defn.PolyFunctionClass)) + + if (!info1.matchesLoosely(info2) && !bothPolyApply) ctx.error(DoubleDefinition(sym1, sym2, root), root.sourcePos) } val earlyCtx = ctx.withPhase(ctx.elimRepeatedPhase.next) diff --git a/compiler/src/dotty/tools/dotc/transform/ElimPolyFunction.scala b/compiler/src/dotty/tools/dotc/transform/ElimPolyFunction.scala new file mode 100644 index 000000000000..6ba6591abc05 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/ElimPolyFunction.scala @@ -0,0 +1,68 @@ +package dotty.tools.dotc +package transform + +import ast.{Trees, tpd} +import core._, core.Decorators._ +import MegaPhase._, Phases.Phase +import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._, DenotTransformers._ +import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Scopes._, Denotations._ +import TypeErasure.ErasedValueType, ValueClasses._ + +/** This phase rewrite PolyFunction subclasses to FunctionN subclasses + * + * class Foo extends PolyFunction { + * def apply(x_1: P_1, ..., x_N: P_N): R = rhs + * } + * becomes: + * class Foo extends FunctionN { + * def apply(x_1: P_1, ..., x_N: P_N): R = rhs + * } + */ +class ElimPolyFunction extends MiniPhase with DenotTransformer { + + import tpd._ + + override def phaseName: String = ElimPolyFunction.name + + override def runsAfter = Set(Erasure.name) + + override def changesParents: Boolean = true // Replaces PolyFunction by FunctionN + + override def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match { + case ref: ClassDenotation if ref.symbol != defn.PolyFunctionClass && ref.derivesFrom(defn.PolyFunctionClass) => + val cinfo = ref.classInfo + val newParent = functionTypeOfPoly(cinfo) + val newParents = cinfo.classParents.map(parent => + if (parent.typeSymbol == defn.PolyFunctionClass) + newParent + else + parent + ) + ref.copySymDenotation(info = cinfo.derivedClassInfo(classParents = newParents)) + case _ => + ref + } + + def functionTypeOfPoly(cinfo: ClassInfo)(implicit ctx: Context): Type = { + val applyMeth = cinfo.decls.lookup(nme.apply).info + val arity = applyMeth.paramNamess.head.length + defn.FunctionType(arity) + } + + override def transformTemplate(tree: Template)(implicit ctx: Context): Tree = { + val newParents = tree.parents.mapconserve(parent => + if (parent.tpe.typeSymbol == defn.PolyFunctionClass) { + val cinfo = tree.symbol.owner.asClass.classInfo + tpd.TypeTree(functionTypeOfPoly(cinfo)) + } + else + parent + ) + cpy.Template(tree)(parents = newParents) + } +} + +object ElimPolyFunction { + val name = "elimPolyFunction" +} + diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index ef9dd7d72947..800d812a14e9 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -415,9 +415,20 @@ object Erasure { * e.m -> e.[]m if `m` is an array operation other than `clone`. */ override def typedSelect(tree: untpd.Select, pt: Type)(implicit ctx: Context): Tree = { + val qual1 = typed(tree.qualifier, AnySelectionProto) def mapOwner(sym: Symbol): Symbol = { - def recur(owner: Symbol): Symbol = + // PolyFunction apply Selects will not have a symbol, so deduce the owner + // from the typed qual. + def polyOwner: Symbol = + if (sym.exists || tree.name != nme.apply) NoSymbol + else { + val owner = qual1.tpe.widen.typeSymbol + if (defn.isFunctionClass(owner)) owner else NoSymbol + } + + polyOwner orElse { + val owner = sym.owner if (defn.specialErasure.contains(owner)) { assert(sym.isConstructor, s"${sym.showLocated}") defn.specialErasure(owner) @@ -425,12 +436,12 @@ object Erasure { defn.erasedFunctionClass(owner) else owner - recur(sym.owner) + } } val origSym = tree.symbol val owner = mapOwner(origSym) - val sym = if (owner eq origSym.owner) origSym else owner.info.decl(origSym.name).symbol + val sym = if (owner eq origSym.maybeOwner) origSym else owner.info.decl(tree.name).symbol assert(sym.exists, origSym.showLocated) def select(qual: Tree, sym: Symbol): Tree = @@ -474,7 +485,7 @@ object Erasure { } } - checkNotErased(recur(typed(tree.qualifier, AnySelectionProto))) + checkNotErased(recur(qual1)) } override def typedThis(tree: untpd.This)(implicit ctx: Context): Tree = diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 5e80c4b2be72..cc82c6c49d92 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -54,7 +54,8 @@ trait TypeAssigner { required = EmptyFlagConjunction, excluded = Private) .suchThat(decl.matches(_)) val inheritedInfo = inherited.info - if (inheritedInfo.exists && + val isPolyFunctionApply = decl.name == nme.apply && (parent <:< defn.PolyFunctionType) + if (isPolyFunctionApply || inheritedInfo.exists && decl.info.widenExpr <:< inheritedInfo.widenExpr && !(inheritedInfo.widenExpr <:< decl.info.widenExpr)) { val r = RefinedType(parent, decl.name, decl.info) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 985dd512fa15..ff9961df0cd9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1271,7 +1271,9 @@ class Typer extends Namer typr.println(s"adding refinement $refinement") checkRefinementNonCyclic(refinement, refineCls, seen) val rsym = refinement.symbol - if (rsym.info.isInstanceOf[PolyType] && rsym.allOverriddenSymbols.isEmpty) + val polymorphicRefinementAllowed = + tpt1.tpe.typeSymbol == defn.PolyFunctionClass && rsym.name == nme.apply + if (!polymorphicRefinementAllowed && rsym.info.isInstanceOf[PolyType] && rsym.allOverriddenSymbols.isEmpty) ctx.error(PolymorphicMethodMissingTypeInParent(rsym, tpt1.symbol), refinement.sourcePos) val member = refineCls.info.member(rsym.name) diff --git a/docs/docs/internals/syntax.md b/docs/docs/internals/syntax.md index a8a02c5fe03e..0a058d9ac39c 100644 --- a/docs/docs/internals/syntax.md +++ b/docs/docs/internals/syntax.md @@ -139,10 +139,13 @@ ClassQualifier ::= ‘[’ id ‘]’ ### Types ```ebnf -Type ::= { ‘erased’ | ‘given’} FunArgTypes ‘=>’ Type Function(ts, t) - | HkTypeParamClause ‘=>>’ Type TypeLambda(ps, t) +Type ::= FunType + | HkTypeParamClause ‘=>>’ Type TypeLambda(ps, t) | MatchType | InfixType +FunType ::= { 'erased' | 'given' } (MonoFunType | PolyFunType) +MonoFunType ::= FunArgTypes ‘=>’ Type Function(ts, t) +PolyFunType :: = HKTypeParamClause '=>' Type PolyFunction(ps, t) FunArgTypes ::= InfixType | ‘(’ [ FunArgType {‘,’ FunArgType } ] ‘)’ | ‘(’ TypedFunParam {‘,’ TypedFunParam } ‘)’ @@ -195,6 +198,7 @@ Expr1 ::= ‘if’ ‘(’ Expr ‘)’ {nl} | ‘throw’ Expr Throw(expr) | ‘return’ [Expr] Return(expr?) | ForExpr + | HkTypeParamClause ‘=>’ Expr PolyFunction(ts, expr) | [SimpleExpr ‘.’] id ‘=’ Expr Assign(expr, expr) | SimpleExpr1 ArgumentExprs ‘=’ Expr Assign(expr, expr) | Expr2 diff --git a/library/src/scala/PolyFunction.scala b/library/src/scala/PolyFunction.scala new file mode 100644 index 000000000000..c6168a88d7f7 --- /dev/null +++ b/library/src/scala/PolyFunction.scala @@ -0,0 +1,10 @@ +package scala + +/** Marker trait for polymorphic function types. + * + * This is the only trait that can be refined with a polymorphic method, + * as long as that method is called `apply`, e.g.: + * PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R } + * This type will be erased to FunctionN. + */ +trait PolyFunction diff --git a/tests/neg/bad-selftype.scala b/tests/neg/bad-selftype.scala index 46a538d1f879..ce64e78ca043 100644 --- a/tests/neg/bad-selftype.scala +++ b/tests/neg/bad-selftype.scala @@ -2,5 +2,5 @@ trait x0[T] { self: x0 => } // error trait x1[T] { self: (=> String) => } // error -trait x2[T] { self: ([X] => X) => } // error +trait x2[T] { self: ([X] =>> X) => } // error diff --git a/tests/neg/i4373.scala b/tests/neg/i4373.scala index 84b6a71666b2..20a6c2595c9b 100644 --- a/tests/neg/i4373.scala +++ b/tests/neg/i4373.scala @@ -17,7 +17,7 @@ object Test { type T1 = _ // error type T2 = _[Int] // error type T3 = _ { type S } // error - type T4 = [X] => _ // error + type T4 = [X] =>> _ // error // Open questions: type T5 = TypeConstr[_ { type S }] // error diff --git a/tests/neg/i6385a.scala b/tests/neg/i6385a.scala index 17761cd076ed..2c52a44937c1 100644 --- a/tests/neg/i6385a.scala +++ b/tests/neg/i6385a.scala @@ -7,5 +7,5 @@ object Test { def f[F[_]](x: Box[F]) = ??? def db: Box[D] = ??? def cb: Box[C] = db // error - f[[X] => C[X]](db) // error -} \ No newline at end of file + f[[X] =>> C[X]](db) // error +} diff --git a/tests/neg/polymorphic-functions.scala b/tests/neg/polymorphic-functions.scala new file mode 100644 index 000000000000..d9783baee967 --- /dev/null +++ b/tests/neg/polymorphic-functions.scala @@ -0,0 +1,5 @@ +object Test { + val pv0: [T] => List[T] = ??? // error + val pv1: Any = [T] => Nil // error + val pv2: [T] => List[T] = [T] => Nil // error // error +} diff --git a/tests/run/polymorphic-functions.scala b/tests/run/polymorphic-functions.scala new file mode 100644 index 000000000000..5167eab1580a --- /dev/null +++ b/tests/run/polymorphic-functions.scala @@ -0,0 +1,93 @@ +object Test extends App { + // Types + type F0 = [T] => List[T] => Option[T] + type F1 = [F[_], G[_], T] => (F[T], F[T] => G[T]) => G[T] + type F11 = [F[_[_]], G[_[_]], T[_]] => (F[T], [U[_]] => F[U] => G[U]) => G[T] + type F2 = [T, U] => (T, U) => Either[T, U] + + // Terms + val t0 = [T] => (ts: List[T]) => ts.headOption + val t0a: F0 = t0 + assert(t0(List(1, 2, 3)) == Some(1)) + + val t1 = [F[_], G[_], T] => (ft: F[T], f: F[T] => G[T]) => f(ft) + val t1a: F1 = t1 + assert(t1(List(1, 2, 3), (ts: List[Int]) => ts.headOption) == Some(1)) + + val t11 = [F[_[_]], G[_[_]], T[_]] => (fl: F[T], f: [U[_]] => F[U] => G[U]) => f(fl) + val t11a: F11 = t11 + case class C11[F[_]](is: F[Int]) + case class D11[F[_]](is: F[Int]) + assert(t11[F = C11](C11(List(1, 2, 3)), [U[_]] => (c: C11[U]) => D11(c.is)) == D11(List(1, 2, 3))) + + val t2 = [T, U] => (t: T, u: U) => Left(t) + val t2a: F2 = t2 + assert(t2(23, "foo") == Left(23)) + + // Polymorphic idenity + val pid = [T] => (t: T) => t + + // Method with poly function argument + def m[T](f: [U] => U => U, t: T) = f(t) + val m0 = m(pid, 23) + + // Constructor with poly function argument + class C[T](f: [U] => U => U, t: T) { val v: T = f(t) } + val c0 = new C(pid, 23) + + // Function with poly function argument + val mf = (f: [U] => U => U, t: Int) => f(t) + val mf0 = mf(pid, 23) + + // Poly function with poly function argument + val pf = [T] => (f: [U] => U => U, t: T) => f(t) + val pf0 = pf(pid, 23) + + // Poly function with AnyVal arguments + val pf2 = [T] => (f: [U] => U => U, t: Int) => f(t) + val pf20 = pf2(pid, 23) + + // Implment/override + val phd = [T] => (ts: List[T]) => ts.headOption + + trait A { + val is: List[Int] + def m1(f: [T] => List[T] => Option[T]): Option[Int] + def m2(f: [T] => List[T] => Option[T]): Option[Int] = f(is) + } + + class B(val is: List[Int]) extends A { + def m1(f: [T] => List[T] => Option[T]): Option[Int] = f(is) + override def m2(f: [T] => List[T] => Option[T]): Option[Int] = f(is) + } + + assert(new B(List(1, 2, 3)).m1(phd) == Some(1)) + assert(new B(List(1, 2, 3)).m2(phd) == Some(1)) + + // Overload + class O(is: List[Int]) { + def m(f: [T] => List[T] => Option[T]): (Option[Int], Boolean) = (f(is), true) + def m(f: [T] => (List[T], T) => Option[T]): (Option[Int], Boolean) = (is.headOption.flatMap(f(is, _)), false) + } + + assert(new O(List(1, 2, 3)).m(phd) == (Some(1), true)) + assert(new O(List(1, 2, 3)).m([T] => (ts: List[T], t: T) => Some(t)) == (Some(1), false)) + + // Dependent + trait Entry[V] { type Key; val key: Key ; val value: V } + def extractKey[V](e: Entry[V]): e.Key = e.key + val md = [V] => (e: Entry[V]) => extractKey(e) + val eis = new Entry[Int] { type Key = String ; val key = "foo" ; val value = 23 } + val v0 = md(eis) + val v0a: String = v0 + assert(v0 == "foo") + + // Contextual + trait Show[T] { def show(t: T): String } + implicit val si: Show[Int] = + new Show[Int] { + def show(t: Int): String = t.toString + } + val s = [T] => (t: T) => given (st: Show[T]) => st.show(t) + assert(s(23) == "23") +}