Skip to content

Commit be8467f

Browse files
committed
Add test with @newMain annotation definition
As defined in #13727 adapted to new API.
1 parent ab5ff9f commit be8467f

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
import scala.annotation.*
2+
import collection.mutable
3+
4+
@newMain def happyBirthday(age: Int, name: String, others: String*) =
5+
val suffix =
6+
age % 100 match
7+
case 11 | 12 | 13 => "th"
8+
case _ =>
9+
age % 10 match
10+
case 1 => "st"
11+
case 2 => "nd"
12+
case 3 => "rd"
13+
case _ => "th"
14+
val bldr = new StringBuilder(s"Happy $age$suffix birthday, $name")
15+
for other <- others do bldr.append(" and ").append(other)
16+
println(bldr)
17+
18+
19+
object Test:
20+
def callMain(args: Array[String]): Unit =
21+
val clazz = Class.forName("happyBirthday")
22+
val method = clazz.getMethod("main", classOf[Array[String]])
23+
method.invoke(null, args)
24+
25+
def main(args: Array[String]): Unit =
26+
callMain(Array("23", "Lisa", "Peter"))
27+
end Test
28+
29+
30+
31+
@experimental
32+
final class newMain extends MainAnnotation:
33+
import newMain._
34+
import MainAnnotation._
35+
36+
override type Parser[T] = util.CommandLineParser.FromString[T]
37+
override type Result = Any
38+
39+
override def command(args: Array[String], commandName: String, documentation: String, parameterInfos: ParameterInfo*) =
40+
new Command[Parser, Result]:
41+
private enum ArgumentKind {
42+
case SimpleArgument, OptionalArgument, VarArgument
43+
}
44+
45+
private val argMarker = "--"
46+
private val shortArgMarker = "-"
47+
48+
/**
49+
* The name of the special argument to display the method's help.
50+
* If one of the method's parameters is called the same, will be ignored.
51+
*/
52+
private val helpArg = "help"
53+
private var helpIsOverridden = false
54+
55+
/**
56+
* The short name of the special argument to display the method's help.
57+
* If one of the method's parameters uses the same short name, will be ignored.
58+
*/
59+
private val shortHelpArg = 'h'
60+
private var shortHelpIsOverridden = false
61+
62+
private val maxUsageLineLength = 120
63+
64+
/** A map from argument canonical name (the name of the parameter in the method definition) to parameter informations */
65+
private val nameToParameterInfo: Map[String, ParameterInfo] = parameterInfos.map(infos => infos.name -> infos).toMap
66+
67+
private val (positionalArgs, byNameArgs, invalidByNameArgs) = {
68+
val namesToCanonicalName: Map[String, String] = parameterInfos.flatMap(
69+
infos =>
70+
var names = getAlternativeNames(infos)
71+
val canonicalName = infos.name
72+
if nameIsValid(canonicalName) then names = canonicalName +: names
73+
names.map(_ -> canonicalName)
74+
).toMap
75+
val shortNamesToCanonicalName: Map[Char, String] = parameterInfos.flatMap(
76+
infos =>
77+
var names = getShortNames(infos)
78+
val canonicalName = infos.name
79+
if shortNameIsValid(canonicalName) then names = canonicalName(0) +: names
80+
names.map(_ -> canonicalName)
81+
).toMap
82+
83+
helpIsOverridden = namesToCanonicalName.exists((name, _) => name == helpArg)
84+
shortHelpIsOverridden = shortNamesToCanonicalName.exists((name, _) => name == shortHelpArg)
85+
86+
def getCanonicalArgName(arg: String): Option[String] =
87+
if arg.startsWith(argMarker) && arg.length > argMarker.length then
88+
namesToCanonicalName.get(arg.drop(argMarker.length))
89+
else if arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 then
90+
shortNamesToCanonicalName.get(arg(shortArgMarker.length))
91+
else
92+
None
93+
94+
def isArgName(arg: String): Boolean =
95+
val isFullName = arg.startsWith(argMarker)
96+
val isShortName = arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 && shortNameIsValid(arg(shortArgMarker.length))
97+
isFullName || isShortName
98+
99+
def recurse(remainingArgs: Seq[String], pa: mutable.Queue[String], bna: Seq[(String, String)], ia: Seq[String]): (mutable.Queue[String], Seq[(String, String)], Seq[String]) =
100+
remainingArgs match {
101+
case Seq() =>
102+
(pa, bna, ia)
103+
case argName +: argValue +: rest if isArgName(argName) =>
104+
getCanonicalArgName(argName) match {
105+
case Some(canonicalName) => recurse(rest, pa, bna :+ (canonicalName -> argValue), ia)
106+
case None => recurse(rest, pa, bna, ia :+ argName)
107+
}
108+
case arg +: rest =>
109+
recurse(rest, pa :+ arg, bna, ia)
110+
}
111+
112+
val (pa, bna, ia) = recurse(args.toSeq, mutable.Queue.empty, Vector(), Vector())
113+
val nameToArgValues: Map[String, Seq[String]] = if bna.isEmpty then Map.empty else bna.groupMapReduce(_._1)(p => List(p._2))(_ ++ _)
114+
(pa, nameToArgValues, ia)
115+
}
116+
117+
/** The kind of the arguments. Used to display help about the main method. */
118+
private val argKinds = new mutable.ArrayBuffer[ArgumentKind]
119+
120+
/** A buffer for all errors */
121+
private val errors = new mutable.ArrayBuffer[String]
122+
123+
/** Issue an error, and return an uncallable getter */
124+
private def error(msg: String): None.type =
125+
errors += msg
126+
None
127+
128+
private inline def nameIsValid(name: String): Boolean =
129+
name.length > 1 // TODO add more checks for illegal characters
130+
131+
private inline def shortNameIsValid(name: String): Boolean =
132+
name.length == 1 && shortNameIsValid(name(0))
133+
134+
private inline def shortNameIsValid(shortName: Char): Boolean =
135+
('A' <= shortName && shortName <= 'Z') || ('a' <= shortName && shortName <= 'z')
136+
137+
private def getNameWithMarker(name: String | Char): String = name match {
138+
case c: Char => shortArgMarker + c
139+
case s: String if shortNameIsValid(s) => shortArgMarker + s
140+
case s => argMarker + s
141+
}
142+
143+
private def convert[T](argName: String, arg: String, p: Parser[T]): Option[T] =
144+
p.fromStringOption(arg).orElse(error(s"invalid argument for $argName: $arg"))
145+
146+
private def usage(): Unit =
147+
def argsUsage: Seq[String] =
148+
for ((infos, kind) <- parameterInfos.zip(argKinds))
149+
yield {
150+
val canonicalName = getNameWithMarker(infos.name)
151+
val shortNames = getShortNames(infos).map(getNameWithMarker)
152+
val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker)
153+
val namesPrint = (canonicalName +: alternativeNames ++: shortNames).mkString("[", " | ", "]")
154+
155+
kind match {
156+
case ArgumentKind.SimpleArgument => s"$namesPrint <${infos.typeName}>"
157+
case ArgumentKind.OptionalArgument => s"[$namesPrint <${infos.typeName}>]"
158+
case ArgumentKind.VarArgument => s"[<${infos.typeName}> [<${infos.typeName}> [...]]]"
159+
}
160+
}
161+
162+
def wrapArgumentUsages(argsUsage: Seq[String], maxLength: Int): Seq[String] = {
163+
def recurse(args: Seq[String], currentLine: String, acc: Vector[String]): Seq[String] =
164+
(args, currentLine) match {
165+
case (Nil, "") => acc
166+
case (Nil, l) => (acc :+ l)
167+
case (arg +: t, "") => recurse(t, arg, acc)
168+
case (arg +: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc)
169+
case (arg +: t, l) => recurse(t, arg, acc :+ l)
170+
}
171+
172+
recurse(argsUsage, "", Vector()).toList
173+
}
174+
175+
val usageBeginning = s"Usage: $commandName "
176+
val argsOffset = usageBeginning.length
177+
val usages = wrapArgumentUsages(argsUsage, maxUsageLineLength - argsOffset)
178+
179+
println(usageBeginning + usages.mkString("\n" + " " * argsOffset))
180+
end usage
181+
182+
private def explain(): Unit =
183+
inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n")
184+
185+
def wrapLongLine(line: String, maxLength: Int): List[String] = {
186+
def recurse(s: String, acc: Vector[String]): Seq[String] =
187+
val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength)
188+
if ((s.length <= maxLength) || (lastSpace < 0))
189+
acc :+ s
190+
else {
191+
val (shortLine, rest) = s.splitAt(lastSpace)
192+
recurse(rest.trim.nn, acc :+ shortLine)
193+
}
194+
195+
recurse(line, Vector()).toList
196+
}
197+
198+
if (documentation.nonEmpty)
199+
println(wrapLongLine(documentation, maxUsageLineLength).mkString("\n"))
200+
if (nameToParameterInfo.nonEmpty) {
201+
val argNameShift = 2
202+
val argDocShift = argNameShift + 2
203+
204+
println("Arguments:")
205+
for ((infos, kind) <- parameterInfos.zip(argKinds))
206+
val canonicalName = getNameWithMarker(infos.name)
207+
val shortNames = getShortNames(infos).map(getNameWithMarker)
208+
val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker)
209+
val otherNames = (alternativeNames ++: shortNames) match {
210+
case Seq() => ""
211+
case names => names.mkString("(", ", ", ") ")
212+
}
213+
val argDoc = StringBuilder(" " * argNameShift)
214+
argDoc.append(s"$canonicalName $otherNames- ${infos.typeName}")
215+
216+
kind match {
217+
case ArgumentKind.OptionalArgument => argDoc.append(" (optional)")
218+
case ArgumentKind.VarArgument => argDoc.append(" (vararg)")
219+
case _ =>
220+
}
221+
222+
val doc = infos.documentation
223+
if (doc.nonEmpty) {
224+
val shiftedDoc =
225+
doc.split("\n").nn
226+
.map(line => shiftLines(wrapLongLine(line.nn, maxUsageLineLength - argDocShift), argDocShift))
227+
.mkString("\n")
228+
argDoc.append("\n").append(shiftedDoc)
229+
}
230+
231+
232+
println(argDoc)
233+
}
234+
end explain
235+
236+
private def getAliases(paramInfo: ParameterInfo): Seq[String] =
237+
paramInfo.annotations.collect{ case a: Alias => a }.flatMap(_.aliases)
238+
239+
private def getAlternativeNames(paramInfo: ParameterInfo): Seq[String] =
240+
getAliases(paramInfo).filter(nameIsValid(_))
241+
242+
private def getShortNames(paramInfo: ParameterInfo): Seq[Char] =
243+
getAliases(paramInfo).filter(shortNameIsValid(_)).map(_(0))
244+
245+
private def getInvalidNames(paramInfo: ParameterInfo): Seq[String | Char] =
246+
getAliases(paramInfo).filter(name => !nameIsValid(name) && !shortNameIsValid(name))
247+
248+
def parseArg[T](idx: Int, optDefaultGetter: Option[() => T])(using p: Parser[T]): Option[T] =
249+
val name = parameterInfos(idx).name
250+
251+
argKinds += (if optDefaultGetter.nonEmpty then ArgumentKind.OptionalArgument else ArgumentKind.SimpleArgument)
252+
253+
byNameArgs.get(name) match {
254+
case Some(Nil) =>
255+
throw AssertionError(s"$name present in byNameArgs, but it has no argument value")
256+
case Some(argValues) =>
257+
if argValues.length > 1 then
258+
// Do not accept multiple values
259+
// Remove this test to take last given argument
260+
error(s"more than one value for $name: ${argValues.mkString(", ")}")
261+
else
262+
convert(name, argValues.last, p)
263+
case None =>
264+
if positionalArgs.length > 0 then
265+
convert(name, positionalArgs.dequeue, p)
266+
else if optDefaultGetter.nonEmpty then
267+
optDefaultGetter.map(_())
268+
else
269+
error(s"missing argument for $name")
270+
}
271+
end parseArg
272+
273+
def parseVararg[T](using p: Parser[T]): Option[Seq[T]] =
274+
argKinds += ArgumentKind.VarArgument
275+
val name = parameterInfos.last.name
276+
277+
val byNameGetters = byNameArgs.getOrElse(name, Seq()).map(arg => convert(name, arg, p))
278+
val positionalGetters = positionalArgs.removeAll.map(arg => convert(name, arg, p))
279+
// First take arguments passed by name, then those passed by position
280+
Some(byNameGetters.flatten ++ positionalGetters.flatten)
281+
282+
override def run(f: => Result): Unit =
283+
// Check aliases unicity
284+
val nameAndCanonicalName = nameToParameterInfo.toList.flatMap {
285+
case (canonicalName, infos) => (canonicalName +: getAlternativeNames(infos) ++: getShortNames(infos)).map(_ -> canonicalName)
286+
}
287+
val nameToCanonicalNames = nameAndCanonicalName.groupMap(_._1)(_._2)
288+
289+
for (name, canonicalNames) <- nameToCanonicalNames if canonicalNames.length > 1
290+
do throw IllegalArgumentException(s"$name is used for multiple parameters: ${canonicalNames.mkString(", ")}")
291+
292+
// Check aliases validity
293+
val problematicNames = nameToParameterInfo.toList.flatMap((_, infos) => getInvalidNames(infos))
294+
if problematicNames.length > 0 then throw IllegalArgumentException(s"The following aliases are invalid: ${problematicNames.mkString(", ")}")
295+
296+
// Handle unused and invalid args
297+
for (remainingArg <- positionalArgs) error(s"unused argument: $remainingArg")
298+
for (invalidArg <- invalidByNameArgs) error(s"unknown argument name: $invalidArg")
299+
300+
val displayHelp =
301+
(!helpIsOverridden && args.contains(getNameWithMarker(helpArg))) || (!shortHelpIsOverridden && args.contains(getNameWithMarker(shortHelpArg)))
302+
303+
if displayHelp then
304+
usage()
305+
println()
306+
explain()
307+
else if errors.nonEmpty then
308+
for msg <- errors do println(s"Error: $msg")
309+
usage()
310+
else
311+
f
312+
end run
313+
end command
314+
end newMain
315+
316+
object newMain:
317+
@experimental
318+
final class Alias(val aliases: String*) extends MainAnnotation.ParameterAnnotation
319+
end newMain

0 commit comments

Comments
 (0)