@@ -2,30 +2,34 @@ package dotty.tools.dotc
2
2
package ast
3
3
4
4
import core ._
5
- import Symbols ._ , Types ._ , Contexts ._ , Flags ._ , Constants ._
6
- import StdNames .nme
7
-
8
- /** Generate proxy classes for @main functions.
9
- * A function like
10
- *
11
- * @main def f(x: S, ys: T*) = ...
12
- *
13
- * would be translated to something like
14
- *
15
- * import CommandLineParser._
16
- * class f {
17
- * @static def main(args: Array[String]): Unit =
18
- * try
19
- * f(
20
- * parseArgument[S](args, 0),
21
- * parseRemainingArguments[T](args, 1): _*
22
- * )
23
- * catch case err: ParseError => showError(err)
24
- * }
25
- */
26
- object MainProxies {
5
+ import Symbols ._ , Types ._ , Contexts ._ , Decorators ._ , util .Spans ._ , Flags ._ , Constants ._
6
+ import StdNames .{nme , tpnme }
7
+ import ast .Trees ._
8
+ import Names .{Name , TermName }
9
+ import Comments .Comment
10
+ import NameKinds .DefaultGetterName
11
+ import Annotations .Annotation
27
12
28
- def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
13
+ object MainProxies {
14
+ /** Generate proxy classes for @main functions.
15
+ * A function like
16
+ *
17
+ * @main def f(x: S, ys: T*) = ...
18
+ *
19
+ * would be translated to something like
20
+ *
21
+ * import CommandLineParser._
22
+ * class f {
23
+ * @static def main(args: Array[String]): Unit =
24
+ * try
25
+ * f(
26
+ * parseArgument[S](args, 0),
27
+ * parseRemainingArguments[T](args, 1): _*
28
+ * )
29
+ * catch case err: ParseError => showError(err)
30
+ * }
31
+ */
32
+ def mainProxiesOld (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
29
33
import tpd ._
30
34
def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
31
35
case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -35,11 +39,11 @@ object MainProxies {
35
39
case _ =>
36
40
Nil
37
41
}
38
- mainMethods(stats).flatMap(mainProxy )
42
+ mainMethods(stats).flatMap(mainProxyOld )
39
43
}
40
44
41
45
import untpd ._
42
- def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
46
+ def mainProxyOld (mainFun : Symbol )(using Context ): List [TypeDef ] = {
43
47
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot ).get.tree.span
44
48
def pos = mainFun.sourcePos
45
49
val argsRef = Ident (nme.args)
@@ -114,4 +118,311 @@ object MainProxies {
114
118
}
115
119
result
116
120
}
121
+
122
+ private type DefaultValueSymbols = Map [Int , Symbol ]
123
+ private type ParameterAnnotationss = Seq [Seq [Annotation ]]
124
+
125
+ /**
126
+ * Generate proxy classes for main functions.
127
+ * A function like
128
+ *
129
+ * /* *
130
+ * * Lorem ipsum dolor sit amet
131
+ * * consectetur adipiscing elit.
132
+ * *
133
+ * * @param x my param x
134
+ * * @param ys all my params y
135
+ * */
136
+ * @main(80) def f(
137
+ * @main.Alias("myX") x: S,
138
+ * ys: T*
139
+ * ) = ...
140
+ *
141
+ * would be translated to something like
142
+ *
143
+ * final class f {
144
+ * static def main(args: Array[String]): Unit = {
145
+ * val cmd = new main(80).command(
146
+ * args,
147
+ * "f",
148
+ * "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
149
+ * new scala.annotation.MainAnnotation.ParameterInfos("x", "S")
150
+ * .withDocumentation("my param x")
151
+ * .withAnnotations(new scala.main.Alias("myX")),
152
+ * new scala.annotation.MainAnnotation.ParameterInfos("ys", "T")
153
+ * .withDocumentation("all my params y")
154
+ * )
155
+ *
156
+ * val args0: () => S = cmd.argGetter[S]("x", None)
157
+ * val args1: () => Seq[T] = cmd.varargGetter[T]("ys")
158
+ *
159
+ * cmd.run(f(args0(), args1()*))
160
+ * }
161
+ * }
162
+ */
163
+ def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
164
+ import tpd ._
165
+
166
+ /**
167
+ * Computes the symbols of the default values of the function. Since they cannot be infered anymore at this
168
+ * point of the compilation, they must be explicitely passed by [[mainProxy ]].
169
+ */
170
+ def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
171
+ scope match {
172
+ case TypeDef (_, template : Template ) =>
173
+ template.body.flatMap((_ : Tree ) match {
174
+ case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
175
+ val DefaultGetterName .NumberedInfo (index) = dd.name.info
176
+ List (index -> dd.symbol)
177
+ case _ => Nil
178
+ }).toMap
179
+ case _ => Map .empty
180
+ }
181
+
182
+ /** Computes the list of main methods present in the code. */
183
+ def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
184
+ case stat : DefDef =>
185
+ val sym = stat.symbol
186
+ sym.annotations.filter(_.matches(defn.MainAnnot )) match {
187
+ case Nil =>
188
+ Nil
189
+ case _ :: Nil =>
190
+ val paramAnnotations = stat.paramss.flatMap(_.map(
191
+ valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation ))
192
+ ))
193
+ (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
194
+ case mainAnnot :: others =>
195
+ report.error(s " method cannot have multiple main annotations " , mainAnnot.tree)
196
+ Nil
197
+ }
198
+ case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
199
+ mainMethods(stat, impl.body)
200
+ case _ =>
201
+ Nil
202
+ }
203
+
204
+ // Assuming that the top-level object was already generated, all main methods will have a scope
205
+ mainMethods(EmptyTree , stats).flatMap(mainProxy)
206
+ }
207
+
208
+ def mainProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): List [TypeDef ] = {
209
+ val mainAnnot = mainFun.getAnnotation(defn.MainAnnot ).get
210
+ def pos = mainFun.sourcePos
211
+ val cmdName : TermName = Names .termName(" cmd" )
212
+
213
+ val documentation = new Documentation (docComment)
214
+
215
+ /** A literal value (Boolean, Int, String, etc.) */
216
+ inline def lit (any : Any ): Literal = Literal (Constant (any))
217
+
218
+ /** None */
219
+ inline def none : Tree = ref(defn.NoneModule .termRef)
220
+
221
+ /** Some(value) */
222
+ inline def some (value : Tree ): Tree = Apply (ref(defn.SomeClass .companionModule.termRef), value)
223
+
224
+ /** () => value */
225
+ def unitToValue (value : Tree ): Tree =
226
+ val anonName = nme.ANON_FUN
227
+ val defdef = DefDef (anonName, List (Nil ), TypeTree (), value)
228
+ Block (defdef, Closure (Nil , Ident (anonName), EmptyTree ))
229
+
230
+ /**
231
+ * Creates a list of references and definitions of arguments, the first referencing the second.
232
+ * The goal is to create the
233
+ * `val args0: () => S = cmd.argGetter[S]("x", None)`
234
+ * part of the code.
235
+ * For each tuple, the first element is a ref to `args0`, the second is the whole definition, the third
236
+ * is the ParameterInfos definition associated to this argument.
237
+ */
238
+ def createArgs (mt : MethodType , cmdName : TermName ): List [(Tree , ValDef , Tree )] =
239
+ mt.paramInfos.zip(mt.paramNames).zipWithIndex.map {
240
+ case ((formal, paramName), n) =>
241
+ val argName = nme.args ++ n.toString
242
+ val isRepeated = formal.isRepeatedParam
243
+
244
+ val (argRef, formalType, getterSym) = {
245
+ val argRef0 = Apply (Ident (argName), Nil )
246
+ if formal.isRepeatedParam then
247
+ (repeated(argRef0), formal.argTypes.head, defn.MainAnnotCommand_varargGetter )
248
+ else (argRef0, formal, defn.MainAnnotCommand_argGetter )
249
+ }
250
+
251
+ // The ParameterInfos
252
+ val parameterInfos = {
253
+ val param = paramName.toString
254
+ val paramInfosTree = New (
255
+ TypeTree (defn.MainAnnotParameterInfos .typeRef),
256
+ // Arguments to be passed to ParameterInfos' constructor
257
+ List (List (lit(param), lit(formalType.show)))
258
+ )
259
+
260
+ /*
261
+ * Assignations to be made after the creation of the ParameterInfos.
262
+ * For example:
263
+ * args0paramInfos.withDocumentation("my param x")
264
+ * is represented by the pair
265
+ * defn.MainAnnotationParameterInfos_withDocumentation -> List(lit("my param x"))
266
+ */
267
+ var assignations : List [(Symbol , List [Tree ])] = Nil
268
+ for (doc <- documentation.argDocs.get(param))
269
+ assignations = (defn.MainAnnotationParameterInfos_withDocumentation -> List (lit(doc))) :: assignations
270
+
271
+ val instanciatedAnnots = paramAnnotations(n).map(instanciateAnnotation).toList
272
+ if instanciatedAnnots.nonEmpty then
273
+ assignations = (defn.MainAnnotationParameterInfos_withAnnotations -> instanciatedAnnots) :: assignations
274
+
275
+ assignations.foldLeft[Tree ](paramInfosTree){ case (tree, (setterSym, values)) => Apply (Select (tree, setterSym.name), values) }
276
+ }
277
+
278
+ val argParams =
279
+ if formal.isRepeatedParam then
280
+ List (lit(paramName.toString))
281
+ else
282
+ val defaultValueGetterOpt = defaultValueSymbols.get(n) match {
283
+ case None =>
284
+ none
285
+ case Some (dvSym) =>
286
+ some(unitToValue(ref(dvSym.termRef)))
287
+ }
288
+ List (lit(paramName.toString), defaultValueGetterOpt)
289
+
290
+ val argDef = ValDef (
291
+ argName,
292
+ TypeTree (),
293
+ Apply (TypeApply (Select (Ident (cmdName), getterSym.name), TypeTree (formalType) :: Nil ), argParams),
294
+ )
295
+
296
+ (argRef, argDef, parameterInfos)
297
+ }
298
+ end createArgs
299
+
300
+ /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
301
+ def instanciateAnnotation (annot : Annotation ): Tree =
302
+ val argss = {
303
+ def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
304
+ case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
305
+ case _ => acc
306
+ }
307
+
308
+ def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
309
+ args.flatMap {
310
+ case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
311
+ case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
312
+ case arg => List (TypedSplice (arg))
313
+ }
314
+
315
+ recurse(annot.tree, Nil )
316
+ }
317
+
318
+ New (TypeTree (annot.symbol.typeRef), argss)
319
+ end instanciateAnnotation
320
+
321
+ var result : List [TypeDef ] = Nil
322
+ if (! mainFun.owner.isStaticOwner)
323
+ report.error(s " main method is not statically accessible " , pos)
324
+ else {
325
+ var args : List [ValDef ] = Nil
326
+ var mainCall : Tree = ref(mainFun.termRef)
327
+ var parameterInfoss : List [Tree ] = Nil
328
+
329
+ mainFun.info match {
330
+ case _ : ExprType =>
331
+ case mt : MethodType =>
332
+ if (mt.isImplicitMethod) {
333
+ report.error(s " main method cannot have implicit parameters " , pos)
334
+ }
335
+ else mt.resType match {
336
+ case restpe : MethodType =>
337
+ report.error(s " main method cannot be curried " , pos)
338
+ Nil
339
+ case _ =>
340
+ val (argRefs, argVals, paramInfoss) = createArgs(mt, cmdName).unzip3
341
+ args = argVals
342
+ mainCall = Apply (mainCall, argRefs)
343
+ parameterInfoss = paramInfoss
344
+ }
345
+ case _ : PolyType =>
346
+ report.error(s " main method cannot have type parameters " , pos)
347
+ case _ =>
348
+ report.error(s " main can only annotate a method " , pos)
349
+ }
350
+
351
+ val cmd = ValDef (
352
+ cmdName,
353
+ TypeTree (),
354
+ Apply (
355
+ Select (instanciateAnnotation(mainAnnot), defn.MainAnnot_command .name),
356
+ Ident (nme.args) :: lit(mainFun.showName) :: lit(documentation.mainDoc) :: parameterInfoss
357
+ )
358
+ )
359
+ val run = Apply (Select (Ident (cmdName), defn.MainAnnotCommand_run .name), mainCall)
360
+ val body = Block (cmd :: args, run)
361
+ val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
362
+ .withFlags(Param )
363
+ /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
364
+ * The annotations will be retype-checked in another scope that may not have the same imports.
365
+ */
366
+ def insertTypeSplices = new TreeMap {
367
+ override def transform (tree : Tree )(using Context ): Tree = tree match
368
+ case tree : tpd.Ident @ unchecked => TypedSplice (tree)
369
+ case tree => super .transform(tree)
370
+ }
371
+ val annots = mainFun.annotations
372
+ .filterNot(_.matches(defn.MainAnnot ))
373
+ .map(annot => insertTypeSplices.transform(annot.tree))
374
+ val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
375
+ .withFlags(JavaStatic )
376
+ .withAnnotations(annots)
377
+ val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
378
+ val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
379
+ .withFlags(Final | Invisible )
380
+ if (! ctx.reporter.hasErrors) result = mainCls.withSpan(mainAnnot.tree.span.toSynthetic) :: Nil
381
+ }
382
+ result
383
+ }
384
+
385
+ /** A class responsible for extracting the docstrings of a method. */
386
+ private class Documentation (docComment : Option [Comment ]):
387
+ import util .CommentParsing ._
388
+
389
+ /** The main part of the documentation. */
390
+ lazy val mainDoc : String = _mainDoc
391
+ /** The parameters identified by @param. Maps from parameter name to its documentation. */
392
+ lazy val argDocs : Map [String , String ] = _argDocs
393
+
394
+ private var _mainDoc : String = " "
395
+ private var _argDocs : Map [String , String ] = Map ()
396
+
397
+ docComment match {
398
+ case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
399
+ case None =>
400
+ }
401
+
402
+ private def cleanComment (raw : String ): String =
403
+ var lines : Seq [String ] = raw.trim.split('\n ' ).toSeq
404
+ lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).trim)
405
+ var s = lines.foldLeft(" " ) {
406
+ case (" " , s2) => s2
407
+ case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
408
+ case (s1, " " ) => s1 + '\n '
409
+ case (s1, s2) if s1.last == '\n ' => s1 + s2
410
+ case (s1, s2) => s1 + ' ' + s2
411
+ }
412
+ s.replaceAll(raw " \[\[ " , " " ).replaceAll(raw " \]\] " , " " ).trim
413
+
414
+ private def parseDocComment (raw : String ): Unit =
415
+ // Positions of the sections (@) in the docstring
416
+ val tidx : List [(Int , Int )] = tagIndex(raw)
417
+
418
+ // Parse main comment
419
+ var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx))
420
+ _mainDoc = cleanComment(mainComment)
421
+
422
+ // Parse arguments comments
423
+ val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
424
+ val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
425
+ val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end) })
426
+ _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
427
+ end Documentation
117
428
}
0 commit comments