Skip to content

Implement method type specialisation #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c1f5842
Add a TypeSpecializer Phase
AlexSikia Mar 1, 2015
cd8eae0
Move TypeSpecializer before ElimByName and implement check for @speci…
AlexSikia Mar 17, 2015
af2ff52
Add method specialization on specified Types.
AlexSikia Jun 2, 2015
5a1e491
Implement type specialization with specified Types.
AlexSikia Jun 2, 2015
d76f8bb
Add InfoTransformer Trait to TypeSpecializer
AlexSikia Jun 2, 2015
7817fde
Add specialized methods dispatch
AlexSikia Apr 20, 2015
3366e45
Add specialised method dispatching
AlexSikia Jun 2, 2015
30566fc
Add `PreSpecializer` phase
AlexSikia Jun 3, 2015
0280142
Specialize methods defined inside of other methods
AlexSikia Jun 3, 2015
f56c7ae
Adapt instance of `TreeTypeMap` to map trees recursively
May 18, 2015
b6fcb63
Workaround https://github.com/lampepfl/dotty/issues/592 Insert casts …
DarkDimius May 20, 2015
a80c443
Add casts, and debug implementation
AlexSikia May 30, 2015
843f547
Do not look into scala- or java-defined symbols
AlexSikia Jun 3, 2015
deeea64
Add run test for specialisation
AlexSikia Jun 3, 2015
d40d01c
Handle `@specialized(AnyRef)`
AlexSikia Jun 4, 2015
4f8b423
Add `out/.keep` file back
AlexSikia Jun 4, 2015
8df6a1a
Clean up testing for Jenkins
AlexSikia Jun 4, 2015
aad8578
Clean up scalastyle on several files
AlexSikia Jun 5, 2015
15f1219
Check for `@specialized(AnyRef)`based on symbols
AlexSikia Jun 5, 2015
b5c5ced
Fix specialised method dispatch to recursive calls
AlexSikia Jun 8, 2015
013b7c3
Fix catching of specialized annotations
AlexSikia Jun 8, 2015
03c0c3c
Adapt specialization tests and clean up
AlexSikia Jun 8, 2015
ab3d717
Correct name mangling
AlexSikia Jun 22, 2015
07e9ae1
SpecializeNames: Duplicate scalac behaviour, sort tparams
DarkDimius Jun 25, 2015
68e4f6e
Allow to instantiate only some type params of a PolyType
DarkDimius Mar 9, 2015
ff9c583
Correct typos
AlexSikia Aug 3, 2015
266194f
Fix partial instantiation of PolyTypes
AlexSikia Jul 23, 2015
6599f59
Implement Partial Specialisation
AlexSikia Jul 26, 2015
ce2b561
Simplify code and clean up
AlexSikia Jul 26, 2015
4f7f3a7
Change Yspecialize behaviour
AlexSikia Jul 26, 2015
87d59fa
Reduce restrictions on methods allowed for specialization
AlexSikia Jul 26, 2015
cc653e4
Use a typemap when transforming DefDef's
AlexSikia Jul 26, 2015
0964663
Add test cases for specialisation
AlexSikia Jul 28, 2015
d48c520
Fix transformation of TypeApply in the specialised case
AlexSikia Aug 10, 2015
6579d82
Remove sorting of method args names when mangling for specialisation
AlexSikia Aug 11, 2015
1cc3c83
Clean up code
AlexSikia Aug 20, 2015
417fc96
Cast return value of `specializedFor` to `TermName`
AlexSikia Aug 21, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class Compiler {
List(new Pickler),
List(new FirstTransform,
new CheckReentrant),
List(new RefChecks,
List(new PreSpecializer,
new RefChecks,
new ElimRepeated,
new NormalizeFlags,
new ExtensionMethods,
Expand All @@ -53,6 +54,7 @@ class Compiler {
List(new PatternMatcher,
new ExplicitOuter,
new Splitter),
List(new TypeSpecializer),
List(new VCInlineMethods,
new SeqLiterals,
new InterceptedMethods,
Expand Down
4 changes: 2 additions & 2 deletions src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import dotty.tools.dotc.transform.SymUtils._
* gets two different denotations in the same period. Hence, if -Yno-double-bindings is
* set, we would get a data race assertion error.
*/
final class TreeTypeMap(
class TreeTypeMap(
val typeMap: Type => Type = IdentityTypeMap,
val treeMap: tpd.Tree => tpd.Tree = identity _,
val oldOwners: List[Symbol] = Nil,
Expand Down Expand Up @@ -75,7 +75,7 @@ final class TreeTypeMap(
updateDecls(prevStats.tail, newStats.tail)
}

override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = treeMap(tree) match {
override final def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = treeMap(tree) match {
case impl @ Template(constr, parents, self, _) =>
val tmap = withMappedSyms(localSyms(impl :: self :: Nil))
cpy.Template(impl)(
Expand Down
2 changes: 2 additions & 0 deletions src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class ScalaSettings extends Settings.SettingGroup {
val YprintSyms = BooleanSetting("-Yprint-syms", "when printing trees print info in symbols instead of corresponding info in trees.")
val YtestPickler = BooleanSetting("-Ytest-pickler", "self-test for pickling functionality; should be used with -Ystop-after:pickler")
val YcheckReentrant = BooleanSetting("-Ycheck-reentrant", "check that compiled program does not contain vars that can be accessed from a global root.")
val Yspecialize = IntSetting("-Yspecialize","Specialize methods with maximum this amount of polymorphic types.", 0, 0 to 10)

def stop = YstopAfter

/** Area-specific debug output.
Expand Down
1 change: 1 addition & 0 deletions src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class Definitions {
lazy val TransientAnnot = ctx.requiredClass("scala.transient")
lazy val NativeAnnot = ctx.requiredClass("scala.native")
lazy val ScalaStrictFPAnnot = ctx.requiredClass("scala.annotation.strictfp")
lazy val SpecializedAnnot = ctx.requiredClass("scala.specialized")

// Annotation classes
lazy val AliasAnnot = ctx.requiredClass("dotty.annotation.internal.Alias")
Expand Down
9 changes: 5 additions & 4 deletions src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package core
import java.security.MessageDigest
import scala.annotation.switch
import scala.io.Codec
import Names._, StdNames._, Contexts._, Symbols._, Flags._
import Names._, dotty.tools.dotc.core.StdNames._, Contexts._, Symbols._, Flags._
import Decorators.StringDecorator
import util.{Chars, NameTransformer}
import Chars.isOperatorPart
Expand Down Expand Up @@ -241,10 +241,11 @@ object NameOps {
case nme.clone_ => nme.clone_
}

def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTargsNames: List[Name])(implicit ctx: Context): name.ThisName = {

def typeToTag(tp: Types.Type): Name = {
tp.classSymbol match {
if (tp eq null) nme.EMPTY
else tp.classSymbol match {
case t if t eq defn.IntClass => nme.specializedTypeNames.Int
case t if t eq defn.BooleanClass => nme.specializedTypeNames.Boolean
case t if t eq defn.ByteClass => nme.specializedTypeNames.Byte
Expand All @@ -258,7 +259,7 @@ object NameOps {
}
}

val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => typeToTag(x._1))
val methodTags: Seq[Name] = (methodTargs zip methodTargsNames).map(x => typeToTag(x._1))
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => typeToTag(x._1))

name.fromName(name ++ nme.specializedTypeNames.prefix ++
Expand Down
8 changes: 5 additions & 3 deletions src/dotty/tools/dotc/core/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ object Names {
def compare(x: Name, y: Name): Int = {
if (x.isTermName && y.isTypeName) 1
else if (x.isTypeName && y.isTermName) -1
else if (x eq y) 0
else if (x.start == y.start && x.length == y.length) 0
else {
val until = x.length min y.length
val until = Math.min(x.length, y.length)
var i = 0

while (i < until && x(i) == y(i)) i = i + 1
Expand All @@ -364,7 +364,9 @@ object Names {
if (x(i) < y(i)) -1
else /*(x(i) > y(i))*/ 1
} else {
x.length - y.length
if (x.length < y.length) 1
else if (x.length > y.length) -1
else 0 // shouldn't happen, but still
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ object Phases {
private val explicitOuterCache = new PhaseCache(classOf[ExplicitOuter])
private val gettersCache = new PhaseCache(classOf[Getters])
private val genBCodeCache = new PhaseCache(classOf[GenBCode])
private val specializeCache = new PhaseCache(classOf[TypeSpecializer])

def typerPhase = typerCache.phase
def picklerPhase = picklerCache.phase
Expand All @@ -252,6 +253,7 @@ object Phases {
def explicitOuterPhase = explicitOuterCache.phase
def gettersPhase = gettersCache.phase
def genBCodePhase = genBCodeCache.phase
def specializePhase = specializeCache.phase

def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id
}
Expand Down
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ object Symbols {
(if(isDefinedInCurrentRun) lastDenot else denot).isTerm

final def isType(implicit ctx: Context): Boolean =
(if(isDefinedInCurrentRun) lastDenot else denot).isType
(if (isDefinedInCurrentRun) lastDenot else denot).isType

final def isClass: Boolean = isInstanceOf[ClassSymbol]

Expand Down
53 changes: 49 additions & 4 deletions src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ import Uniques._
import collection.{mutable, Seq, breakOut}
import config.Config
import config.Printers._
import dotty.tools.sameLength
import annotation.tailrec
import Flags.FlagSet
import typer.Mode
import language.implicitConversions
import scala.collection.mutable.ListBuffer

object Types {

Expand Down Expand Up @@ -2220,9 +2222,10 @@ object Types {

protected def computeSignature(implicit ctx: Context) = resultSignature

def instantiate(argTypes: List[Type])(implicit ctx: Context): Type =
def instantiate(argTypes: List[Type])(implicit ctx: Context): Type = {
assert(sameLength(argTypes, paramNames))
resultType.substParams(this, argTypes)

}
def instantiateBounds(argTypes: List[Type])(implicit ctx: Context): List[TypeBounds] =
paramBounds.mapConserve(_.substParams(this, argTypes).bounds)

Expand All @@ -2235,6 +2238,48 @@ object Types {
x => paramBounds mapConserve (_.subst(this, x).bounds),
x => resType.subst(this, x))

/** Instantiate only some type parameters.
* @param argNum which parameters should be instantiated
* @param argTypes which types should be used for Instatiation
* @return a PolyType with (this.paramNames - argNum.size) type parameters left abstract
*/
def instantiate(argNum: List[Int], argTypes: List[Type])(implicit ctx: Context) = {
// merge original args list with supplied one
def mergeArgs(pp: PolyType, nxt: Int, id: Int, until: Int, argT: List[Type], argN: List[Int], res: ListBuffer[Type]): List[Type] =
if (id < until && argT.nonEmpty) {
if (argN.head == id) // we replace this poly param by supplied one
mergeArgs(pp, nxt, id + 1, until, argT.tail, argN.tail, res += argT.head)
else { // we create a PolyParam that is still not instantiated
val nw = PolyParam(pp, nxt)
res += nw
mergeArgs(pp, nxt + 1, id + 1, until, argT, argN, res)
}
} else {
res ++= nxt.until(nxt + until - id).map(PolyParam(pp, _))
res.toList
}
def args(pp: PolyType) = mergeArgs(pp, 0, 0, argTypes.length + pp.paramNames.length, argTypes, argNum, ListBuffer.empty)

def pnames(origPnames: List[TypeName] = paramNames, argN: List[Int] = argNum, id: Int = 0, tmp: ListBuffer[TypeName] = ListBuffer.empty): List[TypeName] = {
if (argN.isEmpty) {
tmp ++= origPnames
tmp.toList
}
else if (id == argN.head) {
pnames(origPnames.tail, argN.tail, id + 1, tmp)
} else {
pnames(origPnames.tail, argN, id + 1, tmp += origPnames.head)
}
}

PolyType(pnames())(
x => {
val a = args(x)
paramBounds mapConserve (_.substParams(this, a).bounds)
},
x => resType.substParams(this, args(x)))
}

// need to override hashCode and equals to be object identity
// because paramNames by itself is not discriminatory enough
override def equals(other: Any) = this eq other.asInstanceOf[AnyRef]
Expand Down Expand Up @@ -2378,9 +2423,9 @@ object Types {
*
* @param origin The parameter that's tracked by the type variable.
* @param creatorState The typer state in which the variable was created.
* @param owningTree The function part of the TypeApply tree tree that introduces
* @param owningTree The function part of the TypeApply tree that introduces
* the type variable.
* @paran owner The current owner if the context where the variable was created.
* @param owner The current owner if the context where the variable was created.
*
* `owningTree` and `owner` are used to determine whether a type-variable can be instantiated
* at some given point. See `Inferencing#interpolateUndetVars`.
Expand Down
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/transform/FullParameterization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ trait FullParameterization {
* fully parameterized method definition derived from `originalDef`, which
* has `derived` as symbol and `fullyParameterizedType(originalDef.symbol.info)`
* as info.
* `abstractOverClass` defines weather the DefDef should abstract over type parameters
* `abstractOverClass` defines whether the DefDef should abstract over type parameters
* of class that contained original defDef
*/
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
Expand Down
121 changes: 121 additions & 0 deletions src/dotty/tools/dotc/transform/PreSpecializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package dotty.tools.dotc.transform

import dotty.tools.dotc.ast.Trees.{Ident, SeqLiteral, Typed}
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Annotations.Annotation
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.core.Symbols.{ClassSymbol, NoSymbol, Symbol}
import dotty.tools.dotc.core.Types.{ClassInfo, Type}
import dotty.tools.dotc.core.{Definitions, Flags}
import dotty.tools.dotc.transform.TreeTransforms.{TreeTransform, MiniPhaseTransform, TransformerInfo}

/**
* This phase retrieves all `@specialized` anotations,
* and stores them for the `TypeSpecializer` phase.
*/
class PreSpecializer extends MiniPhaseTransform {

override def phaseName: String = "prespecialize"

private var anyRefModule: Symbol = NoSymbol
private var specializableMapping: Map[Symbol, List[Type]] = _
private var specializableModule: Symbol = NoSymbol


override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = {
specializableModule = ctx.requiredModule("scala.Specializable")
anyRefModule = ctx.requiredModule("scala.package")
def specializableField(nm: String) = specializableModule.info.member(nm.toTermName).symbol

specializableMapping = Map(
specializableField("Primitives") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.ShortType,
defn.DoubleType, defn.BooleanType, defn.UnitType, defn.CharType, defn.ByteType),
specializableField("Everything") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.ShortType,
defn.DoubleType, defn.BooleanType, defn.UnitType, defn.CharType, defn.ByteType, defn.AnyRefType),
specializableField("Bits32AndUp") -> List(defn.IntType, defn.LongType, defn.FloatType, defn.DoubleType),
specializableField("Integral") -> List(defn.ByteType, defn.ShortType, defn.IntType, defn.LongType, defn.CharType),
specializableField("AllNumeric") -> List(defn.ByteType, defn.ShortType, defn.IntType, defn.LongType,
defn.CharType, defn.FloatType, defn.DoubleType),
specializableField("BestOfBreed") -> List(defn.IntType, defn.DoubleType, defn.BooleanType, defn.UnitType,
defn.AnyRefType)
)
this
}

private final def primitiveCompanionToPrimitive(companion: Type)(implicit ctx: Context) = {
if (companion.termSymbol eq anyRefModule.info.member(nme.AnyRef.toTermName).symbol) {
defn.AnyRefType
}
else {
val claz = companion.termSymbol.companionClass
assert(defn.ScalaValueClasses.contains(claz))
claz.typeRef
}
}

private def specializableToPrimitive(specializable: Type, name: Name)(implicit ctx: Context): List[Type] = {
if (specializable.termSymbol eq specializableModule.info.member(name).symbol) {
specializableMapping(specializable.termSymbol)
}
else Nil
}

def defn(implicit ctx: Context): Definitions = ctx.definitions

private def primitiveTypes(implicit ctx: Context) =
List(ctx.definitions.ByteType,
ctx.definitions.BooleanType,
ctx.definitions.ShortType,
ctx.definitions.IntType,
ctx.definitions.LongType,
ctx.definitions.FloatType,
ctx.definitions.DoubleType,
ctx.definitions.CharType,
ctx.definitions.UnitType
)

def getSpec(sym: Symbol)(implicit ctx: Context): List[Type] = {

def allowedToSpecialize(sym: Symbol): Boolean = {
sym.name != nme.asInstanceOf_ &&
!(sym is Flags.JavaDefined) &&
!sym.isPrimaryConstructor
}

if (allowedToSpecialize(sym)) {
val annotation = sym.denot.getAnnotation(defn.SpecializedAnnot).getOrElse(Nil)
annotation match {
case annot: Annotation =>
val args = annot.arguments
if (args.isEmpty) primitiveTypes
else args.head match {
case _ @ Typed(SeqLiteral(types), _) =>
types.map(t => primitiveCompanionToPrimitive(t.tpe))
case a @ Ident(groupName) => // Matches `@specialized` annotations on Specializable Groups
specializableToPrimitive(a.tpe.asInstanceOf[Type], groupName)
case _ => ctx.error("unexpected match on specialized annotation"); Nil
}
case nil => Nil
}
} else Nil
}

override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
val tparams = tree.tparams.map(_.symbol)
val requests = tparams.zipWithIndex.map{case(sym, i) => (i, getSpec(sym))}
if (requests.nonEmpty) sendRequests(requests, tree)
tree
}

def sendRequests(requests: List[(Int, List[Type])], tree: tpd.Tree)(implicit ctx: Context): Unit = {
requests.map {
case (index, types) if types.nonEmpty =>
ctx.specializePhase.asInstanceOf[TypeSpecializer].registerSpecializationRequest(tree.symbol)(index, types)
case _ =>
}
}
}
Loading