Skip to content

Commit 82a97f7

Browse files
authored
Implement foundation for detecting partially defined vars (#13601)
This diff lays the foundation for detecting partially defined variables. Think of the following situation: ``` if foo(): x = 1 print(x) # Error: "x" may be undefined. ``` Now, mypy will generate the error in such a case. Note that this diff is not complete. It still generates a lot of false positives. Running it on mypy itself generated 182 errors. Therefore, this feature is disabled by default and the error code must be explicitly enabled. I will implement it in multiple PRs.
1 parent 4de0caa commit 82a97f7

File tree

6 files changed

+363
-0
lines changed

6 files changed

+363
-0
lines changed

mypy/build.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
from mypy.checker import TypeChecker
4848
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
4949
from mypy.indirection import TypeIndirectionVisitor
50+
from mypy.messages import MessageBuilder
5051
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
52+
from mypy.partially_defined import PartiallyDefinedVariableVisitor
5153
from mypy.semanal import SemanticAnalyzer
5254
from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis
5355
from mypy.util import (
@@ -2335,6 +2337,15 @@ def type_check_second_pass(self) -> bool:
23352337
self.time_spent_us += time_spent_us(t0)
23362338
return result
23372339

2340+
def detect_partially_defined_vars(self) -> None:
2341+
assert self.tree is not None, "Internal error: method must be called on parsed file only"
2342+
manager = self.manager
2343+
if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED):
2344+
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
2345+
self.tree.accept(
2346+
PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
2347+
)
2348+
23382349
def finish_passes(self) -> None:
23392350
assert self.tree is not None, "Internal error: method must be called on parsed file only"
23402351
manager = self.manager
@@ -3364,6 +3375,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
33643375
graph[id].type_check_first_pass()
33653376
if not graph[id].type_checker().deferred_nodes:
33663377
unfinished_modules.discard(id)
3378+
graph[id].detect_partially_defined_vars()
33673379
graph[id].finish_passes()
33683380

33693381
while unfinished_modules:
@@ -3372,6 +3384,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
33723384
continue
33733385
if not graph[id].type_check_second_pass():
33743386
unfinished_modules.discard(id)
3387+
graph[id].detect_partially_defined_vars()
33753388
graph[id].finish_passes()
33763389
for id in stale:
33773390
graph[id].generate_unused_ignore_notes()

mypy/errorcodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def __str__(self) -> str:
124124
UNREACHABLE: Final = ErrorCode(
125125
"unreachable", "Warn about unreachable statements or expressions", "General"
126126
)
127+
PARTIALLY_DEFINED: Final[ErrorCode] = ErrorCode(
128+
"partially-defined",
129+
"Warn about variables that are defined only in some execution paths",
130+
"General",
131+
default_enabled=False,
132+
)
127133
REDUNDANT_EXPR: Final = ErrorCode(
128134
"redundant-expr", "Warn about redundant expressions", "General", default_enabled=False
129135
)

mypy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,9 @@ def invalid_keyword_var_arg(self, typ: Type, is_mapping: bool, context: Context)
12141214
def undefined_in_superclass(self, member: str, context: Context) -> None:
12151215
self.fail(f'"{member}" undefined in superclass', context)
12161216

1217+
def variable_may_be_undefined(self, name: str, context: Context) -> None:
1218+
self.fail(f'Name "{name}" may be undefined', context, code=codes.PARTIALLY_DEFINED)
1219+
12171220
def first_argument_for_super_must_be_type(self, actual: Type, context: Context) -> None:
12181221
actual = get_proper_type(actual)
12191222
if isinstance(actual, Instance):

mypy/partially_defined.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from __future__ import annotations
2+
3+
from typing import NamedTuple
4+
5+
from mypy.messages import MessageBuilder
6+
from mypy.nodes import (
7+
AssignmentStmt,
8+
ForStmt,
9+
FuncDef,
10+
FuncItem,
11+
IfStmt,
12+
ListExpr,
13+
Lvalue,
14+
NameExpr,
15+
TupleExpr,
16+
WhileStmt,
17+
)
18+
from mypy.traverser import TraverserVisitor
19+
20+
21+
class DefinedVars(NamedTuple):
22+
"""DefinedVars contains information about variable definition at the end of a branching statement.
23+
`if` and `match` are examples of branching statements.
24+
25+
`may_be_defined` contains variables that were defined in only some branches.
26+
`must_be_defined` contains variables that were defined in all branches.
27+
"""
28+
29+
may_be_defined: set[str]
30+
must_be_defined: set[str]
31+
32+
33+
class BranchStatement:
34+
def __init__(self, already_defined: DefinedVars) -> None:
35+
self.already_defined = already_defined
36+
self.defined_by_branch: list[DefinedVars] = [
37+
DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined))
38+
]
39+
40+
def next_branch(self) -> None:
41+
self.defined_by_branch.append(
42+
DefinedVars(
43+
may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined)
44+
)
45+
)
46+
47+
def record_definition(self, name: str) -> None:
48+
assert len(self.defined_by_branch) > 0
49+
self.defined_by_branch[-1].must_be_defined.add(name)
50+
self.defined_by_branch[-1].may_be_defined.discard(name)
51+
52+
def record_nested_branch(self, vars: DefinedVars) -> None:
53+
assert len(self.defined_by_branch) > 0
54+
current_branch = self.defined_by_branch[-1]
55+
current_branch.must_be_defined.update(vars.must_be_defined)
56+
current_branch.may_be_defined.update(vars.may_be_defined)
57+
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)
58+
59+
def is_possibly_undefined(self, name: str) -> bool:
60+
assert len(self.defined_by_branch) > 0
61+
return name in self.defined_by_branch[-1].may_be_defined
62+
63+
def done(self) -> DefinedVars:
64+
assert len(self.defined_by_branch) > 0
65+
if len(self.defined_by_branch) == 1:
66+
# If there's only one branch, then we just return current.
67+
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
68+
return self.defined_by_branch[0]
69+
70+
# must_be_defined is a union of must_be_defined of all branches.
71+
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
72+
for branch_vars in self.defined_by_branch[1:]:
73+
must_be_defined.intersection_update(branch_vars.must_be_defined)
74+
# may_be_defined are all variables that are not must be defined.
75+
all_vars = set()
76+
for branch_vars in self.defined_by_branch:
77+
all_vars.update(branch_vars.may_be_defined)
78+
all_vars.update(branch_vars.must_be_defined)
79+
may_be_defined = all_vars.difference(must_be_defined)
80+
return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
81+
82+
83+
class DefinedVariableTracker:
84+
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
85+
86+
def __init__(self) -> None:
87+
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
88+
self.scopes: list[list[BranchStatement]] = [
89+
[BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))]
90+
]
91+
92+
def _scope(self) -> list[BranchStatement]:
93+
assert len(self.scopes) > 0
94+
return self.scopes[-1]
95+
96+
def enter_scope(self) -> None:
97+
assert len(self._scope()) > 0
98+
self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])])
99+
100+
def exit_scope(self) -> None:
101+
self.scopes.pop()
102+
103+
def start_branch_statement(self) -> None:
104+
assert len(self._scope()) > 0
105+
self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1]))
106+
107+
def next_branch(self) -> None:
108+
assert len(self._scope()) > 1
109+
self._scope()[-1].next_branch()
110+
111+
def end_branch_statement(self) -> None:
112+
assert len(self._scope()) > 1
113+
result = self._scope().pop().done()
114+
self._scope()[-1].record_nested_branch(result)
115+
116+
def record_declaration(self, name: str) -> None:
117+
assert len(self.scopes) > 0
118+
assert len(self.scopes[-1]) > 0
119+
self._scope()[-1].record_definition(name)
120+
121+
def is_possibly_undefined(self, name: str) -> bool:
122+
assert len(self._scope()) > 0
123+
# A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
124+
# Cases where a variable is not defined altogether are handled by semantic analyzer.
125+
return self._scope()[-1].is_possibly_undefined(name)
126+
127+
128+
class PartiallyDefinedVariableVisitor(TraverserVisitor):
129+
"""Detect variables that are defined only part of the time.
130+
131+
This visitor detects the following case:
132+
if foo():
133+
x = 1
134+
print(x) # Error: "x" may be undefined.
135+
136+
Note that this code does not detect variables not defined in any of the branches -- that is
137+
handled by the semantic analyzer.
138+
"""
139+
140+
def __init__(self, msg: MessageBuilder) -> None:
141+
self.msg = msg
142+
self.tracker = DefinedVariableTracker()
143+
144+
def process_lvalue(self, lvalue: Lvalue) -> None:
145+
if isinstance(lvalue, NameExpr):
146+
self.tracker.record_declaration(lvalue.name)
147+
elif isinstance(lvalue, (ListExpr, TupleExpr)):
148+
for item in lvalue.items:
149+
self.process_lvalue(item)
150+
151+
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
152+
for lvalue in o.lvalues:
153+
self.process_lvalue(lvalue)
154+
super().visit_assignment_stmt(o)
155+
156+
def visit_if_stmt(self, o: IfStmt) -> None:
157+
for e in o.expr:
158+
e.accept(self)
159+
self.tracker.start_branch_statement()
160+
for b in o.body:
161+
b.accept(self)
162+
self.tracker.next_branch()
163+
if o.else_body:
164+
o.else_body.accept(self)
165+
self.tracker.end_branch_statement()
166+
167+
def visit_func_def(self, o: FuncDef) -> None:
168+
self.tracker.enter_scope()
169+
super().visit_func_def(o)
170+
self.tracker.exit_scope()
171+
172+
def visit_func(self, o: FuncItem) -> None:
173+
if o.arguments is not None:
174+
for arg in o.arguments:
175+
self.tracker.record_declaration(arg.variable.name)
176+
super().visit_func(o)
177+
178+
def visit_for_stmt(self, o: ForStmt) -> None:
179+
o.expr.accept(self)
180+
self.process_lvalue(o.index)
181+
o.index.accept(self)
182+
self.tracker.start_branch_statement()
183+
o.body.accept(self)
184+
self.tracker.next_branch()
185+
if o.else_body:
186+
o.else_body.accept(self)
187+
self.tracker.end_branch_statement()
188+
189+
def visit_while_stmt(self, o: WhileStmt) -> None:
190+
o.expr.accept(self)
191+
self.tracker.start_branch_statement()
192+
o.body.accept(self)
193+
self.tracker.next_branch()
194+
if o.else_body:
195+
o.else_body.accept(self)
196+
self.tracker.end_branch_statement()
197+
198+
def visit_name_expr(self, o: NameExpr) -> None:
199+
if self.tracker.is_possibly_undefined(o.name):
200+
self.msg.variable_may_be_undefined(o.name, o)
201+
super().visit_name_expr(o)

mypy/server/update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ def restore(ids: list[str]) -> None:
651651
state.type_checker().reset()
652652
state.type_check_first_pass()
653653
state.type_check_second_pass()
654+
state.detect_partially_defined_vars()
654655
t2 = time.time()
655656
state.finish_passes()
656657
t3 = time.time()

0 commit comments

Comments
 (0)