|
| 1 | +package dotty.tools.dotc |
| 2 | +package transform.localopt |
| 3 | + |
| 4 | +import scala.annotation.tailrec |
| 5 | +import scala.collection.mutable.{ListBuffer, Stack} |
| 6 | +import scala.reflect.{ClassTag, classTag} |
| 7 | +import scala.util.chaining.* |
| 8 | +import scala.util.matching.Regex.Match |
| 9 | + |
| 10 | +import java.util.{Calendar, Date, Formattable} |
| 11 | + |
| 12 | +import StringContextChecker.InterpolationReporter |
| 13 | + |
| 14 | +/** Formatter string checker. */ |
| 15 | +abstract class FormatChecker(using reporter: InterpolationReporter): |
| 16 | + |
| 17 | + // Pick the first runtime type which the i'th arg can satisfy. |
| 18 | + // If conversion is required, implementation must emit it. |
| 19 | + def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] |
| 20 | + |
| 21 | + // count of args, for checking indexes |
| 22 | + def argc: Int |
| 23 | + |
| 24 | + val allFlags = "-#+ 0,(<" |
| 25 | + val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r |
| 26 | + |
| 27 | + // ordinal is the regex group index in the format pattern |
| 28 | + enum SpecGroup: |
| 29 | + case Spec, Index, Flags, Width, Precision, CC |
| 30 | + import SpecGroup.* |
| 31 | + |
| 32 | + /** For N part strings and N-1 args to interpolate, normalize parts and check arg types. |
| 33 | + * |
| 34 | + * Returns parts, possibly updated with explicit leading "%s", |
| 35 | + * and conversions for each arg. |
| 36 | + * |
| 37 | + * Implementation must emit conversions required by invocations of `argType`. |
| 38 | + */ |
| 39 | + def checked(parts0: List[String]): (List[String], List[Conversion]) = |
| 40 | + val amended = ListBuffer.empty[String] |
| 41 | + val convert = ListBuffer.empty[Conversion] |
| 42 | + |
| 43 | + @tailrec |
| 44 | + def loop(parts: List[String], n: Int): Unit = |
| 45 | + parts match |
| 46 | + case part0 :: more => |
| 47 | + def badPart(t: Throwable): String = "".tap(_ => reporter.partError(t.getMessage, index = n, offset = 0)) |
| 48 | + val part = try StringContext.processEscapes(part0) catch badPart |
| 49 | + val matches = formatPattern.findAllMatchIn(part) |
| 50 | + |
| 51 | + def insertStringConversion(): Unit = |
| 52 | + amended += "%s" + part |
| 53 | + convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve |
| 54 | + argType(n-1, classTag[Any]) |
| 55 | + def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") |
| 56 | + def accept(op: Conversion): Unit = |
| 57 | + if !op.isLeading then errorLeading(op) |
| 58 | + op.accepts(argType(n-1, op.acceptableVariants*)) |
| 59 | + amended += part |
| 60 | + convert += op |
| 61 | + |
| 62 | + // after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed |
| 63 | + if n == 0 then amended += part |
| 64 | + else if !matches.hasNext then insertStringConversion() |
| 65 | + else |
| 66 | + val cv = Conversion(matches.next(), n) |
| 67 | + if cv.isLiteral then insertStringConversion() |
| 68 | + else if cv.isIndexed then |
| 69 | + if cv.index.getOrElse(-1) == n then accept(cv) |
| 70 | + else |
| 71 | + // either some other arg num, or '<' |
| 72 | + //c.warning(op.groupPos(Index), "Index is not this arg") |
| 73 | + insertStringConversion() |
| 74 | + else if !cv.isError then accept(cv) |
| 75 | + |
| 76 | + // any remaining conversions in this part must be either literals or indexed |
| 77 | + while matches.hasNext do |
| 78 | + val cv = Conversion(matches.next(), n) |
| 79 | + if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg") |
| 80 | + else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv) |
| 81 | + |
| 82 | + loop(more, n + 1) |
| 83 | + case Nil => () |
| 84 | + end loop |
| 85 | + |
| 86 | + loop(parts0, n = 0) |
| 87 | + (amended.toList, convert.toList) |
| 88 | + end checked |
| 89 | + |
| 90 | + extension (descriptor: Match) |
| 91 | + def at(g: SpecGroup): Int = descriptor.start(g.ordinal) |
| 92 | + def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i |
| 93 | + def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) |
| 94 | + def stringOf(g: SpecGroup): String = group(g).getOrElse("") |
| 95 | + def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt) |
| 96 | + |
| 97 | + extension (inline value: Boolean) |
| 98 | + inline def or(inline body: => Unit): Boolean = value || { body ; false } |
| 99 | + inline def orElse(inline body: => Unit): Boolean = value || { body ; true } |
| 100 | + inline def but(inline body: => Unit): Boolean = value && { body ; false } |
| 101 | + inline def and(inline body: => Unit): Boolean = value && { body ; true } |
| 102 | + |
| 103 | + /** A conversion specifier matched in the argi'th string part, |
| 104 | + * with `argc` arguments to interpolate. |
| 105 | + */ |
| 106 | + sealed abstract class Conversion: |
| 107 | + // the match for this descriptor |
| 108 | + def descriptor: Match |
| 109 | + // the part number for reporting errors |
| 110 | + def argi: Int |
| 111 | + |
| 112 | + // the descriptor fields |
| 113 | + val index: Option[Int] = descriptor.intOf(Index) |
| 114 | + val flags: String = descriptor.stringOf(Flags) |
| 115 | + val width: Option[Int] = descriptor.intOf(Width) |
| 116 | + val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt) |
| 117 | + val op: String = descriptor.stringOf(CC) |
| 118 | + |
| 119 | + // the conversion char is the head of the op string (but see DateTimeXn) |
| 120 | + val cc: Char = if isError then '?' else op(0) |
| 121 | + |
| 122 | + def isError: Boolean = false |
| 123 | + def isIndexed: Boolean = index.nonEmpty || hasFlag('<') |
| 124 | + def isLiteral: Boolean = false |
| 125 | + |
| 126 | + // descriptor is at index 0 of the part string |
| 127 | + def isLeading: Boolean = descriptor.at(Spec) == 0 |
| 128 | + |
| 129 | + // true if passes. Default checks flags and index |
| 130 | + def verify: Boolean = goodFlags && goodIndex |
| 131 | + |
| 132 | + // is the specifier OK with the given arg |
| 133 | + def accepts(arg: ClassTag[?]): Boolean = true |
| 134 | + |
| 135 | + // what arg type if any does the conversion accept |
| 136 | + def acceptableVariants: List[ClassTag[?]] |
| 137 | + |
| 138 | + // what flags does the conversion accept? defaults to all |
| 139 | + protected def okFlags: String = allFlags |
| 140 | + |
| 141 | + def hasFlag(f: Char) = flags.contains(f) |
| 142 | + def hasAnyFlag(fs: String) = fs.exists(hasFlag) |
| 143 | + |
| 144 | + def badFlag(f: Char, msg: String) = |
| 145 | + val i = flags.indexOf(f) match { case -1 => 0 case j => j } |
| 146 | + errorAt(Flags, i)(msg) |
| 147 | + |
| 148 | + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) |
| 149 | + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) |
| 150 | + |
| 151 | + def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") |
| 152 | + def noWidth = width.isEmpty or errorAt(Width)("width not allowed") |
| 153 | + def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") |
| 154 | + def only_-(msg: String) = |
| 155 | + val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } |
| 156 | + badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") |
| 157 | + def goodFlags = |
| 158 | + val badFlags = flags.filterNot(okFlags.contains) |
| 159 | + for f <- badFlags do badFlag(f, s"Illegal flag '$f'") |
| 160 | + badFlags.isEmpty |
| 161 | + def goodIndex = |
| 162 | + if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") |
| 163 | + val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) |
| 164 | + okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") |
| 165 | + object Conversion: |
| 166 | + def apply(m: Match, i: Int): Conversion = |
| 167 | + def badCC(msg: String) = ErrorXn(m, i).tap(error => error.errorAt(if (error.op.isEmpty) Spec else CC)(msg)) |
| 168 | + def cv(cc: Char) = cc match |
| 169 | + case 's' | 'S' => StringXn(m, i) |
| 170 | + case 'h' | 'H' => HashXn(m, i) |
| 171 | + case 'b' | 'B' => BooleanXn(m, i) |
| 172 | + case 'c' | 'C' => CharacterXn(m, i) |
| 173 | + case 'd' | 'o' | |
| 174 | + 'x' | 'X' => IntegralXn(m, i) |
| 175 | + case 'e' | 'E' | |
| 176 | + 'f' | |
| 177 | + 'g' | 'G' | |
| 178 | + 'a' | 'A' => FloatingPointXn(m, i) |
| 179 | + case 't' | 'T' => DateTimeXn(m, i) |
| 180 | + case '%' | 'n' => LiteralXn(m, i) |
| 181 | + case _ => badCC(s"illegal conversion character '$cc'") |
| 182 | + end cv |
| 183 | + m.group(CC) match |
| 184 | + case Some(cc) => cv(cc(0)).tap(_.verify) |
| 185 | + case None => badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp") |
| 186 | + end apply |
| 187 | + val literalHelp = "use %% for literal %, %n for newline" |
| 188 | + end Conversion |
| 189 | + abstract class GeneralXn extends Conversion |
| 190 | + // s | S |
| 191 | + class StringXn(val descriptor: Match, val argi: Int) extends GeneralXn: |
| 192 | + val acceptableVariants = |
| 193 | + if hasFlag('#') then classTag[Formattable] :: Nil |
| 194 | + else classTag[Any] :: Nil |
| 195 | + override protected def okFlags = "-#<" |
| 196 | + // b | B |
| 197 | + class BooleanXn(val descriptor: Match, val argi: Int) extends GeneralXn: |
| 198 | + val FakeNullTag: ClassTag[?] = null |
| 199 | + val acceptableVariants = classTag[Boolean] :: FakeNullTag :: Nil |
| 200 | + override def accepts(arg: ClassTag[?]): Boolean = |
| 201 | + arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") |
| 202 | + override protected def okFlags = "-<" |
| 203 | + // h | H |
| 204 | + class HashXn(val descriptor: Match, val argi: Int) extends GeneralXn: |
| 205 | + val acceptableVariants = classTag[Any] :: Nil |
| 206 | + override protected def okFlags = "-<" |
| 207 | + // %% | %n |
| 208 | + class LiteralXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 209 | + override def isLiteral = true |
| 210 | + override def verify = op match |
| 211 | + case "%" => super.verify && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) |
| 212 | + case "n" => noFlags && noWidth && noPrecision |
| 213 | + override protected val okFlags = "-" |
| 214 | + override def acceptableVariants = Nil |
| 215 | + class CharacterXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 216 | + override def verify = super.verify && noPrecision && only_-("c conversion") |
| 217 | + val acceptableVariants = classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil |
| 218 | + class IntegralXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 219 | + override def verify = |
| 220 | + def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") |
| 221 | + def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") |
| 222 | + super.verify && noPrecision && !d_# && !x_comma |
| 223 | + val acceptableVariants = classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil |
| 224 | + override def accepts(arg: ClassTag[?]): Boolean = |
| 225 | + arg == classTag[BigInt] || { |
| 226 | + cc match |
| 227 | + case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; false |
| 228 | + case _ => true |
| 229 | + } |
| 230 | + class FloatingPointXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 231 | + override def verify = super.verify && (cc match { |
| 232 | + case 'a' | 'A' => |
| 233 | + val badFlags = ",(".filter(hasFlag) |
| 234 | + noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) |
| 235 | + case _ => true |
| 236 | + }) |
| 237 | + val acceptableVariants = classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil |
| 238 | + class DateTimeXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 239 | + override val cc: Char = if op.length > 1 then op(1) else '?' |
| 240 | + def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") |
| 241 | + def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") |
| 242 | + override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions") |
| 243 | + val acceptableVariants = classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil |
| 244 | + class ErrorXn(val descriptor: Match, val argi: Int) extends Conversion: |
| 245 | + override def isError = true |
| 246 | + override def verify = false |
| 247 | + override def acceptableVariants = Nil |
0 commit comments