Skip to content

Commit b94bb59

Browse files
committed
Implement flow typing for nullability
When working with explicit nulls we often would like to do flow typing so that common idioms typecheck. e.g.: val x: String|Null = ??? if (x != null && x.length < 10) To do flow typing, we maintain a set of TermRefs that must be non-null in the current context. The set is then updated as we type conditions in if-expressions/others. We currently only handle `val`s (we don't handle `var`s). The current implementation does flow typing in 1) the branches of an if-statement 2) within a condition (as in the example above) 3) within a block: e.g. val x: String|Null = ??? if (x == null) return val y = x.length TODO: the current implementation doesn't play well with TASTy, in that we lose the flow typing info when we load trees through TASTy.
1 parent 6946371 commit b94bb59

21 files changed

+1137
-47
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import xsbti.AnalysisCallback
3737
import plugins._
3838
import java.util.concurrent.atomic.AtomicInteger
3939

40+
import dotty.tools.dotc.core.FlowTyper.FlowFacts
41+
4042
object Contexts {
4143

4244
private val (compilerCallbackLoc, store1) = Store.empty.newLocation[CompilerCallback]()
@@ -143,6 +145,11 @@ object Contexts {
143145
protected def gadt_=(gadt: GADTMap): Unit = _gadt = gadt
144146
final def gadt: GADTMap = _gadt
145147

148+
/** The terms currently known to be non-null (in spite of their declared type) */
149+
private[this] var _flowFacts: FlowFacts = _
150+
protected def flowFacts_=(flowFacts: FlowFacts): Unit = _flowFacts = flowFacts
151+
def flowFacts: FlowFacts = _flowFacts
152+
146153
/** The history of implicit searches that are currently active */
147154
private[this] var _searchHistory: SearchHistory = null
148155
protected def searchHistory_= (searchHistory: SearchHistory): Unit = _searchHistory = searchHistory
@@ -432,6 +439,7 @@ object Contexts {
432439
_typeAssigner = origin.typeAssigner
433440
_importInfo = origin.importInfo
434441
_gadt = origin.gadt
442+
_flowFacts = origin.flowFacts
435443
_searchHistory = origin.searchHistory
436444
_typeComparer = origin.typeComparer
437445
_source = origin.source
@@ -536,6 +544,11 @@ object Contexts {
536544
def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this }
537545
def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this }
538546
def setFreshGADTBounds: this.type = setGadt(gadt.fresh)
547+
def addFlowFacts(facts: FlowFacts): this.type = {
548+
assert(settings.YexplicitNulls.value)
549+
this.flowFacts ++= facts
550+
this
551+
}
539552
def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this }
540553
def setSource(source: SourceFile): this.type = { this.source = source; this }
541554
def setTypeComparerFn(tcfn: Context => TypeComparer): this.type = { this.typeComparer = tcfn(this); this }
@@ -618,6 +631,7 @@ object Contexts {
618631
typeComparer = new TypeComparer(this)
619632
searchHistory = new SearchRoot
620633
gadt = EmptyGADTMap
634+
flowFacts = FlowTyper.emptyFlowFacts
621635
}
622636

623637
@sharable object NoContext extends Context(null) {
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
package dotty.tools.dotc.core
2+
3+
import dotty.tools.dotc.ast.tpd._
4+
import StdNames.nme
5+
import dotty.tools.dotc.ast.Trees.{Apply, Block, If, Select, TypeApply}
6+
import dotty.tools.dotc.ast.tpd
7+
import dotty.tools.dotc.core.Constants.Constant
8+
import dotty.tools.dotc.core.Contexts.Context
9+
import dotty.tools.dotc.core.Names.Name
10+
import dotty.tools.dotc.core.Types.{NonNullTermRef, TermRef, Type}
11+
12+
import scala.annotation.internal.sharable
13+
14+
/** Flow-sensitive typer */
15+
object FlowTyper {
16+
17+
/** A set of `TermRef`s known to be non-null at the current program point */
18+
type FlowFacts = Set[TermRef]
19+
20+
/** The initial state where no `TermRef`s are known to be non-null */
21+
@sharable val emptyFlowFacts = Set.empty[TermRef]
22+
23+
/** Tries to improve the precision of `tpe` using flow-sensitive type information.
24+
* For nullability, is `tpe` is a `TermRef` declared as nullable but known to be non-nullable because of the
25+
* contextual info, returns the non-nullable version of the type.
26+
* If the precision of the type can't be improved, then returns the type unchanged.
27+
*/
28+
def refineType(tpe: Type)(implicit ctx: Context): Type = {
29+
assert(ctx.settings.YexplicitNulls.value)
30+
tpe match {
31+
case tref: TermRef if ctx.flowFacts.contains(tref) =>
32+
NonNullTermRef.fromTermRef(tref)
33+
case _ => tpe
34+
}
35+
}
36+
37+
/** Nullability facts inferred from a condition.
38+
* @param ifTrue are the terms known to be non-null if the condition is true.
39+
* @param ifFalse are the terms known to be non-null if the condition is false.
40+
*/
41+
case class Inferred(ifTrue: FlowFacts, ifFalse: FlowFacts) {
42+
// Let `NN(e, true/false)` be the set of terms that are non-null if `e` evaluates to `true/false`.
43+
// We can use De Morgan's laws to underapproximate `NN` via `Inferred`.
44+
// e.g. say `e = e1 && e2`. Then if `e` is `false`, we know that either `!e1` or `!e2`.
45+
// Let `t` be a term that is in both `NN(e1, false)` and `NN(e2, false)`.
46+
// Then it follows that `t` must be in `NN(e, false)`. This means that if we set
47+
// `Inferred(e1 && e2, false) = Inferred(e1, false) ∩ Inferred(e2, false)`, we'll have
48+
// `Inferred(e1 && e2, false) ⊂ NN(e1 && e2, false)` (formally, we'd do a structural induction on `e`).
49+
// This means that when we infer something we do so soundly. The methods below use this approach.
50+
51+
/** If `this` corresponds to a condition `e1` and `other` to `e2`, calculate the inferred facts for `e1 && e2`. */
52+
def combineAnd(other: Inferred): Inferred = Inferred(ifTrue.union(other.ifTrue), ifFalse.intersect(other.ifFalse))
53+
54+
/** If `this` corresponds to a condition `e1` and `other` to `e2`, calculate the inferred facts for `e1 || e2`. */
55+
def combineOr(other: Inferred): Inferred = Inferred(ifTrue.intersect(other.ifTrue), ifFalse.union(other.ifFalse))
56+
57+
/** The inferred facts for the negation of this condition. */
58+
def negate: Inferred = Inferred(ifFalse, ifTrue)
59+
}
60+
61+
object Inferred {
62+
/** Create a singleton inferred fact containing `tref`. */
63+
def apply(tref: TermRef, ifTrue: Boolean): Inferred = {
64+
if (ifTrue) Inferred(Set(tref), emptyFlowFacts)
65+
else Inferred(emptyFlowFacts, Set(tref))
66+
}
67+
}
68+
69+
/** Analyze the tree for a condition `cond` to learn new flow facts.
70+
* Supports ands, ors, and unary negation.
71+
*
72+
* Example:
73+
* (1)
74+
* ```
75+
* val x: String|Null = "foo"
76+
* if (x != null) {
77+
* // x: String in the "then" branch
78+
* }
79+
* ```
80+
* Notice that `x` must be stable for the above to work.
81+
*
82+
* Let NN(cond, true/false) be the set of paths (`TermRef`s) that we can infer to be non-null
83+
* if `cond` is true/false, respectively. Then define NN by (basically De Morgan's laws):
84+
*
85+
* NN(p == null, true) = {} we also handle `eq`
86+
* NN(p == null, false) = {p} if p is stable
87+
* NN(p != null, true) = {p} if p is stable we also handle `ne`
88+
* NN(p != null, false) = {}
89+
* NN(p.isInstanceOf[Null], true) = {}
90+
* NN(p.isInstanceOf[Null], false) = {p} if p is stable
91+
* NN(A && B, true) = ∪(NN(A, true), NN(B, true))
92+
* NN(A && B, false) = ∩(NN(A, false), NN(B, false))
93+
* NN(A || B, true) = ∩(NN(A, true), NN(B, true))
94+
* NN(A || B, false) = ∪(NN(A, false), NN(B, false))
95+
* NN(!A, true) = NN(A, false)
96+
* NN(!A, false) = NN(A, true)
97+
* NN({S1; ...; Sn, cond}, true/false) = NN(cond, true/false)
98+
* NN(cond, _) = {} otherwise
99+
*/
100+
def inferFromCond(cond: Tree)(implicit ctx: Context): Inferred = {
101+
assert(ctx.settings.YexplicitNulls.value)
102+
/** Combine two sets of facts according to `op`. */
103+
def combine(lhs: Inferred, op: Name, rhs: Inferred): Inferred = {
104+
op match {
105+
case _ if op == nme.ZAND => lhs.combineAnd(rhs)
106+
case _ if op == nme.ZOR => lhs.combineOr(rhs)
107+
}
108+
}
109+
110+
val emptyFacts = Inferred(emptyFlowFacts, emptyFlowFacts)
111+
val nullLit = tpd.Literal(Constant(null))
112+
113+
/** Recurse over a conditional to extract flow facts. */
114+
def recur(tree: Tree): Inferred = {
115+
tree match {
116+
case Apply(Select(lhs, op), List(rhs)) =>
117+
if (op == nme.ZAND || op == nme.ZOR) combine(recur(lhs), op, recur(rhs))
118+
else if (op == nme.EQ || op == nme.NE || op == nme.eq || op == nme.ne) newFact(lhs, isEq = (op == nme.EQ || op == nme.eq), rhs)
119+
else emptyFacts
120+
// TODO(abeln): handle type test with argument that's not a subtype of `Null`.
121+
// We could infer "non-null" in that case: e.g. `if (x.isInstanceOf[String]) { // x can't be null }`
122+
// case TypeApply(Select(lhs, op), List(tArg)) if op == nme.isInstanceOf_ && tArg.tpe.isNullType =>
123+
// newFact(lhs, isEq = true, nullLit)
124+
case Select(lhs, op) if op == nme.UNARY_! => recur(lhs).negate
125+
case Block(_, expr) => recur(expr)
126+
case inline: Inlined => recur(inline.expansion)
127+
case typed: Typed => recur(typed.expr) // TODO(abeln): check that the type is `Boolean`?
128+
case _ => emptyFacts
129+
}
130+
}
131+
132+
/** Extract new facts from an expression `lhs = rhs` or `lhs != rhs`
133+
* if either the lhs or rhs is the `null` literal.
134+
*/
135+
def newFact(lhs: Tree, isEq: Boolean, rhs: Tree): Inferred = {
136+
def isNullLit(tree: Tree): Boolean = tree match {
137+
case lit: Literal if lit.const.tag == Constants.NullTag => true
138+
case _ => false
139+
}
140+
141+
def isStableTermRef(tree: Tree): Boolean = asStableTermRef(tree).isDefined
142+
143+
def asStableTermRef(tree: Tree): Option[TermRef] = tree.tpe match {
144+
case tref: TermRef if tref.isStable => Some(tref)
145+
case _ => None
146+
}
147+
148+
val trefOpt =
149+
if (isNullLit(lhs) && isStableTermRef(rhs)) asStableTermRef(rhs)
150+
else if (isStableTermRef(lhs) && isNullLit(rhs)) asStableTermRef(lhs)
151+
else None
152+
153+
trefOpt match {
154+
case Some(tref) =>
155+
// If `isEq`, then the condition is of the form `lhs == null`,
156+
// in which case we know `lhs` is non-null if the condition is false.
157+
Inferred(tref, ifTrue = !isEq)
158+
case _ => emptyFacts
159+
}
160+
}
161+
162+
recur(cond)
163+
}
164+
165+
/** Infer flow-sensitive type information inside a condition.
166+
*
167+
* Specifically, if `cond` is of the form `lhs &&` or `lhs ||`, where the lhs has already been typed
168+
* (and the rhs hasn't been typed yet), compute the non-null facts that must hold so that the rhs can
169+
* execute. These facts can then be soundly assumed when typing the rhs, because boolean operators are
170+
* short-circuiting.
171+
*
172+
* This is useful in e.g.
173+
* ```
174+
* val x: String|Null = ???
175+
* if (x != null && x.length > 0) ...
176+
* ```
177+
*/
178+
def inferWithinCond(cond: Tree)(implicit ctx: Context): FlowFacts = {
179+
assert(ctx.settings.YexplicitNulls.value)
180+
cond match {
181+
case Select(lhs, op) if op == nme.ZAND || op == nme.ZOR =>
182+
val Inferred(ifTrue, ifFalse) = inferFromCond(lhs)
183+
if (op == nme.ZAND) ifTrue
184+
else ifFalse
185+
case _ => emptyFlowFacts
186+
}
187+
}
188+
189+
/** Infer flow-sensitive type information within a block.
190+
*
191+
* More precisely, if `s1; s2` are consecutive statements in a block, this returns
192+
* a context with nullability facts that hold once `s1` has executed.
193+
* The new facts can then be used to type `s2`.
194+
*
195+
* This is useful for e.g.
196+
* ```
197+
* val x: String|Null = ???
198+
* if (x == null) return "foo"
199+
* val y = x.length // x: String inferred
200+
* ```
201+
*
202+
* How can we obtain additional facts just from the fact that `s1` executed?
203+
* This can happen if `s1` is of the form `If(cond, then, else)`, where `then` or
204+
* `else` have non-local control flow.
205+
*
206+
* The following qualify as non-local:
207+
* 1) a return
208+
* 2) an expression of type `Nothing` (in particular, usages of `throw`)
209+
* 3) a block where the last expression is non-local
210+
* 4) nothing else is non-local
211+
*
212+
* So, for example, if we know that `x` must be non-null if `cond` is true, and `else` is non-local,
213+
* then in order for `s2` to execute `cond` must be true. We can thus soundly add `x` to our
214+
* flow facts.
215+
*/
216+
def inferWithinBlock(stat: Tree)(implicit ctx: Context): FlowFacts = {
217+
def isNonLocal(s: Tree): Boolean = s match {
218+
case _: Return => true
219+
case Block(_, expr) => isNonLocal(expr)
220+
case _ =>
221+
// If the type is bottom (like the result of a `throw`), then we assume the statement
222+
// won't finish executing.
223+
s.tpe.isBottomType
224+
}
225+
226+
assert(ctx.settings.YexplicitNulls.value)
227+
stat match {
228+
case If(cond, thenExpr, elseExpr) =>
229+
val Inferred(ifTrue, ifFalse) = inferFromCond(cond)
230+
if (isNonLocal(thenExpr) && isNonLocal(elseExpr)) ifTrue ++ ifFalse
231+
else if (isNonLocal(thenExpr)) ifFalse
232+
else if (isNonLocal(elseExpr)) ifTrue
233+
else emptyFlowFacts
234+
case _ => emptyFlowFacts
235+
}
236+
}
237+
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,15 +1979,25 @@ object Types {
19791979

19801980
private def computeDenot(implicit ctx: Context): Denotation = {
19811981

1982-
def finish(d: Denotation) = {
1983-
if (d.exists)
1982+
def finish(d: Denotation): Denotation = {
1983+
if (d.exists) {
1984+
val d1 = if (ctx.settings.YexplicitNulls.value && this.isInstanceOf[NonNullTermRef]) {
1985+
// If the denotation is computed for the first time, or if it's ever updated, make sure
1986+
// that the `info` is non-null.
1987+
d.mapInfo(_.stripNull)
1988+
} else {
1989+
d
1990+
}
19841991
// Avoid storing NoDenotations in the cache - we will not be able to recover from
19851992
// them. The situation might arise that a type has NoDenotation in some later
19861993
// phase but a defined denotation earlier (e.g. a TypeRef to an abstract type
19871994
// is undefined after erasure.) We need to be able to do time travel back and
19881995
// forth also in these cases.
1989-
setDenot(d)
1990-
d
1996+
setDenot(d1)
1997+
d1
1998+
} else {
1999+
d
2000+
}
19912001
}
19922002

19932003
def fromDesignator = designator match {
@@ -2444,6 +2454,29 @@ object Types {
24442454
myHash = hc
24452455
}
24462456

2457+
/** A `TermRef` that, through flow-sensitive type inference, we know is non-null.
2458+
* Accordingly, the `info` in its denotation won't be of the form `T|Null`.
2459+
* Notice that this class isn't cached, unlike the regular `TermRef`.
2460+
*/
2461+
final class NonNullTermRef(prefix: Type, designator: Designator) extends TermRef(prefix, designator) {
2462+
// There's nothing special about this class: it's just used as a marker to identify certain
2463+
// `TermRef`s in `computeDenot`.
2464+
}
2465+
2466+
object NonNullTermRef {
2467+
2468+
/** Create a `TermRef` that's just like `tref`, but whose `info` is always non-null. */
2469+
def fromTermRef(tref: TermRef)(implicit ctx: Context): NonNullTermRef = {
2470+
assert(ctx.settings.YexplicitNulls.value)
2471+
val denot = tref.denot.mapInfo(_.stripNull)
2472+
val nn = new NonNullTermRef(tref.prefix, denot.symbol)
2473+
// We need to set the non-null denotation directly because normally the "non-nullable" denotations
2474+
// are created in `computeDenot`, but they _won't_ be computed if the original `tref` _already_ had
2475+
// a cached denotation.
2476+
nn.withDenot(denot).asInstanceOf[NonNullTermRef]
2477+
}
2478+
}
2479+
24472480
final class CachedTypeRef(prefix: Type, designator: Designator, hc: Int) extends TypeRef(prefix, designator) {
24482481
assert((prefix ne NoPrefix) || designator.isInstanceOf[Symbol])
24492482
myHash = hc

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -823,19 +823,35 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
823823
typr.println(i"result failure for $tree with type ${fun1.tpe.widen}, expected = $pt")
824824

825825
/** Type application where arguments come from prototype, and no implicits are inserted */
826-
def simpleApply(fun1: Tree, proto: FunProto)(implicit ctx: Context): Tree =
827-
methPart(fun1).tpe match {
828-
case funRef: TermRef =>
829-
val app =
830-
if (proto.allArgTypesAreCurrent())
831-
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
832-
else
833-
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(argCtx(tree))
834-
convertNewGenericArray(app.result)
835-
case _ =>
836-
handleUnexpectedFunType(tree, fun1)
826+
def simpleApply(fun1: Tree, proto: FunProto)(implicit ctx: Context): Tree = {
827+
val ctx1 = if (ctx.settings.YexplicitNulls.value) {
828+
// TODO(abeln): we're re-doing work here by recomputing what's implies by the lhs of the comparison.
829+
// e.g. in `A && B && C && D`, we'll recompute the facts implied by `A && B` twice.
830+
// Find a more-efficient way to do this.
831+
val newFacts = FlowTyper.inferWithinCond(fun1)
832+
if (newFacts.isEmpty) ctx else ctx.fresh.addFlowFacts(newFacts)
833+
} else {
834+
ctx
837835
}
838836

837+
// Separate into a function so we can pass the updated context.
838+
def proc(implicit ctx: Context): tpd.Tree = {
839+
methPart(fun1).tpe match {
840+
case funRef: TermRef =>
841+
val app =
842+
if (proto.allArgTypesAreCurrent())
843+
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
844+
else
845+
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(argCtx(tree))
846+
convertNewGenericArray(app.result)
847+
case _ =>
848+
handleUnexpectedFunType(tree, fun1)
849+
}
850+
}
851+
852+
proc(ctx1)
853+
}
854+
839855
/** Try same application with an implicit inserted around the qualifier of the function
840856
* part. Return an optional value to indicate success.
841857
*/

0 commit comments

Comments
 (0)