|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import NamedTuple |
| 4 | + |
| 5 | +from mypy.messages import MessageBuilder |
| 6 | +from mypy.nodes import ( |
| 7 | + AssignmentStmt, |
| 8 | + ForStmt, |
| 9 | + FuncDef, |
| 10 | + FuncItem, |
| 11 | + IfStmt, |
| 12 | + ListExpr, |
| 13 | + Lvalue, |
| 14 | + NameExpr, |
| 15 | + TupleExpr, |
| 16 | + WhileStmt, |
| 17 | +) |
| 18 | +from mypy.traverser import TraverserVisitor |
| 19 | + |
| 20 | + |
| 21 | +class DefinedVars(NamedTuple): |
| 22 | + """DefinedVars contains information about variable definition at the end of a branching statement. |
| 23 | + `if` and `match` are examples of branching statements. |
| 24 | +
|
| 25 | + `may_be_defined` contains variables that were defined in only some branches. |
| 26 | + `must_be_defined` contains variables that were defined in all branches. |
| 27 | + """ |
| 28 | + |
| 29 | + may_be_defined: set[str] |
| 30 | + must_be_defined: set[str] |
| 31 | + |
| 32 | + |
| 33 | +class BranchStatement: |
| 34 | + def __init__(self, already_defined: DefinedVars) -> None: |
| 35 | + self.already_defined = already_defined |
| 36 | + self.defined_by_branch: list[DefinedVars] = [ |
| 37 | + DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined)) |
| 38 | + ] |
| 39 | + |
| 40 | + def next_branch(self) -> None: |
| 41 | + self.defined_by_branch.append( |
| 42 | + DefinedVars( |
| 43 | + may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined) |
| 44 | + ) |
| 45 | + ) |
| 46 | + |
| 47 | + def record_definition(self, name: str) -> None: |
| 48 | + assert len(self.defined_by_branch) > 0 |
| 49 | + self.defined_by_branch[-1].must_be_defined.add(name) |
| 50 | + self.defined_by_branch[-1].may_be_defined.discard(name) |
| 51 | + |
| 52 | + def record_nested_branch(self, vars: DefinedVars) -> None: |
| 53 | + assert len(self.defined_by_branch) > 0 |
| 54 | + current_branch = self.defined_by_branch[-1] |
| 55 | + current_branch.must_be_defined.update(vars.must_be_defined) |
| 56 | + current_branch.may_be_defined.update(vars.may_be_defined) |
| 57 | + current_branch.may_be_defined.difference_update(current_branch.must_be_defined) |
| 58 | + |
| 59 | + def is_possibly_undefined(self, name: str) -> bool: |
| 60 | + assert len(self.defined_by_branch) > 0 |
| 61 | + return name in self.defined_by_branch[-1].may_be_defined |
| 62 | + |
| 63 | + def done(self) -> DefinedVars: |
| 64 | + assert len(self.defined_by_branch) > 0 |
| 65 | + if len(self.defined_by_branch) == 1: |
| 66 | + # If there's only one branch, then we just return current. |
| 67 | + # Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`). |
| 68 | + return self.defined_by_branch[0] |
| 69 | + |
| 70 | + # must_be_defined is a union of must_be_defined of all branches. |
| 71 | + must_be_defined = set(self.defined_by_branch[0].must_be_defined) |
| 72 | + for branch_vars in self.defined_by_branch[1:]: |
| 73 | + must_be_defined.intersection_update(branch_vars.must_be_defined) |
| 74 | + # may_be_defined are all variables that are not must be defined. |
| 75 | + all_vars = set() |
| 76 | + for branch_vars in self.defined_by_branch: |
| 77 | + all_vars.update(branch_vars.may_be_defined) |
| 78 | + all_vars.update(branch_vars.must_be_defined) |
| 79 | + may_be_defined = all_vars.difference(must_be_defined) |
| 80 | + return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined) |
| 81 | + |
| 82 | + |
| 83 | +class DefinedVariableTracker: |
| 84 | + """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.""" |
| 85 | + |
| 86 | + def __init__(self) -> None: |
| 87 | + # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. |
| 88 | + self.scopes: list[list[BranchStatement]] = [ |
| 89 | + [BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))] |
| 90 | + ] |
| 91 | + |
| 92 | + def _scope(self) -> list[BranchStatement]: |
| 93 | + assert len(self.scopes) > 0 |
| 94 | + return self.scopes[-1] |
| 95 | + |
| 96 | + def enter_scope(self) -> None: |
| 97 | + assert len(self._scope()) > 0 |
| 98 | + self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])]) |
| 99 | + |
| 100 | + def exit_scope(self) -> None: |
| 101 | + self.scopes.pop() |
| 102 | + |
| 103 | + def start_branch_statement(self) -> None: |
| 104 | + assert len(self._scope()) > 0 |
| 105 | + self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1])) |
| 106 | + |
| 107 | + def next_branch(self) -> None: |
| 108 | + assert len(self._scope()) > 1 |
| 109 | + self._scope()[-1].next_branch() |
| 110 | + |
| 111 | + def end_branch_statement(self) -> None: |
| 112 | + assert len(self._scope()) > 1 |
| 113 | + result = self._scope().pop().done() |
| 114 | + self._scope()[-1].record_nested_branch(result) |
| 115 | + |
| 116 | + def record_declaration(self, name: str) -> None: |
| 117 | + assert len(self.scopes) > 0 |
| 118 | + assert len(self.scopes[-1]) > 0 |
| 119 | + self._scope()[-1].record_definition(name) |
| 120 | + |
| 121 | + def is_possibly_undefined(self, name: str) -> bool: |
| 122 | + assert len(self._scope()) > 0 |
| 123 | + # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`. |
| 124 | + # Cases where a variable is not defined altogether are handled by semantic analyzer. |
| 125 | + return self._scope()[-1].is_possibly_undefined(name) |
| 126 | + |
| 127 | + |
| 128 | +class PartiallyDefinedVariableVisitor(TraverserVisitor): |
| 129 | + """Detect variables that are defined only part of the time. |
| 130 | +
|
| 131 | + This visitor detects the following case: |
| 132 | + if foo(): |
| 133 | + x = 1 |
| 134 | + print(x) # Error: "x" may be undefined. |
| 135 | +
|
| 136 | + Note that this code does not detect variables not defined in any of the branches -- that is |
| 137 | + handled by the semantic analyzer. |
| 138 | + """ |
| 139 | + |
| 140 | + def __init__(self, msg: MessageBuilder) -> None: |
| 141 | + self.msg = msg |
| 142 | + self.tracker = DefinedVariableTracker() |
| 143 | + |
| 144 | + def process_lvalue(self, lvalue: Lvalue) -> None: |
| 145 | + if isinstance(lvalue, NameExpr): |
| 146 | + self.tracker.record_declaration(lvalue.name) |
| 147 | + elif isinstance(lvalue, (ListExpr, TupleExpr)): |
| 148 | + for item in lvalue.items: |
| 149 | + self.process_lvalue(item) |
| 150 | + |
| 151 | + def visit_assignment_stmt(self, o: AssignmentStmt) -> None: |
| 152 | + for lvalue in o.lvalues: |
| 153 | + self.process_lvalue(lvalue) |
| 154 | + super().visit_assignment_stmt(o) |
| 155 | + |
| 156 | + def visit_if_stmt(self, o: IfStmt) -> None: |
| 157 | + for e in o.expr: |
| 158 | + e.accept(self) |
| 159 | + self.tracker.start_branch_statement() |
| 160 | + for b in o.body: |
| 161 | + b.accept(self) |
| 162 | + self.tracker.next_branch() |
| 163 | + if o.else_body: |
| 164 | + o.else_body.accept(self) |
| 165 | + self.tracker.end_branch_statement() |
| 166 | + |
| 167 | + def visit_func_def(self, o: FuncDef) -> None: |
| 168 | + self.tracker.enter_scope() |
| 169 | + super().visit_func_def(o) |
| 170 | + self.tracker.exit_scope() |
| 171 | + |
| 172 | + def visit_func(self, o: FuncItem) -> None: |
| 173 | + if o.arguments is not None: |
| 174 | + for arg in o.arguments: |
| 175 | + self.tracker.record_declaration(arg.variable.name) |
| 176 | + super().visit_func(o) |
| 177 | + |
| 178 | + def visit_for_stmt(self, o: ForStmt) -> None: |
| 179 | + o.expr.accept(self) |
| 180 | + self.process_lvalue(o.index) |
| 181 | + o.index.accept(self) |
| 182 | + self.tracker.start_branch_statement() |
| 183 | + o.body.accept(self) |
| 184 | + self.tracker.next_branch() |
| 185 | + if o.else_body: |
| 186 | + o.else_body.accept(self) |
| 187 | + self.tracker.end_branch_statement() |
| 188 | + |
| 189 | + def visit_while_stmt(self, o: WhileStmt) -> None: |
| 190 | + o.expr.accept(self) |
| 191 | + self.tracker.start_branch_statement() |
| 192 | + o.body.accept(self) |
| 193 | + self.tracker.next_branch() |
| 194 | + if o.else_body: |
| 195 | + o.else_body.accept(self) |
| 196 | + self.tracker.end_branch_statement() |
| 197 | + |
| 198 | + def visit_name_expr(self, o: NameExpr) -> None: |
| 199 | + if self.tracker.is_possibly_undefined(o.name): |
| 200 | + self.msg.variable_may_be_undefined(o.name, o) |
| 201 | + super().visit_name_expr(o) |
0 commit comments