Skip to content

Commit 96ac8b3

Browse files
authored
[partially defined] implement support for try statements (#14114)
This adds support for try/except/finally/else check. The implementation ended up pretty complicated because it had to handle jumps different for finally. It took me a few iterations to get to this solution and that's the cleanest one I could come up with. Closes #13928.
1 parent df6e828 commit 96ac8b3

File tree

2 files changed

+295
-2
lines changed

2 files changed

+295
-2
lines changed

mypy/partially_defined.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RefExpr,
3232
ReturnStmt,
3333
StarExpr,
34+
TryStmt,
3435
TupleExpr,
3536
WhileStmt,
3637
WithStmt,
@@ -66,6 +67,13 @@ def __init__(
6667
self.must_be_defined = set(must_be_defined)
6768
self.skipped = skipped
6869

70+
def copy(self) -> BranchState:
71+
return BranchState(
72+
must_be_defined=set(self.must_be_defined),
73+
may_be_defined=set(self.may_be_defined),
74+
skipped=self.skipped,
75+
)
76+
6977

7078
class BranchStatement:
7179
def __init__(self, initial_state: BranchState) -> None:
@@ -77,6 +85,11 @@ def __init__(self, initial_state: BranchState) -> None:
7785
)
7886
]
7987

88+
def copy(self) -> BranchStatement:
89+
result = BranchStatement(self.initial_state)
90+
result.branches = [b.copy() for b in self.branches]
91+
return result
92+
8093
def next_branch(self) -> None:
8194
self.branches.append(
8295
BranchState(
@@ -90,6 +103,11 @@ def record_definition(self, name: str) -> None:
90103
self.branches[-1].must_be_defined.add(name)
91104
self.branches[-1].may_be_defined.discard(name)
92105

106+
def delete_var(self, name: str) -> None:
107+
assert len(self.branches) > 0
108+
self.branches[-1].must_be_defined.discard(name)
109+
self.branches[-1].may_be_defined.discard(name)
110+
93111
def record_nested_branch(self, state: BranchState) -> None:
94112
assert len(self.branches) > 0
95113
current_branch = self.branches[-1]
@@ -151,6 +169,11 @@ def __init__(self, stmts: list[BranchStatement]) -> None:
151169
self.branch_stmts: list[BranchStatement] = stmts
152170
self.undefined_refs: dict[str, set[NameExpr]] = {}
153171

172+
def copy(self) -> Scope:
173+
result = Scope([s.copy() for s in self.branch_stmts])
174+
result.undefined_refs = self.undefined_refs.copy()
175+
return result
176+
154177
def record_undefined_ref(self, o: NameExpr) -> None:
155178
if o.name not in self.undefined_refs:
156179
self.undefined_refs[o.name] = set()
@@ -166,6 +189,15 @@ class DefinedVariableTracker:
166189
def __init__(self) -> None:
167190
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
168191
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())])]
192+
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
193+
# in things like try/except/finally statements.
194+
self.disable_branch_skip = False
195+
196+
def copy(self) -> DefinedVariableTracker:
197+
result = DefinedVariableTracker()
198+
result.scopes = [s.copy() for s in self.scopes]
199+
result.disable_branch_skip = self.disable_branch_skip
200+
return result
169201

170202
def _scope(self) -> Scope:
171203
assert len(self.scopes) > 0
@@ -195,14 +227,19 @@ def end_branch_statement(self) -> None:
195227

196228
def skip_branch(self) -> None:
197229
# Only skip branch if we're outside of "root" branch statement.
198-
if len(self._scope().branch_stmts) > 1:
230+
if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
199231
self._scope().branch_stmts[-1].skip_branch()
200232

201233
def record_definition(self, name: str) -> None:
202234
assert len(self.scopes) > 0
203235
assert len(self.scopes[-1].branch_stmts) > 0
204236
self._scope().branch_stmts[-1].record_definition(name)
205237

238+
def delete_var(self, name: str) -> None:
239+
assert len(self.scopes) > 0
240+
assert len(self.scopes[-1].branch_stmts) > 0
241+
self._scope().branch_stmts[-1].delete_var(name)
242+
206243
def record_undefined_ref(self, o: NameExpr) -> None:
207244
"""Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
208245
assert len(self.scopes) > 0
@@ -268,6 +305,7 @@ def __init__(
268305
self.type_map = type_map
269306
self.options = options
270307
self.loops: list[Loop] = []
308+
self.try_depth = 0
271309
self.tracker = DefinedVariableTracker()
272310
for name in implicit_module_attrs:
273311
self.tracker.record_definition(name)
@@ -432,6 +470,75 @@ def visit_expression_stmt(self, o: ExpressionStmt) -> None:
432470
self.tracker.skip_branch()
433471
super().visit_expression_stmt(o)
434472

473+
def visit_try_stmt(self, o: TryStmt) -> None:
474+
"""
475+
Note that finding undefined vars in `finally` requires different handling from
476+
the rest of the code. In particular, we want to disallow skipping branches due to jump
477+
statements in except/else clauses for finally but not for other cases. Imagine a case like:
478+
def f() -> int:
479+
try:
480+
x = 1
481+
except:
482+
# This jump statement needs to be handled differently depending on whether or
483+
# not we're trying to process `finally` or not.
484+
return 0
485+
finally:
486+
# `x` may be undefined here.
487+
pass
488+
# `x` is always defined here.
489+
return x
490+
"""
491+
self.try_depth += 1
492+
if o.finally_body is not None:
493+
# In order to find undefined vars in `finally`, we need to
494+
# process try/except with branch skipping disabled. However, for the rest of the code
495+
# after finally, we need to process try/except with branch skipping enabled.
496+
# Therefore, we need to process try/finally twice.
497+
# Because processing is not idempotent, we should make a copy of the tracker.
498+
old_tracker = self.tracker.copy()
499+
self.tracker.disable_branch_skip = True
500+
self.process_try_stmt(o)
501+
self.tracker = old_tracker
502+
self.process_try_stmt(o)
503+
self.try_depth -= 1
504+
505+
def process_try_stmt(self, o: TryStmt) -> None:
506+
"""
507+
Processes try statement decomposing it into the following:
508+
if ...:
509+
body
510+
else_body
511+
elif ...:
512+
except 1
513+
elif ...:
514+
except 2
515+
else:
516+
except n
517+
finally
518+
"""
519+
self.tracker.start_branch_statement()
520+
o.body.accept(self)
521+
if o.else_body is not None:
522+
o.else_body.accept(self)
523+
if len(o.handlers) > 0:
524+
assert len(o.handlers) == len(o.vars) == len(o.types)
525+
for i in range(len(o.handlers)):
526+
self.tracker.next_branch()
527+
exc_type = o.types[i]
528+
if exc_type is not None:
529+
exc_type.accept(self)
530+
var = o.vars[i]
531+
if var is not None:
532+
self.process_definition(var.name)
533+
var.accept(self)
534+
o.handlers[i].accept(self)
535+
if var is not None:
536+
self.tracker.delete_var(var.name)
537+
self.tracker.end_branch_statement()
538+
539+
if o.finally_body is not None:
540+
o.finally_body.accept(self)
541+
435542
def visit_while_stmt(self, o: WhileStmt) -> None:
436543
o.expr.accept(self)
437544
self.tracker.start_branch_statement()
@@ -478,7 +585,9 @@ def visit_name_expr(self, o: NameExpr) -> None:
478585
self.tracker.record_definition(o.name)
479586
elif self.tracker.is_defined_in_different_branch(o.name):
480587
# A variable is defined in one branch but used in a different branch.
481-
if self.loops:
588+
if self.loops or self.try_depth > 0:
589+
# If we're in a loop or in a try, we can't be sure that this variable
590+
# is undefined. Report it as "may be undefined".
482591
self.variable_may_be_undefined(o.name, o)
483592
else:
484593
self.var_used_before_def(o.name, o)

test-data/unit/check-possibly-undefined.test

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,190 @@ def f3() -> None:
525525
y = x
526526
z = x # E: Name "x" may be undefined
527527

528+
[case testTryBasic]
529+
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
530+
def f1() -> int:
531+
try:
532+
x = 1
533+
except:
534+
pass
535+
return x # E: Name "x" may be undefined
536+
537+
def f2() -> int:
538+
try:
539+
pass
540+
except:
541+
x = 1
542+
return x # E: Name "x" may be undefined
543+
544+
def f3() -> int:
545+
try:
546+
x = 1
547+
except:
548+
y = x # E: Name "x" may be undefined
549+
return x # E: Name "x" may be undefined
550+
551+
def f4() -> int:
552+
try:
553+
x = 1
554+
except:
555+
return 0
556+
return x
557+
558+
def f5() -> int:
559+
try:
560+
x = 1
561+
except:
562+
raise
563+
return x
564+
565+
def f6() -> None:
566+
try:
567+
pass
568+
except BaseException as exc:
569+
x = exc # No error.
570+
exc = BaseException()
571+
# This case is covered by the other check, not by possibly undefined check.
572+
y = exc # E: Trying to read deleted variable "exc"
573+
574+
def f7() -> int:
575+
try:
576+
if int():
577+
x = 1
578+
assert False
579+
except:
580+
pass
581+
return x # E: Name "x" may be undefined
582+
[builtins fixtures/exception.pyi]
583+
584+
[case testTryMultiExcept]
585+
# flags: --enable-error-code possibly-undefined
586+
def f1() -> int:
587+
try:
588+
x = 1
589+
except BaseException:
590+
x = 2
591+
except:
592+
x = 3
593+
return x
594+
595+
def f2() -> int:
596+
try:
597+
x = 1
598+
except BaseException:
599+
pass
600+
except:
601+
x = 3
602+
return x # E: Name "x" may be undefined
603+
[builtins fixtures/exception.pyi]
604+
605+
[case testTryFinally]
606+
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
607+
def f1() -> int:
608+
try:
609+
x = 1
610+
finally:
611+
x = 2
612+
return x
613+
614+
def f2() -> int:
615+
try:
616+
pass
617+
except:
618+
pass
619+
finally:
620+
x = 2
621+
return x
622+
623+
def f3() -> int:
624+
try:
625+
x = 1
626+
except:
627+
pass
628+
finally:
629+
y = x # E: Name "x" may be undefined
630+
return x
631+
632+
def f4() -> int:
633+
try:
634+
x = 0
635+
except BaseException:
636+
raise
637+
finally:
638+
y = x # E: Name "x" may be undefined
639+
return y
640+
641+
def f5() -> int:
642+
try:
643+
if int():
644+
x = 1
645+
else:
646+
return 0
647+
finally:
648+
pass
649+
return x # No error.
650+
651+
def f6() -> int:
652+
try:
653+
if int():
654+
x = 1
655+
else:
656+
return 0
657+
finally:
658+
a = x # E: Name "x" may be undefined
659+
return a
660+
[builtins fixtures/exception.pyi]
661+
662+
[case testTryElse]
663+
# flags: --enable-error-code possibly-undefined
664+
def f1() -> int:
665+
try:
666+
return 0
667+
except BaseException:
668+
x = 1
669+
else:
670+
x = 2
671+
finally:
672+
y = x
673+
return y
674+
675+
def f2() -> int:
676+
try:
677+
pass
678+
except:
679+
x = 1
680+
else:
681+
x = 2
682+
return x
683+
684+
def f3() -> int:
685+
try:
686+
pass
687+
except:
688+
x = 1
689+
else:
690+
pass
691+
return x # E: Name "x" may be undefined
692+
693+
def f4() -> int:
694+
try:
695+
x = 1
696+
except:
697+
x = 2
698+
else:
699+
pass
700+
return x
701+
702+
def f5() -> int:
703+
try:
704+
pass
705+
except:
706+
x = 1
707+
else:
708+
return 1
709+
return x
710+
[builtins fixtures/exception.pyi]
711+
528712
[case testNoReturn]
529713
# flags: --enable-error-code possibly-undefined
530714

0 commit comments

Comments
 (0)