Skip to content

Commit 637a5f6

Browse files
author
EnzeXing
committed
Refactor pattern matching, skipping cases when safe to do so
1 parent 332fceb commit 637a5f6

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,12 @@ class Objects(using Context @constructorOnly):
644644
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
645645
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))
646646

647+
def remove(b: Value): Value = (a, b) match
648+
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
649+
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
650+
case (a: Ref, b: Ref) if a.equals(b) => Bottom
651+
case _ => a
652+
647653
def widen(height: Int)(using Context): Value =
648654
if height == 0 then Cold
649655
else
@@ -1386,29 +1392,25 @@ class Objects(using Context @constructorOnly):
13861392
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
13871393
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
13881394

1389-
def evalCase(caseDef: CaseDef): Value =
1390-
evalPattern(scrutinee, caseDef.pat)
1391-
eval(caseDef.guard, thisV, klass)
1392-
eval(caseDef.body, thisV, klass)
1393-
13941395
/** Abstract evaluation of patterns.
13951396
*
13961397
* It augments the local environment for bound pattern variables. As symbols are globally
13971398
* unique, we can put them in a single environment.
13981399
*
13991400
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
14001401
*/
1401-
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
1402+
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
14021403
val trace2 = Trace.trace.add(pat)
14031404
pat match
14041405
case Alternative(pats) =>
1405-
for pat <- pats do evalPattern(scrutinee, pat)
1406-
scrutinee
1406+
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
1407+
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
1408+
(orType, values.join)
14071409

14081410
case bind @ Bind(_, pat) =>
1409-
val value = evalPattern(scrutinee, pat)
1411+
val (tpe, value) = evalPattern(scrutinee, pat)
14101412
initLocal(bind.symbol, value)
1411-
scrutinee
1413+
(tpe, value)
14121414

14131415
case UnApply(fun, implicits, pats) =>
14141416
given Trace = trace2
@@ -1417,6 +1419,10 @@ class Objects(using Context @constructorOnly):
14171419
val funRef = fun1.tpe.asInstanceOf[TermRef]
14181420
val unapplyResTp = funRef.widen.finalResultType
14191421

1422+
val receiverType = fun1 match
1423+
case ident: Ident => funRef.prefix
1424+
case select: Select => select.qualifier.tpe
1425+
14201426
val receiver = fun1 match
14211427
case ident: Ident =>
14221428
evalType(funRef.prefix, thisV, klass)
@@ -1505,17 +1511,18 @@ class Objects(using Context @constructorOnly):
15051511
end if
15061512
end if
15071513
end if
1508-
scrutinee
1514+
(receiverType, scrutinee.filterType(receiverType))
15091515

15101516
case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
1511-
scrutinee
1517+
(defn.ThrowableType, scrutinee)
15121518

1513-
case Typed(pat, _) =>
1514-
evalPattern(scrutinee, pat)
1519+
case Typed(pat, typeTree) =>
1520+
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
1521+
(typeTree.tpe, value)
15151522

15161523
case tree =>
15171524
// For all other trees, the semantics is normal.
1518-
eval(tree, thisV, klass)
1525+
(defn.ThrowableType, eval(tree, thisV, klass))
15191526

15201527
end evalPattern
15211528

@@ -1539,12 +1546,12 @@ class Objects(using Context @constructorOnly):
15391546
if isWildcardStarArgList(pats) then
15401547
if pats.size == 1 then
15411548
// call .toSeq
1542-
val toSeqDenot = getMemberMethod(scrutineeType, nme.toSeq, toSeqType(elemType))
1549+
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
15431550
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
15441551
evalPattern(toSeqRes, pats.head)
15451552
else
15461553
// call .drop
1547-
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
1554+
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
15481555
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
15491556
for pat <- pats.init do evalPattern(applyRes, pat)
15501557
evalPattern(dropRes, pats.last)
@@ -1555,8 +1562,21 @@ class Objects(using Context @constructorOnly):
15551562
end if
15561563
end evalSeqPatterns
15571564

1565+
def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
1566+
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
1567+
(catchValue == Bottom && remainingScrutinee != Bottom)
15581568

1559-
cases.map(evalCase).join
1569+
var remainingScrutinee = scrutinee
1570+
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
1571+
for caseDef <- cases do
1572+
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
1573+
eval(caseDef.guard, thisV, klass)
1574+
if !canSkipCase(remainingScrutinee, value) then
1575+
caseResults.addOne(eval(caseDef.body, thisV, klass))
1576+
if catchesAllOf(caseDef, tpe) then
1577+
remainingScrutinee = remainingScrutinee.remove(value)
1578+
1579+
caseResults.join
15601580
end patternMatch
15611581

15621582
/** Handle semantics of leaf nodes

0 commit comments

Comments
 (0)