Skip to content

Commit f6e8146

Browse files
timotheeandresnicolasstucki
authored andcommitted
Add scala.annotation.MainAnnotation
1 parent 25f4eec commit f6e8146

File tree

86 files changed

+2508
-33
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+2508
-33
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 336 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,34 @@ package dotty.tools.dotc
22
package ast
33

44
import core._
5-
import Symbols._, Types._, Contexts._, Flags._, Constants._
6-
import StdNames.nme
7-
8-
/** Generate proxy classes for @main functions.
9-
* A function like
10-
*
11-
* @main def f(x: S, ys: T*) = ...
12-
*
13-
* would be translated to something like
14-
*
15-
* import CommandLineParser._
16-
* class f {
17-
* @static def main(args: Array[String]): Unit =
18-
* try
19-
* f(
20-
* parseArgument[S](args, 0),
21-
* parseRemainingArguments[T](args, 1): _*
22-
* )
23-
* catch case err: ParseError => showError(err)
24-
* }
25-
*/
26-
object MainProxies {
5+
import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._
6+
import StdNames.{nme, tpnme}
7+
import ast.Trees._
8+
import Names.{Name, TermName}
9+
import Comments.Comment
10+
import NameKinds.DefaultGetterName
11+
import Annotations.Annotation
2712

28-
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
13+
object MainProxies {
14+
/** Generate proxy classes for @main functions.
15+
* A function like
16+
*
17+
* @main def f(x: S, ys: T*) = ...
18+
*
19+
* would be translated to something like
20+
*
21+
* import CommandLineParser._
22+
* class f {
23+
* @static def main(args: Array[String]): Unit =
24+
* try
25+
* f(
26+
* parseArgument[S](args, 0),
27+
* parseRemainingArguments[T](args, 1): _*
28+
* )
29+
* catch case err: ParseError => showError(err)
30+
* }
31+
*/
32+
def mainProxiesOld(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
2933
import tpd._
3034
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
3135
case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) =>
@@ -35,11 +39,11 @@ object MainProxies {
3539
case _ =>
3640
Nil
3741
}
38-
mainMethods(stats).flatMap(mainProxy)
42+
mainMethods(stats).flatMap(mainProxyOld)
3943
}
4044

4145
import untpd._
42-
def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
46+
def mainProxyOld(mainFun: Symbol)(using Context): List[TypeDef] = {
4347
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span
4448
def pos = mainFun.sourcePos
4549
val argsRef = Ident(nme.args)
@@ -114,4 +118,311 @@ object MainProxies {
114118
}
115119
result
116120
}
121+
122+
private type DefaultValueSymbols = Map[Int, Symbol]
123+
private type ParameterAnnotationss = Seq[Seq[Annotation]]
124+
125+
/**
126+
* Generate proxy classes for main functions.
127+
* A function like
128+
*
129+
* /**
130+
* * Lorem ipsum dolor sit amet
131+
* * consectetur adipiscing elit.
132+
* *
133+
* * @param x my param x
134+
* * @param ys all my params y
135+
* */
136+
* @main(80) def f(
137+
* @main.Alias("myX") x: S,
138+
* ys: T*
139+
* ) = ...
140+
*
141+
* would be translated to something like
142+
*
143+
* final class f {
144+
* static def main(args: Array[String]): Unit = {
145+
* val cmd = new main(80).command(
146+
* args,
147+
* "f",
148+
* "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
149+
* new scala.annotation.MainAnnotation.ParameterInfos("x", "S")
150+
* .withDocumentation("my param x")
151+
* .withAnnotations(new scala.main.Alias("myX")),
152+
* new scala.annotation.MainAnnotation.ParameterInfos("ys", "T")
153+
* .withDocumentation("all my params y")
154+
* )
155+
*
156+
* val args0: () => S = cmd.argGetter[S]("x", None)
157+
* val args1: () => Seq[T] = cmd.varargGetter[T]("ys")
158+
*
159+
* cmd.run(f(args0(), args1()*))
160+
* }
161+
* }
162+
*/
163+
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
164+
import tpd._
165+
166+
/**
167+
* Computes the symbols of the default values of the function. Since they cannot be infered anymore at this
168+
* point of the compilation, they must be explicitely passed by [[mainProxy]].
169+
*/
170+
def defaultValueSymbols(scope: Tree, funSymbol: Symbol): DefaultValueSymbols =
171+
scope match {
172+
case TypeDef(_, template: Template) =>
173+
template.body.flatMap((_: Tree) match {
174+
case dd: DefDef if dd.name.is(DefaultGetterName) && dd.name.firstPart == funSymbol.name =>
175+
val DefaultGetterName.NumberedInfo(index) = dd.name.info
176+
List(index -> dd.symbol)
177+
case _ => Nil
178+
}).toMap
179+
case _ => Map.empty
180+
}
181+
182+
/** Computes the list of main methods present in the code. */
183+
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
184+
case stat: DefDef =>
185+
val sym = stat.symbol
186+
sym.annotations.filter(_.matches(defn.MainAnnot)) match {
187+
case Nil =>
188+
Nil
189+
case _ :: Nil =>
190+
val paramAnnotations = stat.paramss.flatMap(_.map(
191+
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation))
192+
))
193+
(sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
194+
case mainAnnot :: others =>
195+
report.error(s"method cannot have multiple main annotations", mainAnnot.tree)
196+
Nil
197+
}
198+
case stat @ TypeDef(_, impl: Template) if stat.symbol.is(Module) =>
199+
mainMethods(stat, impl.body)
200+
case _ =>
201+
Nil
202+
}
203+
204+
// Assuming that the top-level object was already generated, all main methods will have a scope
205+
mainMethods(EmptyTree, stats).flatMap(mainProxy)
206+
}
207+
208+
def mainProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): List[TypeDef] = {
209+
val mainAnnot = mainFun.getAnnotation(defn.MainAnnot).get
210+
def pos = mainFun.sourcePos
211+
val cmdName: TermName = Names.termName("cmd")
212+
213+
val documentation = new Documentation(docComment)
214+
215+
/** A literal value (Boolean, Int, String, etc.) */
216+
inline def lit(any: Any): Literal = Literal(Constant(any))
217+
218+
/** None */
219+
inline def none: Tree = ref(defn.NoneModule.termRef)
220+
221+
/** Some(value) */
222+
inline def some(value: Tree): Tree = Apply(ref(defn.SomeClass.companionModule.termRef), value)
223+
224+
/** () => value */
225+
def unitToValue(value: Tree): Tree =
226+
val anonName = nme.ANON_FUN
227+
val defdef = DefDef(anonName, List(Nil), TypeTree(), value)
228+
Block(defdef, Closure(Nil, Ident(anonName), EmptyTree))
229+
230+
/**
231+
* Creates a list of references and definitions of arguments, the first referencing the second.
232+
* The goal is to create the
233+
* `val args0: () => S = cmd.argGetter[S]("x", None)`
234+
* part of the code.
235+
* For each tuple, the first element is a ref to `args0`, the second is the whole definition, the third
236+
* is the ParameterInfos definition associated to this argument.
237+
*/
238+
def createArgs(mt: MethodType, cmdName: TermName): List[(Tree, ValDef, Tree)] =
239+
mt.paramInfos.zip(mt.paramNames).zipWithIndex.map {
240+
case ((formal, paramName), n) =>
241+
val argName = nme.args ++ n.toString
242+
val isRepeated = formal.isRepeatedParam
243+
244+
val (argRef, formalType, getterSym) = {
245+
val argRef0 = Apply(Ident(argName), Nil)
246+
if formal.isRepeatedParam then
247+
(repeated(argRef0), formal.argTypes.head, defn.MainAnnotCommand_varargGetter)
248+
else (argRef0, formal, defn.MainAnnotCommand_argGetter)
249+
}
250+
251+
// The ParameterInfos
252+
val parameterInfos = {
253+
val param = paramName.toString
254+
val paramInfosTree = New(
255+
TypeTree(defn.MainAnnotParameterInfos.typeRef),
256+
// Arguments to be passed to ParameterInfos' constructor
257+
List(List(lit(param), lit(formalType.show)))
258+
)
259+
260+
/*
261+
* Assignations to be made after the creation of the ParameterInfos.
262+
* For example:
263+
* args0paramInfos.withDocumentation("my param x")
264+
* is represented by the pair
265+
* defn.MainAnnotationParameterInfos_withDocumentation -> List(lit("my param x"))
266+
*/
267+
var assignations: List[(Symbol, List[Tree])] = Nil
268+
for (doc <- documentation.argDocs.get(param))
269+
assignations = (defn.MainAnnotationParameterInfos_withDocumentation -> List(lit(doc))) :: assignations
270+
271+
val instanciatedAnnots = paramAnnotations(n).map(instanciateAnnotation).toList
272+
if instanciatedAnnots.nonEmpty then
273+
assignations = (defn.MainAnnotationParameterInfos_withAnnotations -> instanciatedAnnots) :: assignations
274+
275+
assignations.foldLeft[Tree](paramInfosTree){ case (tree, (setterSym, values)) => Apply(Select(tree, setterSym.name), values) }
276+
}
277+
278+
val argParams =
279+
if formal.isRepeatedParam then
280+
List(lit(paramName.toString))
281+
else
282+
val defaultValueGetterOpt = defaultValueSymbols.get(n) match {
283+
case None =>
284+
none
285+
case Some(dvSym) =>
286+
some(unitToValue(ref(dvSym.termRef)))
287+
}
288+
List(lit(paramName.toString), defaultValueGetterOpt)
289+
290+
val argDef = ValDef(
291+
argName,
292+
TypeTree(),
293+
Apply(TypeApply(Select(Ident(cmdName), getterSym.name), TypeTree(formalType) :: Nil), argParams),
294+
)
295+
296+
(argRef, argDef, parameterInfos)
297+
}
298+
end createArgs
299+
300+
/** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
301+
def instanciateAnnotation(annot: Annotation): Tree =
302+
val argss = {
303+
def recurse(t: tpd.Tree, acc: List[List[Tree]]): List[List[Tree]] = t match {
304+
case Apply(t, args: List[tpd.Tree]) => recurse(t, extractArgs(args) :: acc)
305+
case _ => acc
306+
}
307+
308+
def extractArgs(args: List[tpd.Tree]): List[Tree] =
309+
args.flatMap {
310+
case Typed(SeqLiteral(varargs, _), _) => varargs.map(arg => TypedSplice(arg))
311+
case arg: Select if arg.name.is(DefaultGetterName) => Nil // Ignore default values, they will be added later by the compiler
312+
case arg => List(TypedSplice(arg))
313+
}
314+
315+
recurse(annot.tree, Nil)
316+
}
317+
318+
New(TypeTree(annot.symbol.typeRef), argss)
319+
end instanciateAnnotation
320+
321+
var result: List[TypeDef] = Nil
322+
if (!mainFun.owner.isStaticOwner)
323+
report.error(s"main method is not statically accessible", pos)
324+
else {
325+
var args: List[ValDef] = Nil
326+
var mainCall: Tree = ref(mainFun.termRef)
327+
var parameterInfoss: List[Tree] = Nil
328+
329+
mainFun.info match {
330+
case _: ExprType =>
331+
case mt: MethodType =>
332+
if (mt.isImplicitMethod) {
333+
report.error(s"main method cannot have implicit parameters", pos)
334+
}
335+
else mt.resType match {
336+
case restpe: MethodType =>
337+
report.error(s"main method cannot be curried", pos)
338+
Nil
339+
case _ =>
340+
val (argRefs, argVals, paramInfoss) = createArgs(mt, cmdName).unzip3
341+
args = argVals
342+
mainCall = Apply(mainCall, argRefs)
343+
parameterInfoss = paramInfoss
344+
}
345+
case _: PolyType =>
346+
report.error(s"main method cannot have type parameters", pos)
347+
case _ =>
348+
report.error(s"main can only annotate a method", pos)
349+
}
350+
351+
val cmd = ValDef(
352+
cmdName,
353+
TypeTree(),
354+
Apply(
355+
Select(instanciateAnnotation(mainAnnot), defn.MainAnnot_command.name),
356+
Ident(nme.args) :: lit(mainFun.showName) :: lit(documentation.mainDoc) :: parameterInfoss
357+
)
358+
)
359+
val run = Apply(Select(Ident(cmdName), defn.MainAnnotCommand_run.name), mainCall)
360+
val body = Block(cmd :: args, run)
361+
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
362+
.withFlags(Param)
363+
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
364+
* The annotations will be retype-checked in another scope that may not have the same imports.
365+
*/
366+
def insertTypeSplices = new TreeMap {
367+
override def transform(tree: Tree)(using Context): Tree = tree match
368+
case tree: tpd.Ident @unchecked => TypedSplice(tree)
369+
case tree => super.transform(tree)
370+
}
371+
val annots = mainFun.annotations
372+
.filterNot(_.matches(defn.MainAnnot))
373+
.map(annot => insertTypeSplices.transform(annot.tree))
374+
val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body)
375+
.withFlags(JavaStatic)
376+
.withAnnotations(annots)
377+
val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil)
378+
val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl)
379+
.withFlags(Final | Invisible)
380+
if (!ctx.reporter.hasErrors) result = mainCls.withSpan(mainAnnot.tree.span.toSynthetic) :: Nil
381+
}
382+
result
383+
}
384+
385+
/** A class responsible for extracting the docstrings of a method. */
386+
private class Documentation(docComment: Option[Comment]):
387+
import util.CommentParsing._
388+
389+
/** The main part of the documentation. */
390+
lazy val mainDoc: String = _mainDoc
391+
/** The parameters identified by @param. Maps from parameter name to its documentation. */
392+
lazy val argDocs: Map[String, String] = _argDocs
393+
394+
private var _mainDoc: String = ""
395+
private var _argDocs: Map[String, String] = Map()
396+
397+
docComment match {
398+
case Some(comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
399+
case None =>
400+
}
401+
402+
private def cleanComment(raw: String): String =
403+
var lines: Seq[String] = raw.trim.split('\n').toSeq
404+
lines = lines.map(l => l.substring(skipLineLead(l, -1), l.length).trim)
405+
var s = lines.foldLeft("") {
406+
case ("", s2) => s2
407+
case (s1, "") if s1.last == '\n' => s1 // Multiple newlines are kept as single newlines
408+
case (s1, "") => s1 + '\n'
409+
case (s1, s2) if s1.last == '\n' => s1 + s2
410+
case (s1, s2) => s1 + ' ' + s2
411+
}
412+
s.replaceAll(raw"\[\[", "").replaceAll(raw"\]\]", "").trim
413+
414+
private def parseDocComment(raw: String): Unit =
415+
// Positions of the sections (@) in the docstring
416+
val tidx: List[(Int, Int)] = tagIndex(raw)
417+
418+
// Parse main comment
419+
var mainComment: String = raw.substring(skipLineLead(raw, 0), startTag(raw, tidx))
420+
_mainDoc = cleanComment(mainComment)
421+
422+
// Parse arguments comments
423+
val argsCommentsSpans: Map[String, (Int, Int)] = paramDocs(raw, "@param", tidx)
424+
val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
425+
val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end) })
426+
_argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
427+
end Documentation
117428
}

0 commit comments

Comments
 (0)