Skip to content

Commit e6156fe

Browse files
author
Lucy Martin
committed
Preventing compilation of a @tailrec method when it does not rewrite, but an inner method does
Adding warnings if there is an annotated def at the top level that is referenced from an inner def
1 parent adf089b commit e6156fe

File tree

5 files changed

+58
-2
lines changed

5 files changed

+58
-2
lines changed

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
208208
case UnstableInlineAccessorID // errorNumber: 192
209209
case VolatileOnValID // errorNumber: 193
210210
case ExtensionNullifiedByMemberID // errorNumber: 194
211+
case TailrecNestedCallID //errorNumber: 195
211212

212213
def errorNumber = ordinal - 1
213214

compiler/src/dotty/tools/dotc/reporting/messages.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context)
19071907
def explain(using Context) = ""
19081908
}
19091909

1910+
class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context)
1911+
extends SyntaxMsg(TailrecNestedCallID) {
1912+
def msg(using Context) = {
1913+
s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}"
1914+
}
1915+
1916+
def explain(using Context) =
1917+
"""Tail recursion is only validated and optimised directly in the definition
1918+
|any calls to the recursive method via an inner def cannot be validated as
1919+
|tail recursive, nor optimised if they are. To enable tail recursion from
1920+
|inner calls, mark the inner def as inline
1921+
|""".stripMargin
1922+
}
1923+
19101924
class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context)
19111925
extends Message(FailureToEliminateExistentialID) {
19121926
def kind = MessageKind.Compatibility

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,20 @@ class TailRec extends MiniPhase {
429429
assert(false, "We should never have gotten inside a pattern")
430430
tree
431431

432-
case tree: ValOrDefDef =>
432+
case tree: ValDef =>
433433
if (isMandatory) noTailTransform(tree.rhs)
434434
tree
435435

436+
case tree: DefDef =>
437+
if (isMandatory)
438+
// We cant tail recurse through nested definitions, so dont want to propagate to child nodes
439+
// We dont want to fail if there is a call that would recurse (as this would be a non self recurse), so dont
440+
// want to call noTailTransform
441+
// We can however warn in this case, as its likely in this situation that someone would expect a tail
442+
// recursion optimization and enabling this to optimise would be a simple case of inlining the inner method
443+
new NestedTailRecAlerter(method, tree.symbol).traverse(tree)
444+
tree
445+
436446
case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
437447
tree
438448

@@ -446,14 +456,28 @@ class TailRec extends MiniPhase {
446456

447457
case Return(expr, from) =>
448458
val fromSym = from.symbol
449-
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
459+
val inTailPosition = (!fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)) // Label returns are only tail if the label is in tail position
460+
&& (!fromSym.is(Method) || (fromSym eq method)) // Method returns are only tail if we are looking at the original method
450461
cpy.Return(tree)(transform(expr, inTailPosition), from)
451462

452463
case _ =>
453464
super.transform(tree)
454465
}
455466
}
456467
}
468+
469+
class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser {
470+
override def traverse(tree: tpd.Tree)(using Context): Unit =
471+
tree match {
472+
case a: Apply =>
473+
if (a.fun.symbol eq method) {
474+
report.warning(new TailrecNestedCall(method, inner), a.srcPos)
475+
}
476+
traverseChildren(tree)
477+
case _ =>
478+
traverseChildren(tree)
479+
}
480+
}
457481
}
458482

459483
object TailRec {

tests/neg/i20105.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@tailrec
2+
def foo(): Unit =
3+
def bar(): Unit =
4+
if (???)
5+
foo()
6+
else
7+
bar()
8+
bar()

tests/warn/i20105.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@tailrec
2+
def foo(): Unit =
3+
def bar(): Unit =
4+
if (???)
5+
foo() // warn
6+
else
7+
bar()
8+
bar()
9+
foo()

0 commit comments

Comments
 (0)