Skip to content

Commit 233fd90

Browse files
bishaboshalihaoyi
andcommitted
SIP 61 - copy phase and annotation from com-lihaoyi/unroll
also copy tests as sbt-scripted tests. Co-authored-by: Jamie Thompson <bishbashboshjt@gmail.com> Co-authored-by: Li Haoyi <haoyi.sg@gmail.com>
1 parent 8b27ecb commit 233fd90

File tree

111 files changed

+2246
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+2246
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Compiler {
6161
List(new InstrumentCoverage) :: // Perform instrumentation for code coverage (if -coverage-out is set)
6262
List(new CrossVersionChecks, // Check issues related to deprecated and experimental
6363
new FirstTransform, // Some transformations to put trees into a canonical form
64+
new UnrollDefs, // Unroll annotated methods
6465
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
6566
new ElimPackagePrefixes, // Eliminate references to package prefixes in Select nodes
6667
new CookComments, // Cook the comments: expand variables, doc, etc.

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ class Definitions {
10361036
@tu lazy val MigrationAnnot: ClassSymbol = requiredClass("scala.annotation.migration")
10371037
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
10381038
@tu lazy val UnusedAnnot: ClassSymbol = requiredClass("scala.annotation.unused")
1039+
@tu lazy val UnrollAnnot: ClassSymbol = requiredClass("scala.annotation.unroll")
10391040
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")
10401041
@tu lazy val NativeAnnot: ClassSymbol = requiredClass("scala.native")
10411042
@tu lazy val RepeatedAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Repeated")
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.*
4+
import core.*
5+
import MegaPhase.MiniPhase
6+
import Contexts.*
7+
import Symbols.*
8+
import Flags.*
9+
import SymDenotations.*
10+
import Decorators.*
11+
import ast.Trees.*
12+
import ast.tpd
13+
import StdNames.nme
14+
import Names.*
15+
import Constants.Constant
16+
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
17+
import dotty.tools.dotc.core.Types.{MethodType, NamedType, PolyType, Type}
18+
import dotty.tools.dotc.core.Symbols
19+
20+
import scala.language.implicitConversions
21+
22+
class UnrollDefs extends MiniPhase {
23+
import tpd._
24+
25+
val phaseName = "unroll"
26+
27+
override val runsAfter = Set(FirstTransform.name)
28+
29+
def copyParam(p: ValDef, parent: Symbol)(using Context) = {
30+
implicitly[Context].typeAssigner.assignType(
31+
cpy.ValDef(p)(p.name, p.tpt, p.rhs),
32+
Symbols.newSymbol(parent, p.name, p.symbol.flags, p.symbol.info)
33+
)
34+
}
35+
36+
def copyParam2(p: TypeDef, parent: Symbol)(using Context) = {
37+
implicitly[Context].typeAssigner.assignType(
38+
cpy.TypeDef(p)(p.name, p.rhs),
39+
Symbols.newSymbol(parent, p.name, p.symbol.flags, p.symbol.info)
40+
)
41+
}
42+
43+
def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
44+
params
45+
.zipWithIndex
46+
.collect {
47+
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") =>
48+
i
49+
}
50+
}
51+
def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
52+
def generateSingleForwarder(defdef: DefDef,
53+
prevMethodType: Type,
54+
paramIndex: Int,
55+
nextParamIndex: Int,
56+
nextSymbol: Symbol,
57+
annotatedParamListIndex: Int,
58+
paramLists: List[ParamClause],
59+
isCaseApply: Boolean)
60+
(using Context) = {
61+
62+
def truncateMethodType0(tpe: Type, n: Int): Type = {
63+
tpe match{
64+
case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType, n + 1))
65+
case mt: MethodType =>
66+
if (n == annotatedParamListIndex) MethodType(mt.paramInfos.take(paramIndex), mt.resType)
67+
else MethodType(mt.paramInfos, truncateMethodType0(mt.resType, n + 1))
68+
}
69+
}
70+
71+
val truncatedMethodType = truncateMethodType0(prevMethodType, 0)
72+
val forwarderDefSymbol = Symbols.newSymbol(
73+
defdef.symbol.owner,
74+
defdef.name,
75+
defdef.symbol.flags &~
76+
HasDefaultParams &~
77+
(if (nextParamIndex == -1) Flags.EmptyFlags else Deferred) |
78+
Invisible,
79+
truncatedMethodType
80+
)
81+
82+
val newParamLists: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) =>
83+
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
84+
else {
85+
if (isTypeClause(ps)) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol))
86+
else ps.map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
87+
}
88+
}
89+
forwarderDefSymbol.setParamssFromDefs(newParamLists)
90+
91+
val defaultOffset = paramLists
92+
.iterator
93+
.take(annotatedParamListIndex)
94+
.filter(!isTypeClause(_))
95+
.map(_.size)
96+
.sum
97+
98+
val defaultCalls = Range(paramIndex, nextParamIndex).map(n =>
99+
val inner = if (defdef.symbol.isConstructor) {
100+
ref(defdef.symbol.owner.companionModule)
101+
.select(DefaultGetterName(defdef.name, n + defaultOffset))
102+
} else if (isCaseApply) {
103+
ref(defdef.symbol.owner.companionModule)
104+
.select(DefaultGetterName(termName("<init>"), n + defaultOffset))
105+
} else {
106+
This(defdef.symbol.owner.asClass)
107+
.select(DefaultGetterName(defdef.name, n + defaultOffset))
108+
}
109+
110+
newParamLists
111+
.take(annotatedParamListIndex)
112+
.map(_.map(p => ref(p.symbol)))
113+
.foldLeft[Tree](inner){
114+
case (lhs: Tree, newParams) =>
115+
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
116+
else Apply(lhs, newParams)
117+
}
118+
)
119+
120+
val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(nextSymbol)
121+
122+
val forwarderCallArgs =
123+
newParamLists.zipWithIndex.map{case (ps, i) =>
124+
if (i == annotatedParamListIndex) ps.map(p => ref(p.symbol)).take(nextParamIndex) ++ defaultCalls
125+
else ps.map(p => ref(p.symbol))
126+
}
127+
128+
lazy val forwarderCall0 = forwarderCallArgs.foldLeft[Tree](forwarderInner){
129+
case (lhs: Tree, newParams) =>
130+
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
131+
else Apply(lhs, newParams)
132+
}
133+
134+
lazy val forwarderCall =
135+
if (!defdef.symbol.isConstructor) forwarderCall0
136+
else Block(List(forwarderCall0), Literal(Constant(())))
137+
138+
val forwarderDef = implicitly[Context].typeAssigner.assignType(
139+
cpy.DefDef(defdef)(
140+
name = forwarderDefSymbol.name,
141+
paramss = newParamLists,
142+
tpt = defdef.tpt,
143+
rhs = if (nextParamIndex == -1) EmptyTree else forwarderCall
144+
),
145+
forwarderDefSymbol
146+
)
147+
148+
forwarderDef
149+
}
150+
151+
def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
152+
cpy.DefDef(defdef)(
153+
name = defdef.name,
154+
paramss = defdef.paramss,
155+
tpt = defdef.tpt,
156+
rhs = Match(
157+
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
158+
startParamIndices.map { paramIndex =>
159+
val Apply(select, args) = defdef.rhs: @unchecked
160+
CaseDef(
161+
Literal(Constant(paramIndex)),
162+
EmptyTree,
163+
Apply(
164+
select,
165+
args.take(paramIndex) ++
166+
Range(paramIndex, paramCount).map(n =>
167+
ref(defdef.symbol.owner.companionModule)
168+
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
169+
)
170+
)
171+
)
172+
} ++ Seq(
173+
CaseDef(
174+
EmptyTree,
175+
EmptyTree,
176+
defdef.rhs
177+
)
178+
)
179+
)
180+
).setDefTree
181+
}
182+
183+
def generateSyntheticDefs(tree: Tree)(using Context): (Option[Symbol], Seq[Tree]) = tree match{
184+
case defdef: DefDef if defdef.paramss.nonEmpty =>
185+
import dotty.tools.dotc.core.NameOps.isConstructorName
186+
187+
val isCaseCopy =
188+
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)
189+
190+
val isCaseApply =
191+
defdef.name.toString == "apply" && defdef.symbol.owner.companionClass.is(CaseClass)
192+
193+
val isCaseFromProduct = defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass)
194+
195+
val annotated =
196+
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
197+
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
198+
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
199+
else defdef.symbol
200+
201+
202+
annotated
203+
.paramSymss
204+
.zipWithIndex
205+
.flatMap{case (paramClause, paramClauseIndex) =>
206+
val annotationIndices = findUnrollAnnotations(paramClause)
207+
if (annotationIndices.isEmpty) None
208+
else Some((paramClauseIndex, annotationIndices))
209+
} match{
210+
case Nil => (None, Nil)
211+
case Seq((paramClauseIndex, annotationIndices)) =>
212+
val paramCount = annotated.paramSymss(paramClauseIndex).size
213+
if (isCaseFromProduct) {
214+
(Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef)))
215+
} else {
216+
if (defdef.symbol.is(Deferred)){
217+
(
218+
Some(defdef.symbol),
219+
(-1 +: annotationIndices :+ paramCount).sliding(2).toList.foldLeft((Seq.empty[DefDef], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
220+
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
221+
val forwarder = generateSingleForwarder(
222+
defdef,
223+
defdef.symbol.info,
224+
nextParamIndex,
225+
paramIndex,
226+
nextSymbol,
227+
paramClauseIndex,
228+
defdef.paramss,
229+
isCaseApply
230+
)
231+
(forwarder +: defdefs, forwarder.symbol)
232+
})._1
233+
)
234+
235+
}else{
236+
237+
(
238+
None,
239+
(annotationIndices :+ paramCount).sliding(2).toList.reverse.foldLeft((Seq.empty[DefDef], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
240+
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
241+
val forwarder = generateSingleForwarder(
242+
defdef,
243+
defdef.symbol.info,
244+
paramIndex,
245+
nextParamIndex,
246+
nextSymbol,
247+
paramClauseIndex,
248+
defdef.paramss,
249+
isCaseApply
250+
)
251+
(forwarder +: defdefs, forwarder.symbol)
252+
})._1
253+
)
254+
}
255+
}
256+
257+
case multiple => sys.error("Cannot have multiple parameter lists containing `@unroll` annotation")
258+
}
259+
260+
case _ => (None, Nil)
261+
}
262+
263+
override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {
264+
265+
val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip
266+
val (_, generatedConstr) = generateSyntheticDefs(tmpl.constr)
267+
val removed = removed0.flatten
268+
269+
super.transformTemplate(
270+
cpy.Template(tmpl)(
271+
tmpl.constr,
272+
tmpl.parents,
273+
tmpl.derived,
274+
tmpl.self,
275+
tmpl.body.filter(t => !removed.contains(t.symbol)) ++ generatedDefs.flatten ++ generatedConstr
276+
)
277+
)
278+
}
279+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package scala.annotation
2+
3+
@experimental("under review as part of SIP-61")
4+
final class unroll extends scala.annotation.StaticAnnotation
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
lazy val utils = project.in(file("utils"))
2+
3+
lazy val sharedSettings = Seq(
4+
scalacOptions ++= Seq("-Ycheck:all", "-experimental")
5+
)
6+
7+
lazy val v1 = project.in(file("v1"))
8+
.settings(sharedSettings)
9+
10+
lazy val v1_app = project.in(file("v1_app")).dependsOn(utils)
11+
.settings(sharedSettings)
12+
.settings(
13+
Compile / unmanagedClasspath := Seq(
14+
Attributed.blank((v1 / Compile / classDirectory).value)
15+
),
16+
)
17+
18+
lazy val v2 = project.in(file("v2"))
19+
.settings(sharedSettings)
20+
21+
lazy val v2_app = project.in(file("v2_app")).dependsOn(utils)
22+
.settings(sharedSettings)
23+
.settings(
24+
Runtime / unmanagedClasspath := Seq(
25+
// add v1_app, compiled against v1, to the classpath
26+
Attributed.blank((v1_app / Runtime / classDirectory).value)
27+
),
28+
Compile / unmanagedClasspath := Seq(
29+
Attributed.blank((v2 / Compile / classDirectory).value)
30+
),
31+
)
32+
33+
lazy val v3 = project.in(file("v3"))
34+
.settings(sharedSettings)
35+
36+
lazy val v3_app = project.in(file("v3_app")).dependsOn(utils)
37+
.settings(sharedSettings)
38+
.settings(
39+
Runtime / unmanagedClasspath := Seq(
40+
// add v1_app, compiled against v1, to the classpath
41+
Attributed.blank((v1_app / Runtime / classDirectory).value),
42+
// add v2_app, compiled against v2, to the classpath
43+
Attributed.blank((v2_app / Runtime / classDirectory).value),
44+
),
45+
Compile / unmanagedClasspath := Seq(
46+
Attributed.blank((v3 / Compile / classDirectory).value)
47+
),
48+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import sbt._
2+
import Keys._
3+
4+
object DottyInjectedPlugin extends AutoPlugin {
5+
override def requires = plugins.JvmPlugin
6+
override def trigger = allRequirements
7+
8+
override val projectSettings = Seq(
9+
scalaVersion := sys.props("plugin.scalaVersion"),
10+
scalacOptions += "-source:3.0-migration"
11+
)
12+
}

sbt-test/unroll-annot/caseclass/test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# compile and run a basic version of Unrolled (v1), and an app that uses it
2+
> v1/compile
3+
> v1_app/runMain unroll.UnrollTestMainV1
4+
# add a field to the case class (v2), and update the app to use it,
5+
# and ensure the old version (v1) still links
6+
> v2/compile
7+
> v2_app/runMain unroll.UnrollTestMainV1
8+
> v2_app/runMain unroll.UnrollTestMainV2
9+
# add a field to the case class (v3), and update the app to use it,
10+
# and ensure the old versions (v1, v2) still link
11+
> v3/compile
12+
> v3_app/runMain unroll.UnrollTestMainV1
13+
> v3_app/runMain unroll.UnrollTestMainV2
14+
> v3_app/runMain unroll.UnrollTestMainV3
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package unroll
2+
3+
object TestUtils {
4+
def logAssertStartsWith(actual: String, expected: String): Unit = {
5+
assert(actual.startsWith(expected))
6+
val suffix = {
7+
val suffix0 = actual.stripPrefix(expected)
8+
if (suffix0.isEmpty) "" else s""" + "$suffix0""""
9+
}
10+
println(s"""Assertion passed: found "$expected"$suffix""")
11+
}
12+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package unroll
2+
3+
case class Unrolled(s: String, n: Int = 1){
4+
def foo = s + n
5+
}

0 commit comments

Comments
 (0)