Skip to content

Commit 14ae862

Browse files
committed
Intrinsify StringContext.f
Ported the current macro implementation
1 parent c40ef6e commit 14ae862

File tree

9 files changed

+152
-281
lines changed

9 files changed

+152
-281
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ class Definitions {
553553
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
554554
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
555555

556+
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
557+
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)
558+
556559
@tu lazy val ArrayType: TypeRef = requiredClassRef("scala.Array")
557560
def ArrayClass(using Context): ClassSymbol = ArrayType.symbol.asClass
558561
@tu lazy val Array_apply : Symbol = ArrayClass.requiredMethod(nme.apply)
@@ -733,9 +736,6 @@ class Definitions {
733736
@tu lazy val StringContextModule_standardInterpolator: Symbol = StringContextModule.requiredMethod(nme.standardInterpolator)
734737
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
735738

736-
@tu lazy val InternalStringContextMacroModule: Symbol = requiredModule("dotty.internal.StringContextMacro")
737-
@tu lazy val InternalStringContextMacroModule_f: Symbol = InternalStringContextMacroModule.requiredMethod(nme.f)
738-
739739
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
740740
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
741741
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ object StdNames {
477477
val flagsFromBits : N = "flagsFromBits"
478478
val flatMap: N = "flatMap"
479479
val foreach: N = "foreach"
480+
val format: N = "format"
480481
val fromDigits: N = "fromDigits"
481482
val fromProduct: N = "fromProduct"
482483
val genericArrayOps: N = "genericArrayOps"

library/src-bootstrapped/dotty/internal/StringContextMacro.scala renamed to compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1-
// ALWAYS KEEP THIS FILE IN src-bootstrapped, DO NOT MOVE TO src
2-
3-
package dotty.internal
4-
5-
import scala.quoted._
6-
7-
object StringContextMacro {
8-
9-
/** Implementation of scala.StringContext.f used in Dotty */
10-
inline def f(inline sc: StringContext)(inline args: Any*): String = ${ interpolate('sc, 'args) }
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import dotty.tools.dotc.ast.Trees._
5+
import dotty.tools.dotc.ast.tpd
6+
import dotty.tools.dotc.core.Decorators._
7+
import dotty.tools.dotc.core.Constants.Constant
8+
import dotty.tools.dotc.core.Contexts._
9+
import dotty.tools.dotc.core.StdNames._
10+
import dotty.tools.dotc.core.NameKinds._
11+
import dotty.tools.dotc.core.Symbols._
12+
import dotty.tools.dotc.core.Types._
13+
14+
// Ported from old dotty.internal.StringContextMacro
15+
// TODO: port Scala 2 logic? (see https://github.com/scala/scala/blob/2.13.x/src/compiler/scala/tools/reflect/FormatInterpolator.scala#L74)
16+
object StringContextChecker {
17+
import tpd._
1118

1219
/** This trait defines a tool to report errors/warnings that do not depend on Position. */
1320
trait Reporter {
@@ -51,53 +58,52 @@ object StringContextMacro {
5158
def restoreReported() : Unit
5259
}
5360

54-
/** Interpolates the arguments to the formatting String given inside a StringContext
55-
*
56-
* @param strCtxExpr the Expr that holds the StringContext which contains all the chunks of the formatting string
57-
* @param args the Expr that holds the sequence of arguments to interpolate to the String in the correct format
58-
* @return the Expr containing the formatted and interpolated String or an error/warning if the parameters are not correct
59-
*/
60-
private def interpolate(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(using qctx: QuoteContext): Expr[String] = {
61-
import qctx.tasty._
62-
val sourceFile = strCtxExpr.unseal.pos.sourceFile
63-
64-
val (partsExpr, parts) = strCtxExpr match {
65-
case Expr.StringContext(p1 as Consts(p2)) => (p1.toList, p2.toList)
66-
case _ => report.throwError("Expected statically known String Context", strCtxExpr)
61+
/** Check the format of the parts of the f".." arguments and returns the string parts of the StringContext */
62+
def checkedParts(strContext_f: Tree, args0: Tree)(using Context): String = {
63+
64+
val (partsExpr, parts) = strContext_f match {
65+
case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) =>
66+
(parts.elems, parts.elems.map { case Literal(Constant(str: String)) => str } )
67+
case _ =>
68+
report.error("Expected statically known String Context", strContext_f.srcPos)
69+
return ""
6770
}
6871

69-
val args = argsExpr match {
70-
case Varargs(args) => args
71-
case _ => report.throwError("Expected statically known argument list", argsExpr)
72+
val args = args0 match {
73+
case args: SeqLiteral => args.elems
74+
case _ =>
75+
report.error("Expected statically known argument list", args0.srcPos)
76+
return ""
7277
}
7378

7479
val reporter = new Reporter{
7580
private[this] var reported = false
7681
private[this] var oldReported = false
7782
def partError(message : String, index : Int, offset : Int) : Unit = {
7883
reported = true
79-
val positionStart = partsExpr(index).unseal.pos.start + offset
80-
error(message, sourceFile, positionStart, positionStart)
84+
val pos = partsExpr(index).sourcePos
85+
val posOffset = pos.withSpan(pos.span.shift(offset))
86+
report.error(message, posOffset)
8187
}
8288
def partWarning(message : String, index : Int, offset : Int) : Unit = {
8389
reported = true
84-
val positionStart = partsExpr(index).unseal.pos.start + offset
85-
warning(message, sourceFile, positionStart, positionStart)
90+
val pos = partsExpr(index).sourcePos
91+
val posOffset = pos.withSpan(pos.span.shift(offset))
92+
report.warning(message, posOffset)
8693
}
8794

8895
def argError(message : String, index : Int) : Unit = {
8996
reported = true
90-
error(message, args(index).unseal.pos)
97+
report.error(message, args(index).srcPos)
9198
}
9299

93100
def strCtxError(message : String) : Unit = {
94101
reported = true
95-
val positionStart = strCtxExpr.unseal.pos.start
96-
error(message, sourceFile, positionStart, positionStart)
102+
report.error(message, strContext_f.srcPos)
97103
}
98104
def argsError(message : String) : Unit = {
99105
reported = true
100-
error(message, argsExpr.unseal.pos)
106+
report.error(message, args0.srcPos)
101107
}
102108

103109
def hasReported() : Boolean = {
@@ -114,18 +120,11 @@ object StringContextMacro {
114120
}
115121
}
116122

117-
interpolate(parts, args, argsExpr, reporter)
123+
checked(parts, args, reporter)
118124
}
119125

120-
/** Helper function for the interpolate function above
121-
*
122-
* @param partsExpr the list of parts enumerated as Expr
123-
* @param args the list of arguments enumerated as Expr
124-
* @param reporter the reporter to return any error/warning when a problem is encountered
125-
* @return the Expr containing the formatted and interpolated String or an error/warning report if the parameters are not correct
126-
*/
127-
def interpolate(parts0 : List[String], args : Seq[Expr[Any]], argsExpr: Expr[Seq[Any]], reporter : Reporter)(using qctx: QuoteContext) : Expr[String] = {
128-
import qctx.tasty._
126+
def checked(parts0: List[String], args: List[Tree], reporter: Reporter)(using Context): String = {
127+
129128

130129
/** Checks if the number of arguments are the same as the number of formatting strings
131130
*
@@ -585,21 +584,21 @@ object StringContextMacro {
585584
* nothing otherwise
586585
*/
587586
def checkTypeWithArgs(argument : (Type, Int), conversionChar : Char, partIndex : Int, flags : List[(Char, Int)]) = {
588-
val booleans = List(Type.of[Boolean], Type.of[Null])
589-
val dates = List(Type.of[Long], Type.of[java.util.Calendar], Type.of[java.util.Date])
590-
val floatingPoints = List(Type.of[Double], Type.of[Float], Type.of[java.math.BigDecimal])
591-
val integral = List(Type.of[Int], Type.of[Long], Type.of[Short], Type.of[Byte], Type.of[java.math.BigInteger])
592-
val character = List(Type.of[Char], Type.of[Byte], Type.of[Short], Type.of[Int])
587+
val booleans = List(defn.BooleanType, defn.NullType)
588+
val dates = List(defn.LongType, requiredClass("java.util.Calendar").typeRef, requiredClass("java.util.Date").typeRef)
589+
val floatingPoints = List(defn.DoubleType, defn.FloatType, requiredClass("java.math.BigDecimal").typeRef)
590+
val integral = List(defn.IntType, defn.LongType, defn.ShortType, defn.ByteType, requiredClass("java.math.BigInteger").typeRef)
591+
val character = List(defn.CharType, defn.ByteType, defn.ShortType, defn.IntType)
593592

594593
val (argType, argIndex) = argument
595594
conversionChar match {
596595
case 'c' | 'C' => checkSubtype(argType, "Char", argIndex, character : _*)
597596
case 'd' | 'o' | 'x' | 'X' => {
598597
checkSubtype(argType, "Int", argIndex, integral : _*)
599598
if (conversionChar != 'd') {
600-
val notAllowedFlagOnCondition = List(('+', !(argType <:< Type.of[java.math.BigInteger]), "only use '+' for BigInt conversions to o, x, X"),
601-
(' ', !(argType <:< Type.of[java.math.BigInteger]), "only use ' ' for BigInt conversions to o, x, X"),
602-
('(', !(argType <:< Type.of[java.math.BigInteger]), "only use '(' for BigInt conversions to o, x, X"),
599+
val notAllowedFlagOnCondition = List(('+', !(argType <:< requiredClass("java.math.BigInteger").typeRef), "only use '+' for BigInt conversions to o, x, X"),
600+
(' ', !(argType <:< requiredClass("java.math.BigInteger").typeRef), "only use ' ' for BigInt conversions to o, x, X"),
601+
('(', !(argType <:< requiredClass("java.math.BigInteger").typeRef), "only use '(' for BigInt conversions to o, x, X"),
603602
(',', true, "',' only allowed for d conversion of integral types"))
604603
checkFlags(partIndex, flags, notAllowedFlagOnCondition : _*)
605604
}
@@ -608,7 +607,7 @@ object StringContextMacro {
608607
case 't' | 'T' => checkSubtype(argType, "Date", argIndex, dates : _*)
609608
case 'b' | 'B' => checkSubtype(argType, "Boolean", argIndex, booleans : _*)
610609
case 'h' | 'H' | 'S' | 's' =>
611-
if (!(argType <:< Type.of[java.util.Formattable]))
610+
if !(argType <:< requiredClass("java.util.Formattable").typeRef) then
612611
for {flag <- flags ; if (flag._1 == '#')}
613612
reporter.argError("type mismatch;\n found : " + argType.widen.show.stripPrefix("scala.Predef.").stripPrefix("java.lang.").stripPrefix("scala.") + "\n required: java.util.Formattable", argIndex)
614613
case 'n' | '%' =>
@@ -647,7 +646,7 @@ object StringContextMacro {
647646
* @param maxArgumentIndex an Option containing the maximum argument index possible, None if no args are specified
648647
* @return a list with all the elements of the conversion per formatting string
649648
*/
650-
def checkPart(part : String, start : Int, argument : Option[(Int, Expr[Any])], maxArgumentIndex : Option[Int]) : List[(Option[(Type, Int)], Char, List[(Char, Int)])] = {
649+
def checkPart(part : String, start : Int, argument : Option[(Int, Tree)], maxArgumentIndex : Option[Int]) : List[(Option[(Type, Int)], Char, List[(Char, Int)])] = {
651650
reporter.resetReported()
652651
val hasFormattingSubstring = getFormattingSubstring(part, part.size, start)
653652
if (hasFormattingSubstring.nonEmpty) {
@@ -658,7 +657,7 @@ object StringContextMacro {
658657
case Some(argIndex, arg) => {
659658
val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, argIndex, argIndex + 1, false, formattingStart)
660659
if (!reporter.hasReported()){
661-
val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.unseal.tpe), part)
660+
val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.tpe), part)
662661
nextStart = conversion + 1
663662
conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex)
664663
} else checkPart(part, conversion + 1, argument, maxArgumentIndex)
@@ -710,7 +709,6 @@ object StringContextMacro {
710709
}
711710
}
712711

713-
// macro expansion
714-
'{(${Expr(parts.mkString)}).format(${argsExpr}: _*)}
712+
parts.mkString
715713
}
716714
}

compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ import dotty.tools.dotc.core.Decorators._
77
import dotty.tools.dotc.core.Constants.Constant
88
import dotty.tools.dotc.core.Contexts._
99
import dotty.tools.dotc.core.StdNames._
10+
import dotty.tools.dotc.core.NameKinds._
1011
import dotty.tools.dotc.core.Symbols._
11-
import dotty.tools.dotc.core.Types.MethodType
12+
import dotty.tools.dotc.core.Types._
1213
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
1314

1415
/**
@@ -116,6 +117,7 @@ class StringInterpolatorOpt extends MiniPhase {
116117
val sym = tree.symbol
117118
val isInterpolatedMethod = // Test names first to avoid loading scala.StringContext if not used
118119
(sym.name == nme.raw_ && sym.eq(defn.StringContext_raw)) ||
120+
(sym.name == nme.f && sym.eq(defn.StringContext_f)) ||
119121
(sym.name == nme.s && sym.eq(defn.StringContext_s))
120122
if (isInterpolatedMethod)
121123
tree match {
@@ -132,6 +134,11 @@ class StringInterpolatorOpt extends MiniPhase {
132134
if (!str.const.stringValue.isEmpty) concat(str)
133135
}
134136
result
137+
case Apply(intp, args :: Nil) if sym.eq(defn.StringContext_f) =>
138+
val partsStr = StringContextChecker.checkedParts(intp, args).mkString
139+
resolveConstructor(defn.StringOps.typeRef, List(Literal(Constant(partsStr))))
140+
.select(nme.format)
141+
.appliedTo(args)
135142
// Starting with Scala 2.13, s and raw are macros in the standard
136143
// library, so we need to expand them manually.
137144
// sc.s(args) --> standardInterpolator(processEscapes, args, sc.parts)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,21 +3257,10 @@ class Typer extends Namer
32573257
if ((inlined ne tree) && errorCount == ctx.reporter.errorCount) readaptSimplified(inlined)
32583258
else inlined
32593259
}
3260-
else if tree.symbol.name == nme.f && tree.symbol == defn.StringContext_f then
3261-
// To avoid forcing StringContext_f when compiling StingContex
3262-
// we test the name before accession symbol StringContext_f.
3263-
3264-
// As scala.StringContext.f is defined in the standard library which
3265-
// we currently do not bootstrap we cannot implement the macro in the library.
3266-
// To overcome the current limitation we intercept the call and rewrite it into
3267-
// a call to dotty.internal.StringContext.f which we can implement using the new macros.
3268-
// As the macro is implemented in the bootstrapped library, it can only be used from the bootstrapped compiler.
3269-
val Apply(TypeApply(Select(sc, _), _), args) = tree
3270-
val newCall = ref(defn.InternalStringContextMacroModule_f).appliedTo(sc).appliedToArgs(args).withSpan(tree.span)
3271-
readaptSimplified(Inliner.inlineCall(newCall))
32723260
else if (tree.symbol.isScala2Macro &&
3273-
// raw and s are eliminated by the StringInterpolatorOpt phase
3261+
// `raw`, `f` and `s` are eliminated by the StringInterpolatorOpt phase
32743262
tree.symbol != defn.StringContext_raw &&
3263+
tree.symbol != defn.StringContext_f &&
32753264
tree.symbol != defn.StringContext_s)
32763265
if (ctx.settings.XignoreScala2Macros.value) {
32773266
report.warning("Scala 2 macro cannot be used in Dotty, this call will crash at runtime. See https://dotty.epfl.ch/docs/reference/dropped-features/macros.html", tree.srcPos.startPos)

library/src-non-bootstrapped/dotty/internal/StringContextMacro.scala

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/neg/f-interpolator-neg.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
object Test {
2+
3+
def numberArgumentsTests(s : String, d : Int) = {
4+
new StringContext().f() // error
5+
new StringContext("", " is ", "%2d years old").f(s) // error
6+
new StringContext("", " is ", "%2d years old").f(s, d, d) // error
7+
new StringContext("", "").f() // error
8+
}
9+
10+
def interpolationMismatches(s : String, f : Double, b : Boolean) = {
11+
f"$s%b" // error
12+
f"$s%c" // error
13+
f"$f%c" // error
14+
f"$s%x" // error
15+
f"$b%d" // error
16+
f"$s%d" // error
17+
f"$f%o" // error
18+
f"$s%e" // error
19+
f"$b%f" // error
20+
f"$s%i" // error
21+
}
22+
23+
def flagMismatches(s : String, c : Char, d : Int, f : Double, t : java.util.Date) = {
24+
f"$s%+ 0,(s" // error
25+
f"$c%#+ 0,(c" // error
26+
f"$d%#d" // error
27+
f"$d%,x" // error
28+
f"$d%+ (x" // error
29+
f"$f%,(a" // error
30+
f"$t%#+ 0,(tT" // error
31+
f"%-#+ 0,(n" // error
32+
f"%#+ 0,(%" // error
33+
}
34+
35+
def badPrecisions(c : Char, d : Int, f : Double, t : java.util.Date) = {
36+
f"$c%.2c" // error
37+
f"$d%.2d" // error
38+
f"%.2%" // error
39+
f"%.2n" // error
40+
f"$f%.2a" // error
41+
f"$t%.2tT" // error
42+
}
43+
44+
def badIndexes() = {
45+
f"%<s" // error
46+
f"%<c" // error
47+
f"%<tT" // error
48+
f"${8}%d ${9}%d %3$$d" // error
49+
f"${8}%d ${9}%d%0$$d" // error
50+
}
51+
52+
def warnings(s : String) = {
53+
f"${8}%d ${9}%1$$d"
54+
f"$s%s $s%s %1$$<s"
55+
f"$s%s $s%1$$s"
56+
}
57+
58+
def badArgTypes(s : String) = {
59+
f"$s%#s" // error
60+
}
61+
62+
def misunderstoodConversions(t : java.util.Date, s : String) = {
63+
f"$t%tG" // error
64+
f"$t%t" // error
65+
f"$s%10.5" // error
66+
}
67+
68+
def otherBrainFailures(d : Int) = {
69+
f"${d}random-leading-junk%d" // error
70+
f"%1$$n"
71+
f"%1$$d" // error
72+
f"blablablabla %% %.2d" // error
73+
f"blablablabla %.2b %%" // error
74+
75+
f"ana${3}%.2f%2${true}%bb" // error
76+
f"ac{2c{2{c.ca "
77+
78+
f"b%c.%2ii%iin" // error
79+
f"b}22%2.c<{%{" // error
80+
f"%%bci.2${'i'}%..2c2" // error
81+
}
82+
83+
}

0 commit comments

Comments
 (0)