Skip to content

Commit f571605

Browse files
committed
Implement memoization
Implement `memo(...)` function which caches its argument on first evaluation and re-uses the cached value afterwards. The cache is placed next to the method enclosing the memo(...) call. `memo` is a member of package `compiletime`.
1 parent 7cde70a commit f571605

File tree

7 files changed

+69
-4
lines changed

7 files changed

+69
-4
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ class Definitions {
237237
@threadUnsafe lazy val Compiletime_constValue : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValue"))
238238
@threadUnsafe lazy val Compiletime_constValueOpt: SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValueOpt"))
239239
@threadUnsafe lazy val Compiletime_code : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("code"))
240+
@threadUnsafe lazy val Compiletime_memo : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("memo"))
240241

241242
/** The `scalaShadowing` package is used to safely modify classes and
242243
* objects in scala so that they can be used from dotty. They will

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ object NameKinds {
213213
safePrefix + info.num
214214
}
215215

216+
def currentCount(prefix: TermName = EmptyTermName) given (ctx: Context): Int =
217+
ctx.freshNames.currentCount(prefix, this)
218+
216219
/** Generate fresh unique term name of this kind with given prefix name */
217220
def fresh(prefix: TermName = EmptyTermName)(implicit ctx: Context): TermName =
218221
ctx.freshNames.newName(prefix, this)
@@ -296,6 +299,7 @@ object NameKinds {
296299
val UniqueInlineName: UniqueNameKind = new UniqueNameKind("$i")
297300
val InlineScrutineeName: UniqueNameKind = new UniqueNameKind("$scrutinee")
298301
val InlineBinderName: UniqueNameKind = new UniqueNameKind("$elem")
302+
val MemoCacheName: UniqueNameKind = new UniqueNameKind("memo$")
299303

300304
/** A kind of unique extension methods; Unlike other unique names, these can be
301305
* unmangled.

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import StdNames._
1515
import transform.SymUtils._
1616
import Contexts.Context
1717
import Names.{Name, TermName}
18-
import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
18+
import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName, MemoCacheName}
1919
import ProtoTypes.selectionProto
2020
import SymDenotations.SymDenotation
2121
import Inferencing.fullyDefinedType
@@ -188,6 +188,19 @@ object Inliner {
188188
if (callSym.is(Macro)) ref(callSym.topLevelClass.owner).select(callSym.topLevelClass.name).withSpan(pos.span)
189189
else Ident(callSym.topLevelClass.typeRef).withSpan(pos.span)
190190
}
191+
192+
/** For every occurrence of a memo cache symbol `memo$N` of type `T_N` in `tree`,
193+
* an assignment `val memo$N: T_N = null`
194+
*/
195+
def memoCacheDefs(tree: Tree) given Context: Set[ValDef] = {
196+
val memoCacheSyms = tree.deepFold[Set[TermSymbol]](Set.empty) {
197+
(syms, t) => t match {
198+
case Assign(lhs, _) if lhs.symbol.name.is(MemoCacheName) => syms + lhs.symbol.asTerm
199+
case _ => syms
200+
}
201+
}
202+
memoCacheSyms.map(ValDef(_, Literal(Constant(null))))
203+
}
191204
}
192205

193206
/** Produces an inlined version of `call` via its `inlined` method.
@@ -392,6 +405,36 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
392405
case _ => EmptyTree
393406
}
394407

408+
/** The expansion of `memo(op)` where `op: T` is:
409+
*
410+
* { if (memo$N == null) memo$N = op; $memo.asInstanceOf[T] }
411+
*
412+
* This creates as a side effect a memo cache symbol $memo$N` of type `T | Null`.
413+
* TODO: Restrict this to non-null types, once nullability checking is in.
414+
*/
415+
def memoized: Tree = {
416+
val currentOwner = ctx.owner.skipWeakOwner
417+
if (currentOwner.isRealMethod) {
418+
val cacheOwner = ctx.owner.effectiveOwner
419+
val argType = callTypeArgs.head.tpe
420+
val memoVar = ctx.newSymbol(
421+
owner = cacheOwner,
422+
name = MemoCacheName.fresh(),
423+
flags =
424+
if (cacheOwner.isTerm) Synthetic | Mutable
425+
else Synthetic | Mutable | Private | Local,
426+
info = OrType(argType, defn.NullType),
427+
coord = call.span)
428+
val cond = If(
429+
ref(memoVar).select(defn.Any_==).appliedTo(Literal(Constant(null))),
430+
ref(memoVar).becomes(callValueArgss.head.head),
431+
Literal(Constant(())))
432+
val expr = ref(memoVar).cast(argType)
433+
Block(cond :: Nil, expr)
434+
}
435+
else errorTree(call, em"""memo(...) outside method""")
436+
}
437+
395438
/** The Inlined node representing the inlined call */
396439
def inlined(sourcePos: SourcePosition): Tree = {
397440

@@ -408,6 +451,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
408451
else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil)
409452
)
410453
}
454+
else if (inlinedMethod == defn.Compiletime_memo)
455+
return memoized
411456

412457
// Compute bindings for all parameters, appending them to bindingsBuf
413458
computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,6 +2121,7 @@ class Typer extends Namer
21212121
case Some(xtree) =>
21222122
traverse(xtree :: rest)
21232123
case none =>
2124+
val memoCacheCount = MemoCacheName.currentCount()
21242125
typed(mdef) match {
21252126
case mdef1: DefDef if Inliner.hasBodyToInline(mdef1.symbol) =>
21262127
buf += inlineExpansion(mdef1)
@@ -2131,6 +2132,8 @@ class Typer extends Namer
21312132
mdef match {
21322133
case mdef: untpd.TypeDef if mdef.mods.isEnumClass =>
21332134
enumContexts(mdef1.symbol) = ctx
2135+
case _: untpd.DefDef if MemoCacheName.currentCount() != memoCacheCount =>
2136+
buf ++= Inliner.memoCacheDefs(mdef1)
21342137
case _ =>
21352138
}
21362139
if (!mdef1.isEmpty) // clashing synthetic case methods are converted to empty trees

compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,27 @@ import core.StdNames.str
99

1010
abstract class FreshNameCreator {
1111
def newName(prefix: TermName, unique: UniqueNameKind): TermName
12+
def currentCount(prefix: TermName, unique: UniqueNameKind): Int
1213
}
1314

1415
object FreshNameCreator {
1516
class Default extends FreshNameCreator {
16-
protected var counter: Int = 0
1717
protected val counters: mutable.Map[String, Int] = mutable.AnyRefMap() withDefaultValue 0
1818

19+
private def keyFor(prefix: TermName, unique: UniqueNameKind) =
20+
str.sanitize(prefix.toString) + unique.separator
21+
22+
/** The current counter for the given combination of `prefix` and `unique` */
23+
def currentCount(prefix: TermName, unique: UniqueNameKind): Int =
24+
counters(keyFor(prefix, unique))
25+
1926
/**
2027
* Create a fresh name with the given prefix. It is guaranteed
2128
* that the returned name has never been returned by a previous
2229
* call to this function (provided the prefix does not end in a digit).
2330
*/
2431
def newName(prefix: TermName, unique: UniqueNameKind): TermName = {
25-
val key = str.sanitize(prefix.toString) + unique.separator
32+
val key = keyFor(prefix, unique)
2633
counters(key) += 1
2734
prefix.derived(unique.NumberedInfo(counters(key)))
2835
}

library/src/scala/compiletime/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,6 @@ package object compiletime {
3838
inline def constValue[T]: T = ???
3939

4040
type S[X <: Int] <: Int
41+
42+
inline def memo[T](op: => T): T = ???
4143
}

tests/run/memoTest.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
object Test extends App {
2+
import compiletime.memo
23

34
var opCache: Int | Null = null
45

@@ -7,6 +8,8 @@ object Test extends App {
78
opCache.asInstanceOf[Int] + 1
89
}
910

10-
assert(foo(1) + foo(2) == 4)
11+
def bar(x: Int) = memo(x * x) + 1
1112

13+
assert(foo(1) + foo(2) == 4)
14+
assert(bar(1) + bar(2) == 4)
1215
}

0 commit comments

Comments
 (0)