Skip to content

Backport "improvement: Support using directives in worksheets" to 3.3 LTS #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ private sealed trait WarningSettings:
private val WvalueDiscard: Setting[Boolean] = BooleanSetting("-Wvalue-discard", "Warn when non-Unit expression results are unused.")
private val WNonUnitStatement = BooleanSetting("-Wnonunit-statement", "Warn when block statements are non-Unit expressions.")
private val WenumCommentDiscard = BooleanSetting("-Wenum-comment-discard", "Warn when a comment ambiguously assigned to multiple enum cases is discarded.")
private val WtoStringInterpolated = BooleanSetting("-Wtostring-interpolated", "Warn a standard interpolator used toString on a reference type.")
private val Wunused: Setting[List[ChoiceWithHelp[String]]] = MultiChoiceHelpSetting(
name = "-Wunused",
helpArg = "warning",
Expand Down Expand Up @@ -288,6 +289,7 @@ private sealed trait WarningSettings:
def valueDiscard(using Context): Boolean = allOr(WvalueDiscard)
def nonUnitStatement(using Context): Boolean = allOr(WNonUnitStatement)
def enumCommentDiscard(using Context): Boolean = allOr(WenumCommentDiscard)
def toStringInterpolated(using Context): Boolean = allOr(WtoStringInterpolated)
def checkInit(using Context): Boolean = allOr(YcheckInit)

/** -X "Extended" or "Advanced" settings */
Expand Down
164 changes: 92 additions & 72 deletions compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.util.matching.Regex.Match

import PartialFunction.cond

import dotty.tools.dotc.ast.tpd.{Match => _, *}
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Symbols.*
Expand All @@ -30,8 +28,9 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
def argType(argi: Int, types: Type*): Type =
require(argi < argc, s"$argi out of range picking from $types")
val tpe = argTypes(argi)
types.find(t => argConformsTo(argi, tpe, t))
.orElse(types.find(t => argConvertsTo(argi, tpe, t)))
types.find(t => t != defn.AnyType && argConformsTo(argi, tpe, t))
.orElse(types.find(t => t != defn.AnyType && argConvertsTo(argi, tpe, t)))
.orElse(types.find(t => t == defn.AnyType && argConformsTo(argi, tpe, t)))
.getOrElse {
report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi)
actuals += args(argi)
Expand Down Expand Up @@ -64,50 +63,57 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List

/** For N part strings and N-1 args to interpolate, normalize parts and check arg types.
*
* Returns normalized part strings and args, where args correcpond to conversions in tail of parts.
* Returns normalized part strings and args, where args correspond to conversions in tail of parts.
*/
def checked: (List[String], List[Tree]) =
val amended = ListBuffer.empty[String]
val convert = ListBuffer.empty[Conversion]

def checkPart(part: String, n: Int): Unit =
val matches = formatPattern.findAllMatchIn(part)

def insertStringConversion(): Unit =
amended += "%s" + part
val cv = Conversion.stringXn(n)
cv.accepts(argType(n-1, defn.AnyType))
convert += cv
cv.lintToString(argTypes(n-1))

def errorLeading(op: Conversion) = op.errorAt(Spec):
s"conversions must follow a splice; ${Conversion.literalHelp}"

def accept(op: Conversion): Unit =
if !op.isLeading then errorLeading(op)
op.accepts(argType(n-1, op.acceptableVariants*))
amended += part
convert += op
op.lintToString(argTypes(n-1))

// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
if n == 0 then amended += part
else if !matches.hasNext then insertStringConversion()
else
val cv = Conversion(matches.next(), n)
if cv.isLiteral then insertStringConversion()
else if cv.isIndexed then
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
else if !cv.isError then accept(cv)

// any remaining conversions in this part must be either literals or indexed
while matches.hasNext do
val cv = Conversion(matches.next(), n)
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)
end checkPart

@tailrec
def loop(remaining: List[String], n: Int): Unit =
remaining match
case part0 :: more =>
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
val part = try StringContext.processEscapes(part0) catch badPart
val matches = formatPattern.findAllMatchIn(part)

def insertStringConversion(): Unit =
amended += "%s" + part
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve
argType(n-1, defn.AnyType)
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}")
def accept(op: Conversion): Unit =
if !op.isLeading then errorLeading(op)
op.accepts(argType(n-1, op.acceptableVariants*))
amended += part
convert += op

// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
if n == 0 then amended += part
else if !matches.hasNext then insertStringConversion()
else
val cv = Conversion(matches.next(), n)
if cv.isLiteral then insertStringConversion()
else if cv.isIndexed then
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
else if !cv.isError then accept(cv)

// any remaining conversions in this part must be either literals or indexed
while matches.hasNext do
val cv = Conversion(matches.next(), n)
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)

loop(more, n + 1)
case Nil => ()
end loop
def loop(remaining: List[String], n: Int): Unit = remaining match
case part0 :: remaining =>
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
val part = try StringContext.processEscapes(part0) catch badPart
checkPart(part, n)
loop(remaining, n + 1)
case Nil =>

loop(parts, n = 0)
if reported then (Nil, Nil)
Expand All @@ -125,10 +131,8 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt)

extension (inline value: Boolean)
inline def or(inline body: => Unit): Boolean = value || { body ; false }
inline def orElse(inline body: => Unit): Boolean = value || { body ; true }
inline def and(inline body: => Unit): Boolean = value && { body ; true }
inline def but(inline body: => Unit): Boolean = value && { body ; false }
inline infix def or(inline body: => Unit): Boolean = value || { body; false }
inline infix def and(inline body: => Unit): Boolean = value && { body; true }

enum Kind:
case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn
Expand All @@ -147,9 +151,10 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
// the conversion char is the head of the op string (but see DateTimeXn)
val cc: Char =
kind match
case ErrorXn => if op.isEmpty then '?' else op(0)
case DateTimeXn => if op.length > 1 then op(1) else '?'
case _ => op(0)
case ErrorXn => if op.isEmpty then '?' else op(0)
case DateTimeXn => if op.length <= 1 then '?' else op(1)
case StringXn => if op.isEmpty then 's' else op(0) // accommodate the default %s
case _ => op(0)

def isIndexed: Boolean = index.nonEmpty || hasFlag('<')
def isError: Boolean = kind == ErrorXn
Expand Down Expand Up @@ -209,18 +214,28 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
// is the specifier OK with the given arg
def accepts(arg: Type): Boolean =
kind match
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean")
case IntegralXn =>
arg == BigIntType || !cond(cc) {
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true
}
case BooleanXn if arg != defn.BooleanType =>
warningAt(CC):
"""non-Boolean value formats as "true" for non-null references and boxed primitives, otherwise "false""""
true
case IntegralXn if arg != BigIntType =>
cc match
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") =>
"+ (".filter(hasFlag).foreach: bad =>
badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")
false
case _ => true
case _ => true

def lintToString(arg: Type): Unit =
if ctx.settings.Whas.toStringInterpolated && kind == StringXn && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
then warningAt(CC)("interpolation uses toString")

// what arg type if any does the conversion accept
def acceptableVariants: List[Type] =
kind match
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil
case BooleanXn => defn.BooleanType :: defn.NullType :: defn.AnyType :: Nil // warn if not boolean
case HashXn => defn.AnyType :: Nil
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil
Expand Down Expand Up @@ -249,25 +264,30 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List

object Conversion:
def apply(m: Match, i: Int): Conversion =
def kindOf(cc: Char) = cc match
case 's' | 'S' => StringXn
case 'h' | 'H' => HashXn
case 'b' | 'B' => BooleanXn
case 'c' | 'C' => CharacterXn
case 'd' | 'o' |
'x' | 'X' => IntegralXn
case 'e' | 'E' |
'f' |
'g' | 'G' |
'a' | 'A' => FloatingPointXn
case 't' | 'T' => DateTimeXn
case '%' | 'n' => LiteralXn
case _ => ErrorXn
end kindOf
m.group(CC) match
case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify)
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
case Some(cc) =>
val xn = cc(0) match
case 's' | 'S' => StringXn
case 'h' | 'H' => HashXn
case 'b' | 'B' => BooleanXn
case 'c' | 'C' => CharacterXn
case 'd' | 'o' |
'x' | 'X' => IntegralXn
case 'e' | 'E' |
'f' |
'g' | 'G' |
'a' | 'A' => FloatingPointXn
case 't' | 'T' => DateTimeXn
case '%' | 'n' => LiteralXn
case _ => ErrorXn
new Conversion(m, i, xn)
.tap(_.verify)
case None =>
new Conversion(m, i, ErrorXn)
.tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
end apply
// construct a default %s conversion
def stringXn(i: Int): Conversion = new Conversion(formatPattern.findAllMatchIn("%").next(), i, StringXn)
val literalHelp = "use %% for literal %, %n for newline"
end Conversion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,22 @@ class StringInterpolatorOpt extends MiniPhase:
def mkConcat(strs: List[Literal], elems: List[Tree]): Tree =
val stri = strs.iterator
val elemi = elems.iterator
var result: Tree = stri.next
var result: Tree = stri.next()
def concat(tree: Tree): Unit =
result = result.select(defn.String_+).appliedTo(tree).withSpan(tree.span)
while elemi.hasNext
do
concat(elemi.next)
val str = stri.next
val elem = elemi.next()
lintToString(elem)
concat(elem)
val str = stri.next()
if !str.const.stringValue.isEmpty then concat(str)
result
end mkConcat
def lintToString(t: Tree): Unit =
val arg: Type = t.tpe
if ctx.settings.Whas.toStringInterpolated && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
then report.warning("interpolation uses toString", t.srcPos)
val sym = tree.symbol
// Test names first to avoid loading scala.StringContext if not used, and common names first
val isInterpolatedMethod =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ final class PcInlineValueProvider(
text
)(startOffset, endOffset)
val startPos = new l.Position(
range.getStart.getLine,
range.getStart.getCharacter - (startOffset - startWithSpace)
range.getStart.nn.getLine,
range.getStart.nn.getCharacter - (startOffset - startWithSpace)
)
val endPos =
if (endWithSpace - 1 >= 0 && text(endWithSpace - 1) == '\n')
new l.Position(range.getEnd.getLine + 1, 0)
new l.Position(range.getEnd.nn.getLine + 1, 0)
else
new l.Position(
range.getEnd.getLine,
range.getEnd.getCharacter + endWithSpace - endOffset
range.getEnd.nn.getLine,
range.getEnd.nn.getCharacter + endWithSpace - endOffset
)

new l.Range(startPos, endPos)
Expand Down Expand Up @@ -129,15 +129,15 @@ final class PcInlineValueProvider(
end defAndRefs

private def stripIndentPrefix(rhs: String, refIndent: String, defIndent: String): String =
val rhsLines = rhs.split("\n").toList
val rhsLines = rhs.split("\n").nn.toList
rhsLines match
case h :: Nil => rhs
case h :: t =>
val noPrefixH = h.stripPrefix(refIndent)
val noPrefixH = h.nn.stripPrefix(refIndent)
if noPrefixH.startsWith("{") then
noPrefixH ++ t.map(refIndent ++ _.stripPrefix(defIndent)).mkString("\n","\n", "")
noPrefixH ++ t.map(refIndent ++ _.nn.stripPrefix(defIndent)).mkString("\n","\n", "")
else
((" " ++ h) :: t).map(refIndent ++ _.stripPrefix(defIndent)).mkString("\n", "\n", "")
((" " ++ h.nn) :: t).map(refIndent ++ _.nn.stripPrefix(defIndent)).mkString("\n", "\n", "")
case Nil => rhs

private def definitionRequiresBrackets(tree: Tree)(using Context): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ class ScalaCliCompletions(
):
def unapply(path: List[Tree]) =
def scalaCliDep = CoursierComplete.isScalaCliDep(
pos.lineContent.take(pos.column).stripPrefix("/*<script>*/")
pos.lineContent.take(pos.column).stripPrefix("/*<script>*/").dropWhile(c => c == ' ' || c == '\t')
)

lazy val supportsUsing =
val filename = pos.source.file.path
filename.endsWith(".sc.scala") ||
filename.endsWith(".worksheet.sc")

path match
case Nil | (_: PackageDef) :: _ => scalaCliDep
// generated script file will end with .sc.scala
case (_: TypeDef) :: (_: PackageDef) :: Nil if pos.source.file.path.endsWith(".sc.scala") =>
case (_: TypeDef) :: (_: PackageDef) :: Nil if supportsUsing =>
scalaCliDep
case (_: Template) :: (_: TypeDef) :: Nil if pos.source.file.path.endsWith(".sc.scala") =>
case (_: Template) :: (_: TypeDef) :: Nil if supportsUsing =>
scalaCliDep
case head :: next => None

Expand Down
8 changes: 4 additions & 4 deletions tests/neg/f-interpolator-neg.check
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
7 | new StringContext("", "").f() // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| too few arguments for interpolated string
-- [E209] Interpolation Error: tests/neg/f-interpolator-neg.scala:11:7 -------------------------------------------------
11 | f"$s%b" // error
| ^
| Found: (s : String), Required: Boolean, Null
-- [E209] Interpolation Warning: tests/neg/f-interpolator-neg.scala:11:9 -----------------------------------------------
11 | f"$s%b" // warn only
| ^
| non-Boolean value formats as "true" for non-null references and boxed primitives, otherwise "false"
-- [E209] Interpolation Error: tests/neg/f-interpolator-neg.scala:12:7 -------------------------------------------------
12 | f"$s%c" // error
| ^
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/f-interpolator-neg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ object Test {
}

def interpolationMismatches(s : String, f : Double, b : Boolean) = {
f"$s%b" // error
f"$s%b" // warn only
f"$s%c" // error
f"$f%c" // error
f"$s%x" // error
Expand Down
36 changes: 36 additions & 0 deletions tests/warn/tostring-interpolated.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//> using options -Wtostring-interpolated
//> abusing options -Wconf:cat=w-flag-tostring-interpolated:e -Wtostring-interpolated

case class C(x: Int)

trait T {
def c = C(42)
def f = f"$c" // warn
def s = s"$c" // warn
def r = raw"$c" // warn

def format = f"${c.x}%d in $c or $c%s" // warn using c.toString // warn

def bool = f"$c%b" // warn just a null check

def oops = s"${null} slipped thru my fingers" // warn

def ok = s"${c.toString}"

def sb = new StringBuilder().append("hello")
def greeting = s"$sb, world" // warn
}

class Mitigations {

val s = "hello, world"
val i = 42
def shown = println("shown")

def ok = s"$s is ok"
def jersey = s"number $i"
def unitized = s"unfortunately $shown" // maybe tell them about unintended ()?

def nopct = f"$s is ok"
def nofmt = f"number $i"
}
Loading