@@ -563,7 +563,9 @@ async def resolve_nonlocals(self, ast_ctx):
563
563
var_names = set (args )
564
564
local_names = set (args )
565
565
for stmt in self .func_def .body :
566
- self .has_closure = self .has_closure or isinstance (stmt , ast .FunctionDef )
566
+ self .has_closure = self .has_closure or isinstance (
567
+ stmt , (ast .FunctionDef , ast .ClassDef , ast .AsyncFunctionDef )
568
+ )
567
569
var_names = var_names .union (
568
570
await ast_ctx .get_names (
569
571
stmt , nonlocal_names = nonlocal_names , global_names = global_names , local_names = local_names ,
@@ -1839,7 +1841,7 @@ async def get_target_names(self, lhs):
1839
1841
names .add (lhs .id )
1840
1842
return names
1841
1843
1842
- async def get_names_set (self , arg , names , nonlocal_names = None , global_names = None , local_names = None ):
1844
+ async def get_names_set (self , arg , names , nonlocal_names , global_names , local_names ):
1843
1845
"""Recursively find all the names mentioned in the AST tree."""
1844
1846
1845
1847
cls_name = arg .__class__ .__name__
@@ -1891,51 +1893,38 @@ async def get_names_set(self, arg, names, nonlocal_names=None, global_names=None
1891
1893
local_names .add (handler .name )
1892
1894
names .add (handler .name )
1893
1895
elif cls_name == "Call" :
1894
- await self .get_names_set (
1895
- arg .func ,
1896
- names ,
1897
- nonlocal_names = nonlocal_names ,
1898
- global_names = global_names ,
1899
- local_names = local_names ,
1900
- )
1896
+ await self .get_names_set (arg .func , names , nonlocal_names , global_names , local_names )
1901
1897
for this_arg in arg .args :
1902
- await self .get_names_set (
1903
- this_arg ,
1904
- names ,
1905
- nonlocal_names = nonlocal_names ,
1906
- global_names = global_names ,
1907
- local_names = local_names ,
1908
- )
1898
+ await self .get_names_set (this_arg , names , nonlocal_names , global_names , local_names )
1909
1899
return
1910
1900
elif cls_name in {"FunctionDef" , "ClassDef" , "AsyncFunctionDef" }:
1911
1901
local_names .add (arg .name )
1912
1902
names .add (arg .name )
1903
+ for dec in arg .decorator_list :
1904
+ await self .get_names_set (dec , names , nonlocal_names , global_names , local_names )
1905
+ #
1906
+ # find unbound names from the body of the function or class
1907
+ #
1908
+ inner_global , inner_names , inner_local = set (), set (), set ()
1909
+ for child in arg .body :
1910
+ await self .get_names_set (child , inner_names , None , inner_global , inner_local )
1911
+ for name in inner_names :
1912
+ if name not in inner_local and name not in inner_global :
1913
+ names .add (name )
1913
1914
return
1914
1915
elif cls_name == "Delete" :
1915
1916
for arg1 in arg .targets :
1916
1917
if isinstance (arg1 , ast .Name ):
1917
1918
local_names .add (arg1 .id )
1918
1919
for child in ast .iter_child_nodes (arg ):
1919
- await self .get_names_set (
1920
- child ,
1921
- names ,
1922
- nonlocal_names = nonlocal_names ,
1923
- global_names = global_names ,
1924
- local_names = local_names ,
1925
- )
1920
+ await self .get_names_set (child , names , nonlocal_names , global_names , local_names )
1926
1921
1927
1922
async def get_names (self , this_ast = None , nonlocal_names = None , global_names = None , local_names = None ):
1928
1923
"""Return set of all the names mentioned in our AST tree."""
1929
1924
names = set ()
1930
1925
this_ast = this_ast or self .ast
1931
1926
if this_ast :
1932
- await self .get_names_set (
1933
- this_ast ,
1934
- names ,
1935
- nonlocal_names = nonlocal_names ,
1936
- global_names = global_names ,
1937
- local_names = local_names ,
1938
- )
1927
+ await self .get_names_set (this_ast , names , nonlocal_names , global_names , local_names )
1939
1928
return names
1940
1929
1941
1930
def parse (self , code_str , filename = None , mode = "exec" ):
0 commit comments