diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index f753b7614a45..b141971497f8 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -42,7 +42,8 @@ class Compiler { List(new Pickler), List(new FirstTransform, new CheckReentrant), - List(new RefChecks, + List(new PreSpecializer, + new RefChecks, new ElimRepeated, new NormalizeFlags, new ExtensionMethods, @@ -53,6 +54,7 @@ class Compiler { List(new PatternMatcher, new ExplicitOuter, new Splitter), + List(new TypeSpecializer), List(new VCInlineMethods, new SeqLiterals, new InterceptedMethods, diff --git a/src/dotty/tools/dotc/ast/TreeTypeMap.scala b/src/dotty/tools/dotc/ast/TreeTypeMap.scala index d714a3d21f40..b8b5451b9abc 100644 --- a/src/dotty/tools/dotc/ast/TreeTypeMap.scala +++ b/src/dotty/tools/dotc/ast/TreeTypeMap.scala @@ -31,7 +31,7 @@ import dotty.tools.dotc.transform.SymUtils._ * gets two different denotations in the same period. Hence, if -Yno-double-bindings is * set, we would get a data race assertion error. */ -final class TreeTypeMap( +class TreeTypeMap( val typeMap: Type => Type = IdentityTypeMap, val treeMap: tpd.Tree => tpd.Tree = identity _, val oldOwners: List[Symbol] = Nil, @@ -75,7 +75,7 @@ final class TreeTypeMap( updateDecls(prevStats.tail, newStats.tail) } - override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = treeMap(tree) match { + override final def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = treeMap(tree) match { case impl @ Template(constr, parents, self, _) => val tmap = withMappedSyms(localSyms(impl :: self :: Nil)) cpy.Template(impl)( diff --git a/src/dotty/tools/dotc/config/ScalaSettings.scala b/src/dotty/tools/dotc/config/ScalaSettings.scala index 05fefc8b41f4..fa32a138f5cc 100644 --- a/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -152,6 +152,8 @@ class ScalaSettings extends Settings.SettingGroup { val YprintSyms = BooleanSetting("-Yprint-syms", "when printing trees print info in symbols instead of corresponding info in trees.") val YtestPickler = BooleanSetting("-Ytest-pickler", "self-test for pickling functionality; should be used with -Ystop-after:pickler") val YcheckReentrant = BooleanSetting("-Ycheck-reentrant", "check that compiled program does not contain vars that can be accessed from a global root.") + val Yspecialize = IntSetting("-Yspecialize","Specialize methods with maximum this amount of polymorphic types.", 0, 0 to 10) + def stop = YstopAfter /** Area-specific debug output. diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index fcd9ef224f10..29328a828cbd 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -338,6 +338,7 @@ class Definitions { lazy val TransientAnnot = ctx.requiredClass("scala.transient") lazy val NativeAnnot = ctx.requiredClass("scala.native") lazy val ScalaStrictFPAnnot = ctx.requiredClass("scala.annotation.strictfp") + lazy val SpecializedAnnot = ctx.requiredClass("scala.specialized") // Annotation classes lazy val AliasAnnot = ctx.requiredClass("dotty.annotation.internal.Alias") diff --git a/src/dotty/tools/dotc/core/NameOps.scala b/src/dotty/tools/dotc/core/NameOps.scala index 593d5f036197..8d3982a33d7d 100644 --- a/src/dotty/tools/dotc/core/NameOps.scala +++ b/src/dotty/tools/dotc/core/NameOps.scala @@ -4,7 +4,7 @@ package core import java.security.MessageDigest import scala.annotation.switch import scala.io.Codec -import Names._, StdNames._, Contexts._, Symbols._, Flags._ +import Names._, dotty.tools.dotc.core.StdNames._, Contexts._, Symbols._, Flags._ import Decorators.StringDecorator import util.{Chars, NameTransformer} import Chars.isOperatorPart @@ -241,10 +241,11 @@ object NameOps { case nme.clone_ => nme.clone_ } - def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = { + def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTargsNames: List[Name])(implicit ctx: Context): name.ThisName = { def typeToTag(tp: Types.Type): Name = { - tp.classSymbol match { + if (tp eq null) nme.EMPTY + else tp.classSymbol match { case t if t eq defn.IntClass => nme.specializedTypeNames.Int case t if t eq defn.BooleanClass => nme.specializedTypeNames.Boolean case t if t eq defn.ByteClass => nme.specializedTypeNames.Byte @@ -258,7 +259,7 @@ object NameOps { } } - val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => typeToTag(x._1)) + val methodTags: Seq[Name] = (methodTargs zip methodTargsNames).map(x => typeToTag(x._1)) val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => typeToTag(x._1)) name.fromName(name ++ nme.specializedTypeNames.prefix ++ diff --git a/src/dotty/tools/dotc/core/Names.scala b/src/dotty/tools/dotc/core/Names.scala index 12def107626f..93a080e77e94 100644 --- a/src/dotty/tools/dotc/core/Names.scala +++ b/src/dotty/tools/dotc/core/Names.scala @@ -353,9 +353,9 @@ object Names { def compare(x: Name, y: Name): Int = { if (x.isTermName && y.isTypeName) 1 else if (x.isTypeName && y.isTermName) -1 - else if (x eq y) 0 + else if (x.start == y.start && x.length == y.length) 0 else { - val until = x.length min y.length + val until = Math.min(x.length, y.length) var i = 0 while (i < until && x(i) == y(i)) i = i + 1 @@ -364,7 +364,9 @@ object Names { if (x(i) < y(i)) -1 else /*(x(i) > y(i))*/ 1 } else { - x.length - y.length + if (x.length < y.length) 1 + else if (x.length > y.length) -1 + else 0 // shouldn't happen, but still } } } diff --git a/src/dotty/tools/dotc/core/Phases.scala b/src/dotty/tools/dotc/core/Phases.scala index 8d5ec08f70af..5fae2626d7b6 100644 --- a/src/dotty/tools/dotc/core/Phases.scala +++ b/src/dotty/tools/dotc/core/Phases.scala @@ -240,6 +240,7 @@ object Phases { private val explicitOuterCache = new PhaseCache(classOf[ExplicitOuter]) private val gettersCache = new PhaseCache(classOf[Getters]) private val genBCodeCache = new PhaseCache(classOf[GenBCode]) + private val specializeCache = new PhaseCache(classOf[TypeSpecializer]) def typerPhase = typerCache.phase def picklerPhase = picklerCache.phase @@ -252,6 +253,7 @@ object Phases { def explicitOuterPhase = explicitOuterCache.phase def gettersPhase = gettersCache.phase def genBCodePhase = genBCodeCache.phase + def specializePhase = specializeCache.phase def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id } diff --git a/src/dotty/tools/dotc/core/Symbols.scala b/src/dotty/tools/dotc/core/Symbols.scala index 2b4f806dd3ee..2b516a3448fb 100644 --- a/src/dotty/tools/dotc/core/Symbols.scala +++ b/src/dotty/tools/dotc/core/Symbols.scala @@ -404,7 +404,7 @@ object Symbols { (if(isDefinedInCurrentRun) lastDenot else denot).isTerm final def isType(implicit ctx: Context): Boolean = - (if(isDefinedInCurrentRun) lastDenot else denot).isType + (if (isDefinedInCurrentRun) lastDenot else denot).isType final def isClass: Boolean = isInstanceOf[ClassSymbol] diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index 0e9f5d9b217a..279c240b03c7 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -29,10 +29,12 @@ import Uniques._ import collection.{mutable, Seq, breakOut} import config.Config import config.Printers._ +import dotty.tools.sameLength import annotation.tailrec import Flags.FlagSet import typer.Mode import language.implicitConversions +import scala.collection.mutable.ListBuffer object Types { @@ -2220,9 +2222,10 @@ object Types { protected def computeSignature(implicit ctx: Context) = resultSignature - def instantiate(argTypes: List[Type])(implicit ctx: Context): Type = + def instantiate(argTypes: List[Type])(implicit ctx: Context): Type = { + assert(sameLength(argTypes, paramNames)) resultType.substParams(this, argTypes) - + } def instantiateBounds(argTypes: List[Type])(implicit ctx: Context): List[TypeBounds] = paramBounds.mapConserve(_.substParams(this, argTypes).bounds) @@ -2235,6 +2238,48 @@ object Types { x => paramBounds mapConserve (_.subst(this, x).bounds), x => resType.subst(this, x)) + /** Instantiate only some type parameters. + * @param argNum which parameters should be instantiated + * @param argTypes which types should be used for Instatiation + * @return a PolyType with (this.paramNames - argNum.size) type parameters left abstract + */ + def instantiate(argNum: List[Int], argTypes: List[Type])(implicit ctx: Context) = { + // merge original args list with supplied one + def mergeArgs(pp: PolyType, nxt: Int, id: Int, until: Int, argT: List[Type], argN: List[Int], res: ListBuffer[Type]): List[Type] = + if (id < until && argT.nonEmpty) { + if (argN.head == id) // we replace this poly param by supplied one + mergeArgs(pp, nxt, id + 1, until, argT.tail, argN.tail, res += argT.head) + else { // we create a PolyParam that is still not instantiated + val nw = PolyParam(pp, nxt) + res += nw + mergeArgs(pp, nxt + 1, id + 1, until, argT, argN, res) + } + } else { + res ++= nxt.until(nxt + until - id).map(PolyParam(pp, _)) + res.toList + } + def args(pp: PolyType) = mergeArgs(pp, 0, 0, argTypes.length + pp.paramNames.length, argTypes, argNum, ListBuffer.empty) + + def pnames(origPnames: List[TypeName] = paramNames, argN: List[Int] = argNum, id: Int = 0, tmp: ListBuffer[TypeName] = ListBuffer.empty): List[TypeName] = { + if (argN.isEmpty) { + tmp ++= origPnames + tmp.toList + } + else if (id == argN.head) { + pnames(origPnames.tail, argN.tail, id + 1, tmp) + } else { + pnames(origPnames.tail, argN, id + 1, tmp += origPnames.head) + } + } + + PolyType(pnames())( + x => { + val a = args(x) + paramBounds mapConserve (_.substParams(this, a).bounds) + }, + x => resType.substParams(this, args(x))) + } + // need to override hashCode and equals to be object identity // because paramNames by itself is not discriminatory enough override def equals(other: Any) = this eq other.asInstanceOf[AnyRef] @@ -2378,9 +2423,9 @@ object Types { * * @param origin The parameter that's tracked by the type variable. * @param creatorState The typer state in which the variable was created. - * @param owningTree The function part of the TypeApply tree tree that introduces + * @param owningTree The function part of the TypeApply tree that introduces * the type variable. - * @paran owner The current owner if the context where the variable was created. + * @param owner The current owner if the context where the variable was created. * * `owningTree` and `owner` are used to determine whether a type-variable can be instantiated * at some given point. See `Inferencing#interpolateUndetVars`. diff --git a/src/dotty/tools/dotc/transform/FullParameterization.scala b/src/dotty/tools/dotc/transform/FullParameterization.scala index e9057e885c47..c123daabdb38 100644 --- a/src/dotty/tools/dotc/transform/FullParameterization.scala +++ b/src/dotty/tools/dotc/transform/FullParameterization.scala @@ -139,7 +139,7 @@ trait FullParameterization { * fully parameterized method definition derived from `originalDef`, which * has `derived` as symbol and `fullyParameterizedType(originalDef.symbol.info)` * as info. - * `abstractOverClass` defines weather the DefDef should abstract over type parameters + * `abstractOverClass` defines whether the DefDef should abstract over type parameters * of class that contained original defDef */ def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree = diff --git a/src/dotty/tools/dotc/transform/PreSpecializer.scala b/src/dotty/tools/dotc/transform/PreSpecializer.scala new file mode 100644 index 000000000000..c91ece37eb4c --- /dev/null +++ b/src/dotty/tools/dotc/transform/PreSpecializer.scala @@ -0,0 +1,121 @@ +package dotty.tools.dotc.transform + +import dotty.tools.dotc.ast.Trees.{Ident, SeqLiteral, Typed} +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Annotations.Annotation +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.DenotTransformers.InfoTransformer +import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.Symbols.{ClassSymbol, NoSymbol, Symbol} +import dotty.tools.dotc.core.Types.{ClassInfo, Type} +import dotty.tools.dotc.core.{Definitions, Flags} +import dotty.tools.dotc.transform.TreeTransforms.{TreeTransform, MiniPhaseTransform, TransformerInfo} + +/** + * This phase retrieves all `@specialized` anotations, + * and stores them for the `TypeSpecializer` phase. + */ +class PreSpecializer extends MiniPhaseTransform { + + override def phaseName: String = "prespecialize" + + private var anyRefModule: Symbol = NoSymbol + private var specializableMapping: Map[Symbol, List[Type]] = _ + private var specializableModule: Symbol = NoSymbol + + + override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = { + specializableModule = ctx.requiredModule("scala.Specializable") + anyRefModule = ctx.requiredModule("scala.package") + def specializableField(nm: String) = specializableModule.info.member(nm.toTermName).symbol + + specializableMapping = Map( + specializableField("Primitives") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.ShortType, + defn.DoubleType, defn.BooleanType, defn.UnitType, defn.CharType, defn.ByteType), + specializableField("Everything") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.ShortType, + defn.DoubleType, defn.BooleanType, defn.UnitType, defn.CharType, defn.ByteType, defn.AnyRefType), + specializableField("Bits32AndUp") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.DoubleType), + specializableField("Integral") -> List(defn.ByteType, defn.ShortType, defn.IntType, defn.LongType, defn.CharType), + specializableField("AllNumeric") -> List(defn.ByteType, defn.ShortType, defn.IntType, defn.LongType, + defn.CharType, defn.FloatType, defn.DoubleType), + specializableField("BestOfBreed") -> List(defn.IntType, defn.DoubleType, defn.BooleanType, defn.UnitType, + defn.AnyRefType) + ) + this + } + + private final def primitiveCompanionToPrimitive(companion: Type)(implicit ctx: Context) = { + if (companion.termSymbol eq anyRefModule.info.member(nme.AnyRef.toTermName).symbol) { + defn.AnyRefType + } + else { + val claz = companion.termSymbol.companionClass + assert(defn.ScalaValueClasses.contains(claz)) + claz.typeRef + } + } + + private def specializableToPrimitive(specializable: Type, name: Name)(implicit ctx: Context): List[Type] = { + if (specializable.termSymbol eq specializableModule.info.member(name).symbol) { + specializableMapping(specializable.termSymbol) + } + else Nil + } + + def defn(implicit ctx: Context): Definitions = ctx.definitions + + private def primitiveTypes(implicit ctx: Context) = + List(ctx.definitions.ByteType, + ctx.definitions.BooleanType, + ctx.definitions.ShortType, + ctx.definitions.IntType, + ctx.definitions.LongType, + ctx.definitions.FloatType, + ctx.definitions.DoubleType, + ctx.definitions.CharType, + ctx.definitions.UnitType + ) + + def getSpec(sym: Symbol)(implicit ctx: Context): List[Type] = { + + def allowedToSpecialize(sym: Symbol): Boolean = { + sym.name != nme.asInstanceOf_ && + !(sym is Flags.JavaDefined) && + !sym.isPrimaryConstructor + } + + if (allowedToSpecialize(sym)) { + val annotation = sym.denot.getAnnotation(defn.SpecializedAnnot).getOrElse(Nil) + annotation match { + case annot: Annotation => + val args = annot.arguments + if (args.isEmpty) primitiveTypes + else args.head match { + case _ @ Typed(SeqLiteral(types), _) => + types.map(t => primitiveCompanionToPrimitive(t.tpe)) + case a @ Ident(groupName) => // Matches `@specialized` annotations on Specializable Groups + specializableToPrimitive(a.tpe.asInstanceOf[Type], groupName) + case _ => ctx.error("unexpected match on specialized annotation"); Nil + } + case nil => Nil + } + } else Nil + } + + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val tparams = tree.tparams.map(_.symbol) + val requests = tparams.zipWithIndex.map{case(sym, i) => (i, getSpec(sym))} + if (requests.nonEmpty) sendRequests(requests, tree) + tree + } + + def sendRequests(requests: List[(Int, List[Type])], tree: tpd.Tree)(implicit ctx: Context): Unit = { + requests.map { + case (index, types) if types.nonEmpty => + ctx.specializePhase.asInstanceOf[TypeSpecializer].registerSpecializationRequest(tree.symbol)(index, types) + case _ => + } + } +} diff --git a/src/dotty/tools/dotc/transform/TypeSpecializer.scala b/src/dotty/tools/dotc/transform/TypeSpecializer.scala new file mode 100644 index 000000000000..af52f6552a1f --- /dev/null +++ b/src/dotty/tools/dotc/transform/TypeSpecializer.scala @@ -0,0 +1,386 @@ +package dotty.tools.dotc.transform + +import dotty.tools.dotc.ast.{tpd, TreeTypeMap} +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.DenotTransformers.InfoTransformer +import dotty.tools.dotc.core.Names.TermName +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.{NameOps, Symbols, Flags} +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform} +import scala.collection.mutable +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools._ + +import scala.collection.mutable.ListBuffer + +class TypeSpecializer extends MiniPhaseTransform with InfoTransformer { + + import tpd._ + override def phaseName = "specialize" + + private def primitiveTypes(implicit ctx: Context) = + List(defn.ByteType, + defn.BooleanType, + defn.ShortType, + defn.IntType, + defn.LongType, + defn.FloatType, + defn.DoubleType, + defn.CharType, + defn.UnitType) + + private def defn(implicit ctx:Context) = ctx.definitions + + /** + * Methods requested for specialization + * Generic Symbol => List[ (position in list of args, specialized type requested) ] + */ + private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[(Int, List[Type])]] = mutable.HashMap.empty + + /** + * A list of instantiation values of generics (for recursive polymorphic methods) + */ + private val genericToInstantiation: mutable.HashMap[Symbols.Symbol, Type] = mutable.HashMap.empty + + /** + * A map that links symbols to their specialized variants. + * Each symbol maps to another map, from the list of specialization types to the specialized symbol. + * Generic symbol => + * Map{ List of [ Tuple(position in list of args, specialized Type) ] for each variant => Specialized Symbol } + */ + private val newSymbolMap: mutable.HashMap[Symbol, mutable.HashMap[List[(Int, Type)], Symbols.Symbol]] = mutable.HashMap.empty + + /** + * A map from specialised symbols to the indices of their remaining generic types + */ + private val newSymToGenIndices: mutable.HashMap[Symbol, List[Int]] = mutable.HashMap.empty + + /** + * A list of symbols gone through the specialisation pipeline + * Is used to make calls to transformInfo idempotent + */ + private val processed: ListBuffer[Symbol] = ListBuffer.empty + + def allowedToSpecialize(sym: Symbol, numOfTypes: Int)(implicit ctx: Context) = + numOfTypes > 0 && + sym.name != nme.asInstanceOf_ && + !newSymbolMap.contains(sym) && + !(sym is Flags.JavaDefined) && + !sym.isPrimaryConstructor + + + def getSpecTypes(method: Symbol, poly: PolyType)(implicit ctx: Context): List[(Int, List[Type])] = { + + val requested = specializationRequests.getOrElse(method, List.empty).toMap + if (requested.nonEmpty) { + poly.paramNames.zipWithIndex.map{case(name, i) => (i, requested.getOrElse(i, Nil))} + } + else { + if (ctx.settings.Yspecialize.value > 0) { + val filteredPrims = primitiveTypes.filter(tpe => poly.paramBounds.forall(_.contains(tpe))) + List.range(0, Math.min(poly.paramNames.length, ctx.settings.Yspecialize.value)).map(i => (i, filteredPrims)) + } + else Nil + } + } + + def requestedSpecialization(decl: Symbol)(implicit ctx: Context): Boolean = + ctx.settings.Yspecialize.value != 0 || specializationRequests.contains(decl) + + def registerSpecializationRequest(method: Symbols.Symbol)(index: Int, arguments: List[Type]) + (implicit ctx: Context) = { + if (ctx.phaseId > this.treeTransformPhase.id) + assert(ctx.phaseId <= this.treeTransformPhase.id) + val prev = specializationRequests.getOrElse(method, List.empty) + specializationRequests.put(method, (index, arguments) :: prev) + } + + override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = { + + def enterNewSyms(newDecls: List[Symbol], classInfo: ClassInfo) = { + val decls = classInfo.decls.cloneScope + newDecls.foreach(decls.enter) + classInfo.derivedClassInfo(decls = decls) + } + + def specializeMethods(sym: Symbol) = { + processed += sym + sym.info match { + case classInfo: ClassInfo => + val newDecls = classInfo.decls + .filter(_.symbol.isCompleted) // We do not want to force symbols. Unforced symbol are not used in the source + .filterNot(_.isConstructor) + .filter(requestedSpecialization) + .flatMap(decl => { + decl.info.widen match { + case poly: PolyType if allowedToSpecialize(decl.symbol, poly.paramNames.length) => + generateMethodSpecializations(getSpecTypes(decl, poly), List.empty)(poly, decl) + case _ => Nil + } + }) + + if (newDecls.nonEmpty) enterNewSyms(newDecls.toList, classInfo) + else tp + case poly: PolyType if allowedToSpecialize(sym, poly.paramNames.length) => + if (sym.owner.info.isInstanceOf[ClassInfo]) { + transformInfo(sym.owner.info, sym.owner) + tp + } + else if (requestedSpecialization(sym) && + allowedToSpecialize(sym, poly.paramNames.length)) { + generateMethodSpecializations(getSpecTypes(sym, poly), List.empty)(poly, sym) + tp + } + else tp + case _ => tp + } + } + + def generateMethodSpecializations(specTypes: List[(Int, List[Type])], instantiations: List[(Int, Type)]) + (poly: PolyType, decl: Symbol) + (implicit ctx: Context): List[Symbol] = { + if (specTypes.nonEmpty) { + specTypes.head match{ + case (i, tpes) if tpes.nonEmpty => + tpes.flatMap(tpe => + generateMethodSpecializations(specTypes.tail, (i, tpe) :: instantiations)(poly, decl) + ) + case (i, nil) => + generateMethodSpecializations(specTypes.tail, instantiations)(poly, decl) + } + } + else { + if (instantiations.isEmpty) Nil + else generateSpecializedSymbol(instantiations.reverse, poly, decl) :: Nil + } + } + + def generateSpecializedSymbol(instantiations: List[(Int, Type)], poly: PolyType, decl: Symbol) + (implicit ctx: Context): Symbol = { + val indices = instantiations.map(_._1) + val instanceTypes = instantiations.map(_._2) + val newSym = ctx.newSymbol( + decl.owner, + NameOps.NameDecorator(decl.name) + .specializedFor(Nil, Nil, instanceTypes, instanceTypes.map(_.asInstanceOf[NamedType].name)) + .asInstanceOf[TermName], + decl.flags | Flags.Synthetic, + { if (indices.length != poly.paramNames.length) // Partial Specialisation case + poly.instantiate(indices, instanceTypes) // Returns a PolyType with uninstantiated types kept generic + else + poly.instantiate(instanceTypes) // Returns a MethodType, no polymorphic type remains + } + ) + + val map = newSymbolMap.getOrElse(decl, mutable.HashMap.empty) + map.put(instantiations, newSym) + newSymbolMap.put(decl, map) + + newSymToGenIndices.put(newSym, indices) + + newSym + } + + if (!processed.contains(sym) && + (sym ne defn.ScalaPredefModule.moduleClass) && + !(sym is Flags.JavaDefined) && + !(sym is Flags.Scala2x) && + !(sym is Flags.Package) && + !sym.isAnonymousClass) { + specializeMethods(sym) + } else tp + } + + override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = { + + tree.tpe.widen match { + + case poly: PolyType + if !(tree.symbol.isPrimaryConstructor + || (tree.symbol is Flags.Label) + ) => + val origTParams = tree.tparams.map(_.symbol) + val origVParams = tree.vparamss.flatten.map(_.symbol) + + def specialize(decl : Symbol): List[Tree] = { + + def makeTypesList(origTSyms: List[Symbol], instantiation: Map[Int, Type], pt: PolyType): List[Type] = { + var holePos = -1 + origTSyms.zipWithIndex.map { + case (_, i) => instantiation.getOrElse(i, { + holePos += 1 + PolyParam(pt, holePos) + }).widen + } + } + + if (newSymbolMap.contains(decl)) { + val specInfo = newSymbolMap(decl) + val newSyms = specInfo.values.toList + val instantiationss = specInfo.keys.toArray + var index = -1 + ctx.debuglog(s"specializing ${tree.symbol} for $origTParams") + newSyms.map { newSym => + index += 1 + val newSymType = newSym.info.widenDealias + polyDefDef(newSym.asTerm, { tparams => vparams => { + val instTypes = newSymType match { + case pt: PolyType => + makeTypesList(origTParams, instantiationss(index).toMap, pt) // Will add missing PolyParams + case _ => instantiationss(index).map(_._2) + } + + val treemap: (Tree => Tree) = _ match { + case Return(t, from) if from.symbol == tree.symbol => Return(t, ref(newSym)) + case t: TypeApply => + (origTParams zip instTypes) + .foreach{case (genTpe, instTpe) => genericToInstantiation.put(genTpe, instTpe)} + transformTypeApply(t) + case t: Apply => + (origTParams zip instTypes) + .foreach{case (genTpe, instTpe) => genericToInstantiation.put(genTpe, instTpe)} + transformApply(t) + case t => t + } + + val abstractPolyType = tree.symbol.info.widenDealias.asInstanceOf[PolyType] + val vparamTpes = vparams.flatten.map(_.tpe) + val typemap = new TypeMap { + override def apply(tp: Type): Type = { + val t = mapOver(tp) + .substDealias(origTParams, instTypes) + .substParams(abstractPolyType, instTypes) + .subst(origVParams, vparamTpes) + newSymType match { + case mt: MethodType if tparams.isEmpty => + t.substParams(newSymType.asInstanceOf[MethodType], vparamTpes) + case pt: PolyType => + t.substParams(newSymType.asInstanceOf[PolyType], tparams) + .substParams(newSymType.resultType.asInstanceOf[MethodType], vparamTpes) + case _ => t + } + } + } + + val typesReplaced = new TreeTypeMap( + treeMap = treemap, + typeMap = typemap, + oldOwners = tree.symbol :: Nil, + newOwners = newSym :: Nil + ).transform(tree.rhs) + + val tp = new TreeMap() { + // needed to workaround https://github.com/lampepfl/dotty/issues/592 + override def transform(tree1: Tree)(implicit ctx: Context) = super.transform(tree1) match { + case t @ Apply(fun, args) => + assert(sameLength(args, fun.tpe.widen.firstParamTypes), + s"Wrong number of parameters." + + s"Expected: ${fun.tpe.widen.firstParamTypes.length}." + + s"Found: ${args.length}") + val newArgs = (args zip fun.tpe.widen.firstParamTypes).map{ + case(arg, tpe) => + assert(tpe.widen ne NoType, "Bad cast when specializing") + arg.ensureConforms(typemap(tpe.widen)) + } + if (sameTypes(args, newArgs)) { + t + } + else tpd.Apply(fun, newArgs) + case t: ValDef => + cpy.ValDef(t)(rhs = if (t.rhs.isEmpty) EmptyTree else + t.rhs.ensureConforms(t.tpt.tpe)) + case t: DefDef => + cpy.DefDef(t)(rhs = if (t.rhs.isEmpty) EmptyTree else + t.rhs.ensureConforms(t.tpt.tpe)) + case t: TypeTree => + t.tpe match { + case pp: PolyParam => + TypeTree(tparams(pp.paramNum)) + case _ => t + } + case t => t + }} + val expectedTypeFixed = tp.transform(typesReplaced) + if (expectedTypeFixed ne EmptyTree) { + expectedTypeFixed.ensureConforms(typemap(newSym.info.widen.finalResultType.widenDealias)) + } + else expectedTypeFixed + }}) + } + } else Nil + } + val specializedTrees = specialize(tree.symbol) + Thicket(tree :: specializedTrees) + case _ => tree + } + } + + def rewireTree(tree: Tree)(implicit ctx: Context): Tree = { + assert(tree.isInstanceOf[TypeApply]) + val TypeApply(fun,args) = tree + if (newSymbolMap.contains(fun.symbol)){ + val newSymInfo = newSymbolMap(fun.symbol) + val betterDefs = newSymInfo.filter{ + case (instantiations, symbol) => { + instantiations.forall { + case (ord, specTpe) => + args(ord).tpe <:< specTpe + }}}.toList + + if (betterDefs.length > 1) { + ctx.debuglog(s"Several specialized variants fit for ${fun.symbol.name} of ${fun.symbol.owner}." + + s" Defaulting to no specialization.") + tree + } + + else if (betterDefs.nonEmpty) { + val newFunSym = betterDefs.head._2 + ctx.debuglog(s"method ${fun.symbol.name} of ${fun.symbol.owner} rewired to specialized variant") + val prefix = fun match { + case Select(pre, name) => + pre + case t @ Ident(_) if t.tpe.isInstanceOf[TermRef] => + val tp = t.tpe.asInstanceOf[TermRef] + if (tp.prefix ne NoPrefix) + ref(tp.prefix.termSymbol) + else EmptyTree + case _ => EmptyTree + } + if (prefix ne EmptyTree) prefix.select(newFunSym) + else ref(newFunSym) + } else tree + } else tree + } + + override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): Tree = { + val TypeApply(fun, _) = tree + if (fun.tpe.widenDealias.isParameterless) rewireTree(tree) + else tree + } + + override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo): Tree = { + val Apply(fun, args) = tree + fun match { + case fun: TypeApply => + val TypeApply(_, typeArgs) = fun + val newFun = rewireTree(fun) + if (fun ne newFun) { + newFun.symbol.info.widenDealias match { + case pt: PolyType => // Need to apply types to the remaining generics first + val tpeOfRemainingGenerics = + typeArgs.zipWithIndex.filterNot(x => newSymToGenIndices(newFun.symbol).contains(x._2)).map(_._1) + assert(tpeOfRemainingGenerics.nonEmpty, + s"Remaining generics on ${newFun.symbol.name} not properly instantiated: missing types") + Apply(TypeApply(newFun, tpeOfRemainingGenerics), args) + case _ => + Apply(newFun, args) + } + } else tree + case fun : Apply => + Apply(transformApply(fun), args) + case _ => tree + } + } +} diff --git a/src/dotty/tools/dotc/transform/ValueClasses.scala b/src/dotty/tools/dotc/transform/ValueClasses.scala index 93005c57ae26..f0762b406064 100644 --- a/src/dotty/tools/dotc/transform/ValueClasses.scala +++ b/src/dotty/tools/dotc/transform/ValueClasses.scala @@ -22,6 +22,7 @@ object ValueClasses { def isMethodWithExtension(d: SymDenotation)(implicit ctx: Context) = d.isRealMethod && + !(d.initial.validFor.firstPhaseId > ctx.extensionMethodsPhase.id) && isDerivedValueClass(d.owner) && !d.isConstructor && !d.is(SuperAccessor) && diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala index c45db4ccc827..3a4374843d11 100644 --- a/src/dotty/tools/dotc/typer/Applications.scala +++ b/src/dotty/tools/dotc/typer/Applications.scala @@ -206,7 +206,7 @@ trait Applications extends Compatibility { self: Typer => /** @param pnames The list of parameter names that are missing arguments * @param args The list of arguments that are not yet passed, or that are waiting to be dropped * @param nameToArg A map from as yet unseen names to named arguments - * @param toDrop A set of names that have already be passed as named arguments + * @param toDrop A set of names that have already been passed as named arguments * * For a well-typed application we have the invariants * diff --git a/test/dotc/tests.scala b/test/dotc/tests.scala index b39d0e928951..db8775b59e00 100644 --- a/test/dotc/tests.scala +++ b/test/dotc/tests.scala @@ -91,6 +91,7 @@ class tests extends CompilerTest { @Test def pos_packageObj = compileFile(posDir, "i0239", twice) @Test def pos_anonClassSubtyping = compileFile(posDir, "anonClassSubtyping", twice) @Test def pos_extmethods = compileFile(posDir, "extmethods", twice) + @Test def pos_specialization = compileDir(posDir, "specialization", twice) @Test def pos_all = compileFiles(posDir) // twice omitted to make tests run faster diff --git a/tests/pos/specialization/anyRef_specialization.scala b/tests/pos/specialization/anyRef_specialization.scala new file mode 100644 index 000000000000..c41d0ba484b8 --- /dev/null +++ b/tests/pos/specialization/anyRef_specialization.scala @@ -0,0 +1,6 @@ +object Test { + def foo[@specialized(AnyRef) T](t: T): T = t + def main (args: Array[String]) = { + foo(5) + } +} diff --git a/tests/pos/specialization/bounds_specialization.scala b/tests/pos/specialization/bounds_specialization.scala new file mode 100644 index 000000000000..9cc72b7853b3 --- /dev/null +++ b/tests/pos/specialization/bounds_specialization.scala @@ -0,0 +1,21 @@ +object bounds_specialization { + class Foo[@specialized K] { + def bar[@specialized U](u: U) = { + def dough[@specialized V](v: V) = { + println("innerMethod") + } + dough(1.toShort) + dough('c') + } + bar(2.toShort) + bar('d') + } + + def kung[@specialized(Int, Double) T <: AnyRef](t: T): T = { + t + } + + def fu[@specialized(Int, Double) T >: Nothing](t: T): T = { + t + } +} \ No newline at end of file diff --git a/tests/pos/specialization/genericClass_specialization.scala b/tests/pos/specialization/genericClass_specialization.scala new file mode 100644 index 000000000000..5c8de766f48c --- /dev/null +++ b/tests/pos/specialization/genericClass_specialization.scala @@ -0,0 +1,6 @@ +object genericClass_specialization { + class A[T] { + def foo[@specialized(Int, Char, Double) U](b: U) = b + } + def foobar[@specialized(Char) X] = new A[X].foo(2) +} diff --git a/tests/pos/specialization/method_in_class_specialization.scala b/tests/pos/specialization/method_in_class_specialization.scala new file mode 100644 index 000000000000..9b9c74f4fdec --- /dev/null +++ b/tests/pos/specialization/method_in_class_specialization.scala @@ -0,0 +1,42 @@ +object method_in_class_specialization { + class A { + def foo[@specialized(Int, Long) T](a: T) = List() + } + class B[K] { + def foo[@specialized T](b: T) = List() + } + class C extends A { + override def foo[@specialized T](c: T) = super.foo(c) + def bar[@specialized(Float, Char) U](c: U) = super.foo(c) + def man[@specialized V](c: V) = List() + } + class D extends B { + override def foo[@specialized T](d: T) = super.foo(d) + def bar[@specialized U](d: U) = super.foo(d) + def man[@specialized V](d: V) = List() + } + class E[U] extends B { + override def foo[@specialized T](e: T) = super.foo(e) + def bar[@specialized U](e: U) = super.foo(e) + def man[@specialized V](e: V) = List() + } + + val a = new A + val b = new B[Int] + val c = new C + val d = new D + val e = new E[Char] + + a.foo(1) + a.foo(1.toLong) + a.foo("foo") + b.foo(2) + c.foo(3) + d.foo(4) + e.foo(5) + + c.bar('d') + c.bar(6) + e.bar('d') + e.bar(7) +} \ No newline at end of file diff --git a/tests/pos/specialization/method_in_method_specialization.scala b/tests/pos/specialization/method_in_method_specialization.scala new file mode 100644 index 000000000000..af5e6c535d1d --- /dev/null +++ b/tests/pos/specialization/method_in_method_specialization.scala @@ -0,0 +1,19 @@ +object method_in_method_specialization { + def outer[@specialized(Int) O](o: O): O = { + def inner[@specialized(Int, Char) I](i: I): O = i match { + case o2: O => o2 + case _ => o + } + inner(1) + inner('c') + } + + outer(2) + outer('d') + + def outer2[@specialized(Int) O](o: O): Int = { + def inner2[@specialized(Int) I] (i: I) = 1 + inner2(42) + } + outer2(1) +} \ No newline at end of file diff --git a/tests/pos/specialization/multi_specialization.scala b/tests/pos/specialization/multi_specialization.scala new file mode 100644 index 000000000000..7ad357253850 --- /dev/null +++ b/tests/pos/specialization/multi_specialization.scala @@ -0,0 +1,9 @@ +object multi_specialization { + def one[@specialized T](n: T): T = n + def two[@specialized(Int, Double) T,@specialized(Double, Int) U](n: T, m: U): (T,U) = (n,m) + def three[@specialized(Int, Double) T,@specialized(Double, Int) U, V](n: T, m: U, o: V): (T,U,V) = (n,m,o) + + one(1) + two(1,2) + two('a', null) +} \ No newline at end of file diff --git a/tests/pos/specialization/mutual_specialization.scala b/tests/pos/specialization/mutual_specialization.scala new file mode 100644 index 000000000000..849aaba2b459 --- /dev/null +++ b/tests/pos/specialization/mutual_specialization.scala @@ -0,0 +1,6 @@ +object mutual_specialization { + class A[T] { + def foo[@specialized(Double) U](b: U, n: Int): Unit = if (n > 0) bar(b, n-1) + def bar[@specialized(Double) V](a: V, n: Int): Unit = if (n > 0) foo(a, n-1) + } +} diff --git a/tests/pos/specialization/nothing_specialization.scala b/tests/pos/specialization/nothing_specialization.scala new file mode 100644 index 000000000000..39fb35e273db --- /dev/null +++ b/tests/pos/specialization/nothing_specialization.scala @@ -0,0 +1,7 @@ +object nothing_specialization { + def ret_nothing[@specialized(Char) T] = { + def apply[@specialized(Char) X](xs : X*) : List[X] = List(xs:_*) + def apply6[@specialized(Char) X](xs : Nothing*) : List[Nothing] = List(xs: _*) + def apply2[@specialized(Long) U] = 1.asInstanceOf[U] + } +} diff --git a/tests/pos/specialization/partial_specialization.scala b/tests/pos/specialization/partial_specialization.scala new file mode 100644 index 000000000000..989db1f9df01 --- /dev/null +++ b/tests/pos/specialization/partial_specialization.scala @@ -0,0 +1,16 @@ +trait partial_specialization { + def foo1stOutOf1[@specialized(Int, Char) T](t: T) = ??? + def foo1stOutOf2[@specialized(Int, Char) T, U](t: T, u: U): T = t + def foo2ndOutOf2[T, @specialized(Int, Char) U](t: T, u: U) = ??? + def fooAllOutOf2[@specialized(Int, Char) T, @specialized(Int, Char) U](t: T, u: U) = ??? + def foo1st3rdOutOf3[@specialized(Int, Char) T, U, @specialized(Int, Char) V](t: T, u: U, v: V) = ??? + + def main(args: Array[String]) = { + foo1stOutOf2(1, 2.0) + foo1stOutOf2(1, 2.0) + foo1st3rdOutOf3(1, 2, 'c') + foo2ndOutOf2(1, 'c') + fooAllOutOf2('a','b') + fooAllOutOf2(1.0,1.0) + } +} \ No newline at end of file diff --git a/tests/pos/specialization/recursive_specialization.scala b/tests/pos/specialization/recursive_specialization.scala new file mode 100644 index 000000000000..a48e8a358292 --- /dev/null +++ b/tests/pos/specialization/recursive_specialization.scala @@ -0,0 +1,12 @@ +object recursive_specialization { + class Spec { + def plus[@specialized T](a: T, b:T)(ev: Numeric[T]): T = plus(b, a)(ev) + } + + class IntSpec extends Spec { + lazy val res = plus(1,2)(Numeric.IntIsIntegral) + } + + def main(args: Array[String]) = { + } +} diff --git a/tests/pos/specialization/return_specialization.scala b/tests/pos/specialization/return_specialization.scala new file mode 100644 index 000000000000..8cfddbbd1f3d --- /dev/null +++ b/tests/pos/specialization/return_specialization.scala @@ -0,0 +1,6 @@ +object return_specialization { + def qwa[@specialized(Int) T](a: (T, T) => T, b: T): T = { + if(a ne this) return a(b, b) + else b + } +} diff --git a/tests/pos/specialization/simple_specialization.scala b/tests/pos/specialization/simple_specialization.scala new file mode 100644 index 000000000000..e835ddee559b --- /dev/null +++ b/tests/pos/specialization/simple_specialization.scala @@ -0,0 +1,14 @@ +trait simple_specialization { + def printer1[@specialized(Int, Long) T](a: T) = { + println(a.toString) + } + def printer2[@specialized(Int, Long) T, U](a: T, b: U) = { + println(a.toString + b.toString) + } + def print(i: Int) = { + printer1(i) + println(" ---- ") + printer2(i,i) + } + print(9) +} diff --git a/tests/pos/specialization/specializable_specialization.scala b/tests/pos/specialization/specializable_specialization.scala new file mode 100644 index 000000000000..c13b51fc6b0d --- /dev/null +++ b/tests/pos/specialization/specializable_specialization.scala @@ -0,0 +1,15 @@ +import Specializable._ + +object specializable_specialization { + def foo[@specialized(Primitives) T](t: T): T = t + def foo2[@specialized(Everything) T](t: T): T = t + def foo3[@specialized(Bits32AndUp) T](t: T): T = t + def foo4[@specialized(Integral) T](t: T): T = t + def foo5[@specialized(AllNumeric) T](t: T): T = t + def foo6[@specialized(BestOfBreed) T](t: T): T = t + + def main(args: Array[String]) = { + foo('c') + foo5('c') + } +} diff --git a/tests/pos/specialization/subtype_specialization.scala b/tests/pos/specialization/subtype_specialization.scala new file mode 100644 index 000000000000..1ce6d570f6b8 --- /dev/null +++ b/tests/pos/specialization/subtype_specialization.scala @@ -0,0 +1,13 @@ +object subtype_specialization { + + class Seq[+A] + + case class FirstName[T](name: String) extends Seq[Char] {} + + def foo[@specialized(Char) A](stuff: Seq[A]): Seq[A] = { + stuff + } + + val s: Seq[FirstName] = foo[FirstName](new Seq[FirstName]) + +} diff --git a/tests/pos/specialization/this_specialization.scala b/tests/pos/specialization/this_specialization.scala new file mode 100644 index 000000000000..40244f661f7a --- /dev/null +++ b/tests/pos/specialization/this_specialization.scala @@ -0,0 +1,14 @@ +trait Foo[@specialized +A] { +// all those examples trigger bugs due to https://github.com/lampepfl/dotty/issues/592 + def bop[@specialized B >: A]: Foo[B] = new Bar[B](this) + def gwa[@specialized B >: A]: Foo[B] = this + def gwd[@specialized B >: A]: Foo[B] = { + val d: Foo[B] = this + d + } + //def bip[@specialized C >: A, @specialized D >: A]: Foo[D] = new Cho[D, C](new Bar[C](this)) +} + +case class Bar[@specialized a](tl: Foo[a]) extends Foo[a] + +//case class Cho[@specialized c, @specialized d](tl: Bar[d]) extends Foo[c] diff --git a/tests/pos/specialization/type_test.scala b/tests/pos/specialization/type_test.scala new file mode 100644 index 000000000000..570cfdf33d54 --- /dev/null +++ b/tests/pos/specialization/type_test.scala @@ -0,0 +1,3 @@ +object type_test { + def typeTest(i: Char): Unit = i.isInstanceOf[Int] +} \ No newline at end of file diff --git a/tests/run/method-specialization.check b/tests/run/method-specialization.check new file mode 100644 index 000000000000..46894d008a35 --- /dev/null +++ b/tests/run/method-specialization.check @@ -0,0 +1,7 @@ +10 +82 +3 +10 +int +double,int +int,class java.lang.Object \ No newline at end of file diff --git a/tests/run/method-specialization.scala b/tests/run/method-specialization.scala new file mode 100644 index 000000000000..ba95ac51f585 --- /dev/null +++ b/tests/run/method-specialization.scala @@ -0,0 +1,55 @@ +object Test extends dotty.runtime.LegacyApp { + + class Foo { + def foo[@specialized U](u: U) = u + } + class Bar { + def bar[@specialized U,@specialized V](u: U, v: V) = v + } + class Baz { + def baz[@specialized(Int, Char) V](v: V): V = v + } + class Kung { + def kung[@specialized U, V](u: U, v: V) = println(u.getClass) + } + + override def main(args: Array[String]): Unit = { + /** + * Expected output is: + * + * 10 + * 82 + * 3 + * 10 + * int + * double,int + * int,class java.lang.Object + */ + + val a = new Foo + val b = new Bar + val c = new Baz + val d = new Kung + val foo_methods = a.getClass.getMethods + val bar_methods = b.getClass.getMethods + val baz_methods = c.getClass.getMethods + val kung_methods = d.getClass.getMethods + println(foo_methods.filter(_.toString.contains("foo")).length) + println(bar_methods.filter(_.toString.contains("bar")).length) + println(baz_methods.filter(_.toString.contains("baz")).length) + println(kung_methods.filter(_.toString.contains("kung")).length) + + val baz_int_param = baz_methods.filter(_.toString.contains("$mIc$sp")).head.getParameterTypes.mkString(",") + val bar_int_double_params = bar_methods.filter(s => s.toString.contains("$mDIc$sp")) + val kung_int_gen_params = kung_methods.filter(s => s.toString.contains("$mIc$sp")) + println(baz_int_param) + println(bar_int_double_params.head.getParameterTypes.mkString(",")) + println(kung_int_gen_params.head.getParameterTypes.mkString(",")) + + def genericKung[A](a: A) = d.kung(a, a) + genericKung(1) + + d.kung(1, 1) + d.kung(1.0, 1.0) + } +} \ No newline at end of file