Skip to content

Commit dc2b0b3

Browse files
committed
Emit efficient code for switch over strings
The pattern matcher will now emit `Match` with `String` scrutinee as well as the existing `Int` scrutinee. The JVM backend handles this case by emitting bytecode that switches on the String's `hashCode` (this matches what Java does). The SJS already handles `String` matches. The approach is similar to scala/scala#8451 (see scala/bug#11740 too), except that instead of doing a transformation on the AST, we just emit the right bytecode straight away. This is desirable since it means that Scala.js (and any other backend) can choose their own optimised strategy for compiling a match on strings.
1 parent cca5f8f commit dc2b0b3

File tree

7 files changed

+295
-75
lines changed

7 files changed

+295
-75
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 156 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package backend
33
package jvm
44

55
import scala.annotation.switch
6+
import scala.collection.mutable.SortedMap
67

78
import scala.tools.asm
89
import scala.tools.asm.{Handle, Label, Opcodes}
@@ -826,61 +827,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
826827
generatedType
827828
}
828829

829-
/*
830-
* A Match node contains one or more case clauses,
831-
* each case clause lists one or more Int values to use as keys, and a code block.
832-
* Except the "default" case clause which (if it exists) doesn't list any Int key.
833-
*
834-
* On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels).
835-
* That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch.
836-
*
837-
* On a second pass, we emit the switch blocks, one for each different target.
830+
/* A Match node contains one or more case clauses, each case clause lists one or more
831+
* Int/String values to use as keys, and a code block. The exception is the "default" case
832+
* clause which doesn't list any key (there is exactly one of these per match).
838833
*/
839834
private def genMatch(tree: Match): BType = tree match {
840835
case Match(selector, cases) =>
841836
lineNumber(tree)
842-
genLoad(selector, INT)
843837
val generatedType = tpeTK(tree)
838+
val postMatch = new asm.Label
844839

845-
var flatKeys: List[Int] = Nil
846-
var targets: List[asm.Label] = Nil
847-
var default: asm.Label = null
848-
var switchBlocks: List[(asm.Label, Tree)] = Nil
849-
850-
// collect switch blocks and their keys, but don't emit yet any switch-block.
851-
for (caze @ CaseDef(pat, guard, body) <- cases) {
852-
assert(guard == tpd.EmptyTree, guard)
853-
val switchBlockPoint = new asm.Label
854-
switchBlocks ::= (switchBlockPoint, body)
855-
pat match {
856-
case Literal(value) =>
857-
flatKeys ::= value.intValue
858-
targets ::= switchBlockPoint
859-
case Ident(nme.WILDCARD) =>
860-
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
861-
default = switchBlockPoint
862-
case Alternative(alts) =>
863-
alts foreach {
864-
case Literal(value) =>
865-
flatKeys ::= value.intValue
866-
targets ::= switchBlockPoint
867-
case _ =>
868-
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
869-
}
870-
case _ =>
871-
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
840+
// Only two possible selector types exist in `Match` trees at this point: Int and String
841+
if (tpeTK(selector) == INT) {
842+
843+
/* On a first pass over the case clauses, we flatten the keys and their
844+
* targets (the latter represented with asm.Labels). That representation
845+
* allows JCodeMethodV to emit a lookupswitch or a tableswitch.
846+
*
847+
* On a second pass, we emit the switch blocks, one for each different target.
848+
*/
849+
850+
var flatKeys: List[Int] = Nil
851+
var targets: List[asm.Label] = Nil
852+
var default: asm.Label = null
853+
var switchBlocks: List[(asm.Label, Tree)] = Nil
854+
855+
genLoad(selector, INT)
856+
857+
// collect switch blocks and their keys, but don't emit yet any switch-block.
858+
for (caze @ CaseDef(pat, guard, body) <- cases) {
859+
assert(guard == tpd.EmptyTree, guard)
860+
val switchBlockPoint = new asm.Label
861+
switchBlocks ::= (switchBlockPoint, body)
862+
pat match {
863+
case Literal(value) =>
864+
flatKeys ::= value.intValue
865+
targets ::= switchBlockPoint
866+
case Ident(nme.WILDCARD) =>
867+
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
868+
default = switchBlockPoint
869+
case Alternative(alts) =>
870+
alts foreach {
871+
case Literal(value) =>
872+
flatKeys ::= value.intValue
873+
targets ::= switchBlockPoint
874+
case _ =>
875+
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
876+
}
877+
case _ =>
878+
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
879+
}
872880
}
873-
}
874881

875-
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
882+
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
876883

877-
// emit switch-blocks.
878-
val postMatch = new asm.Label
879-
for (sb <- switchBlocks.reverse) {
880-
val (caseLabel, caseBody) = sb
881-
markProgramPoint(caseLabel)
882-
genLoad(caseBody, generatedType)
883-
bc goTo postMatch
884+
// emit switch-blocks.
885+
for (sb <- switchBlocks.reverse) {
886+
val (caseLabel, caseBody) = sb
887+
markProgramPoint(caseLabel)
888+
genLoad(caseBody, generatedType)
889+
bc goTo postMatch
890+
}
891+
} else {
892+
893+
/* Since the JVM doesn't have a way to switch on a string, we switch
894+
* on the `hashCode` of the string then do an `equals` check (with a
895+
* possible second set of jumps if blocks can be reach from multiple
896+
* string alternatives).
897+
*
898+
* This mirrors the way that Java compiles `switch` on Strings.
899+
*/
900+
901+
var default: asm.Label = null
902+
var indirectBlocks: List[(asm.Label, Tree)] = Nil
903+
904+
import scala.collection.mutable
905+
906+
// Cases grouped by their hashCode
907+
val casesByHash = SortedMap.empty[Int, List[(String, Either[asm.Label, Tree])]]
908+
var caseFallback: Tree = null
909+
910+
for (caze @ CaseDef(pat, guard, body) <- cases) {
911+
assert(guard == tpd.EmptyTree, guard)
912+
pat match {
913+
case Literal(value) =>
914+
val strValue = value.stringValue
915+
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
916+
val newCase = (strValue, Right(body))
917+
Some(newCase :: existingCasesOpt.getOrElse(Nil))
918+
}
919+
case Ident(nme.WILDCARD) =>
920+
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
921+
default = new asm.Label
922+
indirectBlocks ::= (default, body)
923+
case Alternative(alts) =>
924+
// We need an extra basic block since multiple strings can lead to this code
925+
val indirectCaseGroupLabel = new asm.Label
926+
indirectBlocks ::= (indirectCaseGroupLabel, body)
927+
alts foreach {
928+
case Literal(value) =>
929+
val strValue = value.stringValue
930+
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
931+
val newCase = (strValue, Left(indirectCaseGroupLabel))
932+
Some(newCase :: existingCasesOpt.getOrElse(Nil))
933+
}
934+
case _ =>
935+
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
936+
}
937+
938+
case _ =>
939+
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
940+
}
941+
}
942+
943+
// Organize the hashCode options into switch cases
944+
var flatKeys: List[Int] = Nil
945+
var targets: List[asm.Label] = Nil
946+
var hashBlocks: List[(asm.Label, List[(String, Either[asm.Label, Tree])])] = Nil
947+
for ((hashValue, hashCases) <- casesByHash) {
948+
val switchBlockPoint = new asm.Label
949+
hashBlocks ::= (switchBlockPoint, hashCases)
950+
flatKeys ::= hashValue
951+
targets ::= switchBlockPoint
952+
}
953+
954+
// Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it
955+
genLoadIf(
956+
If(
957+
tree.selector.select(defn.Any_==).appliedTo(nullLiteral),
958+
Literal(Constant(0)),
959+
tree.selector.select(defn.Any_hashCode).appliedToNone
960+
),
961+
INT
962+
)
963+
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
964+
965+
// emit blocks for each hash case
966+
for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) {
967+
markProgramPoint(hashLabel)
968+
for ((caseString, indirectLblOrBody) <- caseAlternatives) {
969+
val comparison = if (caseString == null) defn.Any_== else defn.Any_equals
970+
val condp = Literal(Constant(caseString)).select(defn.Any_==).appliedTo(tree.selector)
971+
val keepGoing = new asm.Label
972+
indirectLblOrBody match {
973+
case Left(jump) =>
974+
genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing)
975+
976+
case Right(caseBody) =>
977+
val thisCaseMatches = new asm.Label
978+
genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches)
979+
markProgramPoint(thisCaseMatches)
980+
genLoad(caseBody, generatedType)
981+
bc goTo postMatch
982+
}
983+
markProgramPoint(keepGoing)
984+
}
985+
bc goTo default
986+
}
987+
988+
// emit blocks for common patterns
989+
for ((caseLabel, caseBody) <- indirectBlocks.reverse) {
990+
markProgramPoint(caseLabel)
991+
genLoad(caseBody, generatedType)
992+
bc goTo postMatch
993+
}
884994
}
885995

886996
markProgramPoint(postMatch)

compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,12 +2863,6 @@ class JSCodeGen()(using genCtx: Context) {
28632863
def abortMatch(msg: String): Nothing =
28642864
throw new FatalError(s"$msg in switch-like pattern match at ${tree.span}: $tree")
28652865

2866-
/* Although GenBCode adapts the scrutinee and the cases to `int`, only
2867-
* true `int`s can reach the back-end, as asserted by the String-switch
2868-
* transformation in `cleanup`. Therefore, we do not adapt, preserving
2869-
* the `string`s and `null`s that come out of the pattern matching in
2870-
* Scala 2.13.2+.
2871-
*/
28722866
val genSelector = genExpr(selector)
28732867

28742868
// Sanity check: we can handle Ints and Strings (including `null`s), but nothing else
@@ -2925,11 +2919,6 @@ class JSCodeGen()(using genCtx: Context) {
29252919
* When no optimization applies, and any of the case values is not a
29262920
* literal int, we emit a series of `if..else` instead of a `js.Match`.
29272921
* This became necessary in 2.13.2 with strings and nulls.
2928-
*
2929-
* Note that dotc has not adopted String-switch-Matches yet, so these code
2930-
* paths are dead code at the moment. However, they already existed in the
2931-
* scalac, so were ported, to be immediately available and working when
2932-
* dotc starts emitting switch-Matches on Strings.
29332922
*/
29342923
def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType
29352924

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import util.Property._
2020

2121
/** The pattern matching transform.
2222
* After this phase, the only Match nodes remaining in the code are simple switches
23-
* where every pattern is an integer constant
23+
* where every pattern is an integer or string constant
2424
*/
2525
class PatternMatcher extends MiniPhase {
2626
import ast.tpd._
@@ -768,13 +768,15 @@ object PatternMatcher {
768768
(tpe isRef defn.IntClass) ||
769769
(tpe isRef defn.ByteClass) ||
770770
(tpe isRef defn.ShortClass) ||
771-
(tpe isRef defn.CharClass)
771+
(tpe isRef defn.CharClass) ||
772+
(tpe isRef defn.StringClass)
772773

773-
val seen = mutable.Set[Int]()
774+
val seen = mutable.Set[Any]()
774775

775-
def isNewIntConst(tree: Tree) = tree match {
776-
case Literal(const) if const.isIntRange && !seen.contains(const.intValue) =>
777-
seen += const.intValue
776+
def isNewSwitchableConst(tree: Tree) = tree match {
777+
case Literal(const)
778+
if (const.isIntRange || const.tag == Constants.StringTag) && !seen.contains(const.value) =>
779+
seen += const.value
778780
true
779781
case _ =>
780782
false
@@ -789,7 +791,7 @@ object PatternMatcher {
789791
val alts = List.newBuilder[Tree]
790792
def rec(innerPlan: Plan): Boolean = innerPlan match {
791793
case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail)
792-
if scrut === scrutinee && isNewIntConst(tree) =>
794+
if scrut === scrutinee && isNewSwitchableConst(tree) =>
793795
alts += tree
794796
rec(tail)
795797
case ReturnPlan(`outerLabel`) =>
@@ -809,7 +811,7 @@ object PatternMatcher {
809811

810812
def recur(plan: Plan): List[(List[Tree], Plan)] = plan match {
811813
case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail)
812-
if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) =>
814+
if scrut === scrutinee && !canFallThrough(ons) && isNewSwitchableConst(tree) =>
813815
(tree :: Nil, ons) :: recur(tail)
814816
case SeqPlan(AlternativesPlan(alts, ons), tail) =>
815817
(alts, ons) :: recur(tail)
@@ -832,29 +834,32 @@ object PatternMatcher {
832834

833835
/** Emit a switch-match */
834836
private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = {
835-
/* Make sure to adapt the scrutinee to Int, as well as all the alternatives
836-
* of all cases, so that only Matches on pritimive Ints survive this phase.
837+
/* Make sure to adapt the scrutinee to Int or String, as well as all the
838+
* alternatives, so that only Matches on pritimive Ints or Strings survive
839+
* this phase.
837840
*/
838841

839-
val intScrutinee =
840-
if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee
841-
else scrutinee.select(nme.toInt)
842+
val (primScrutinee, scrutineeTpe) =
843+
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
844+
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
845+
else (scrutinee.select(nme.toInt), defn.IntType)
842846

843-
def intLiteral(lit: Tree): Tree =
847+
def primLiteral(lit: Tree): Tree =
844848
val Literal(constant) = lit
845849
if (constant.tag == Constants.IntTag) lit
850+
else if (constant.tag == Constants.StringTag) lit
846851
else cpy.Literal(lit)(Constant(constant.intValue))
847852

848853
val caseDefs = cases.map { (alts, ons) =>
849854
val pat = alts match {
850-
case alt :: Nil => intLiteral(alt)
851-
case Nil => Underscore(defn.IntType) // default case
852-
case _ => Alternative(alts.map(intLiteral))
855+
case alt :: Nil => primLiteral(alt)
856+
case Nil => Underscore(scrutineeTpe) // default case
857+
case _ => Alternative(alts.map(primLiteral))
853858
}
854859
CaseDef(pat, EmptyTree, emit(ons))
855860
}
856861

857-
Match(intScrutinee, caseDefs)
862+
Match(primScrutinee, caseDefs)
858863
}
859864

860865
/** If selfCheck is `true`, used to check whether a tree gets generated twice */
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2
2+
-1
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import annotation.switch
2+
3+
object Test {
4+
def test(s: String): Int = {
5+
(s : @switch) match {
6+
case "1" => 0
7+
case null => -1
8+
case _ => s.toInt
9+
}
10+
}
11+
12+
def main(args: Array[String]): Unit = {
13+
println(test("2"))
14+
println(test(null))
15+
}
16+
}

tests/run/string-switch.check

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
fido Success(dog)
2+
garfield Success(cat)
3+
wanda Success(fish)
4+
henry Success(horse)
5+
felix Failure(scala.MatchError: felix (of class java.lang.String))
6+
deuteronomy Success(cat)
7+
=====
8+
AaAa 2031744 Success(1)
9+
BBBB 2031744 Success(2)
10+
BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String))
11+
cCCc 3015872 Success(3)
12+
ddDd 3077408 Success(4)
13+
EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String))
14+
=====
15+
A Success(())
16+
X Failure(scala.MatchError: X (of class java.lang.String))
17+
=====
18+
Success(3)
19+
null Success(2)
20+
7 Failure(scala.MatchError: 7 (of class java.lang.String))
21+
=====
22+
pig Success(1)
23+
dog Success(2)
24+
=====
25+
Ea 2236 Success(1)
26+
FB 2236 Success(2)
27+
cC 3136 Success(3)
28+
xx 3840 Success(4)
29+
null 0 Success(4)

0 commit comments

Comments
 (0)