31
31
RefExpr ,
32
32
ReturnStmt ,
33
33
StarExpr ,
34
+ TryStmt ,
34
35
TupleExpr ,
35
36
WhileStmt ,
36
37
WithStmt ,
@@ -66,6 +67,13 @@ def __init__(
66
67
self .must_be_defined = set (must_be_defined )
67
68
self .skipped = skipped
68
69
70
+ def copy (self ) -> BranchState :
71
+ return BranchState (
72
+ must_be_defined = set (self .must_be_defined ),
73
+ may_be_defined = set (self .may_be_defined ),
74
+ skipped = self .skipped ,
75
+ )
76
+
69
77
70
78
class BranchStatement :
71
79
def __init__ (self , initial_state : BranchState ) -> None :
@@ -77,6 +85,11 @@ def __init__(self, initial_state: BranchState) -> None:
77
85
)
78
86
]
79
87
88
+ def copy (self ) -> BranchStatement :
89
+ result = BranchStatement (self .initial_state )
90
+ result .branches = [b .copy () for b in self .branches ]
91
+ return result
92
+
80
93
def next_branch (self ) -> None :
81
94
self .branches .append (
82
95
BranchState (
@@ -90,6 +103,11 @@ def record_definition(self, name: str) -> None:
90
103
self .branches [- 1 ].must_be_defined .add (name )
91
104
self .branches [- 1 ].may_be_defined .discard (name )
92
105
106
+ def delete_var (self , name : str ) -> None :
107
+ assert len (self .branches ) > 0
108
+ self .branches [- 1 ].must_be_defined .discard (name )
109
+ self .branches [- 1 ].may_be_defined .discard (name )
110
+
93
111
def record_nested_branch (self , state : BranchState ) -> None :
94
112
assert len (self .branches ) > 0
95
113
current_branch = self .branches [- 1 ]
@@ -151,6 +169,11 @@ def __init__(self, stmts: list[BranchStatement]) -> None:
151
169
self .branch_stmts : list [BranchStatement ] = stmts
152
170
self .undefined_refs : dict [str , set [NameExpr ]] = {}
153
171
172
+ def copy (self ) -> Scope :
173
+ result = Scope ([s .copy () for s in self .branch_stmts ])
174
+ result .undefined_refs = self .undefined_refs .copy ()
175
+ return result
176
+
154
177
def record_undefined_ref (self , o : NameExpr ) -> None :
155
178
if o .name not in self .undefined_refs :
156
179
self .undefined_refs [o .name ] = set ()
@@ -166,6 +189,15 @@ class DefinedVariableTracker:
166
189
def __init__ (self ) -> None :
167
190
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
168
191
self .scopes : list [Scope ] = [Scope ([BranchStatement (BranchState ())])]
192
+ # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
193
+ # in things like try/except/finally statements.
194
+ self .disable_branch_skip = False
195
+
196
+ def copy (self ) -> DefinedVariableTracker :
197
+ result = DefinedVariableTracker ()
198
+ result .scopes = [s .copy () for s in self .scopes ]
199
+ result .disable_branch_skip = self .disable_branch_skip
200
+ return result
169
201
170
202
def _scope (self ) -> Scope :
171
203
assert len (self .scopes ) > 0
@@ -195,14 +227,19 @@ def end_branch_statement(self) -> None:
195
227
196
228
def skip_branch (self ) -> None :
197
229
# Only skip branch if we're outside of "root" branch statement.
198
- if len (self ._scope ().branch_stmts ) > 1 :
230
+ if len (self ._scope ().branch_stmts ) > 1 and not self . disable_branch_skip :
199
231
self ._scope ().branch_stmts [- 1 ].skip_branch ()
200
232
201
233
def record_definition (self , name : str ) -> None :
202
234
assert len (self .scopes ) > 0
203
235
assert len (self .scopes [- 1 ].branch_stmts ) > 0
204
236
self ._scope ().branch_stmts [- 1 ].record_definition (name )
205
237
238
+ def delete_var (self , name : str ) -> None :
239
+ assert len (self .scopes ) > 0
240
+ assert len (self .scopes [- 1 ].branch_stmts ) > 0
241
+ self ._scope ().branch_stmts [- 1 ].delete_var (name )
242
+
206
243
def record_undefined_ref (self , o : NameExpr ) -> None :
207
244
"""Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
208
245
assert len (self .scopes ) > 0
@@ -268,6 +305,7 @@ def __init__(
268
305
self .type_map = type_map
269
306
self .options = options
270
307
self .loops : list [Loop ] = []
308
+ self .try_depth = 0
271
309
self .tracker = DefinedVariableTracker ()
272
310
for name in implicit_module_attrs :
273
311
self .tracker .record_definition (name )
@@ -432,6 +470,75 @@ def visit_expression_stmt(self, o: ExpressionStmt) -> None:
432
470
self .tracker .skip_branch ()
433
471
super ().visit_expression_stmt (o )
434
472
473
+ def visit_try_stmt (self , o : TryStmt ) -> None :
474
+ """
475
+ Note that finding undefined vars in `finally` requires different handling from
476
+ the rest of the code. In particular, we want to disallow skipping branches due to jump
477
+ statements in except/else clauses for finally but not for other cases. Imagine a case like:
478
+ def f() -> int:
479
+ try:
480
+ x = 1
481
+ except:
482
+ # This jump statement needs to be handled differently depending on whether or
483
+ # not we're trying to process `finally` or not.
484
+ return 0
485
+ finally:
486
+ # `x` may be undefined here.
487
+ pass
488
+ # `x` is always defined here.
489
+ return x
490
+ """
491
+ self .try_depth += 1
492
+ if o .finally_body is not None :
493
+ # In order to find undefined vars in `finally`, we need to
494
+ # process try/except with branch skipping disabled. However, for the rest of the code
495
+ # after finally, we need to process try/except with branch skipping enabled.
496
+ # Therefore, we need to process try/finally twice.
497
+ # Because processing is not idempotent, we should make a copy of the tracker.
498
+ old_tracker = self .tracker .copy ()
499
+ self .tracker .disable_branch_skip = True
500
+ self .process_try_stmt (o )
501
+ self .tracker = old_tracker
502
+ self .process_try_stmt (o )
503
+ self .try_depth -= 1
504
+
505
+ def process_try_stmt (self , o : TryStmt ) -> None :
506
+ """
507
+ Processes try statement decomposing it into the following:
508
+ if ...:
509
+ body
510
+ else_body
511
+ elif ...:
512
+ except 1
513
+ elif ...:
514
+ except 2
515
+ else:
516
+ except n
517
+ finally
518
+ """
519
+ self .tracker .start_branch_statement ()
520
+ o .body .accept (self )
521
+ if o .else_body is not None :
522
+ o .else_body .accept (self )
523
+ if len (o .handlers ) > 0 :
524
+ assert len (o .handlers ) == len (o .vars ) == len (o .types )
525
+ for i in range (len (o .handlers )):
526
+ self .tracker .next_branch ()
527
+ exc_type = o .types [i ]
528
+ if exc_type is not None :
529
+ exc_type .accept (self )
530
+ var = o .vars [i ]
531
+ if var is not None :
532
+ self .process_definition (var .name )
533
+ var .accept (self )
534
+ o .handlers [i ].accept (self )
535
+ if var is not None :
536
+ self .tracker .delete_var (var .name )
537
+ self .tracker .end_branch_statement ()
538
+
539
+ if o .finally_body is not None :
540
+ o .finally_body .accept (self )
541
+
435
542
def visit_while_stmt (self , o : WhileStmt ) -> None :
436
543
o .expr .accept (self )
437
544
self .tracker .start_branch_statement ()
@@ -478,7 +585,9 @@ def visit_name_expr(self, o: NameExpr) -> None:
478
585
self .tracker .record_definition (o .name )
479
586
elif self .tracker .is_defined_in_different_branch (o .name ):
480
587
# A variable is defined in one branch but used in a different branch.
481
- if self .loops :
588
+ if self .loops or self .try_depth > 0 :
589
+ # If we're in a loop or in a try, we can't be sure that this variable
590
+ # is undefined. Report it as "may be undefined".
482
591
self .variable_may_be_undefined (o .name , o )
483
592
else :
484
593
self .var_used_before_def (o .name , o )
0 commit comments