Skip to content

Commit b9781dc

Browse files
committed
fixed nonlocal scoping for deeper inner functions; also check decorators for local variables
1 parent 5e8afde commit b9781dc

File tree

3 files changed

+67
-35
lines changed

3 files changed

+67
-35
lines changed

custom_components/pyscript/eval.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,9 @@ async def resolve_nonlocals(self, ast_ctx):
563563
var_names = set(args)
564564
local_names = set(args)
565565
for stmt in self.func_def.body:
566-
self.has_closure = self.has_closure or isinstance(stmt, ast.FunctionDef)
566+
self.has_closure = self.has_closure or isinstance(
567+
stmt, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)
568+
)
567569
var_names = var_names.union(
568570
await ast_ctx.get_names(
569571
stmt, nonlocal_names=nonlocal_names, global_names=global_names, local_names=local_names,
@@ -1839,7 +1841,7 @@ async def get_target_names(self, lhs):
18391841
names.add(lhs.id)
18401842
return names
18411843

1842-
async def get_names_set(self, arg, names, nonlocal_names=None, global_names=None, local_names=None):
1844+
async def get_names_set(self, arg, names, nonlocal_names, global_names, local_names):
18431845
"""Recursively find all the names mentioned in the AST tree."""
18441846

18451847
cls_name = arg.__class__.__name__
@@ -1891,51 +1893,38 @@ async def get_names_set(self, arg, names, nonlocal_names=None, global_names=None
18911893
local_names.add(handler.name)
18921894
names.add(handler.name)
18931895
elif cls_name == "Call":
1894-
await self.get_names_set(
1895-
arg.func,
1896-
names,
1897-
nonlocal_names=nonlocal_names,
1898-
global_names=global_names,
1899-
local_names=local_names,
1900-
)
1896+
await self.get_names_set(arg.func, names, nonlocal_names, global_names, local_names)
19011897
for this_arg in arg.args:
1902-
await self.get_names_set(
1903-
this_arg,
1904-
names,
1905-
nonlocal_names=nonlocal_names,
1906-
global_names=global_names,
1907-
local_names=local_names,
1908-
)
1898+
await self.get_names_set(this_arg, names, nonlocal_names, global_names, local_names)
19091899
return
19101900
elif cls_name in {"FunctionDef", "ClassDef", "AsyncFunctionDef"}:
19111901
local_names.add(arg.name)
19121902
names.add(arg.name)
1903+
for dec in arg.decorator_list:
1904+
await self.get_names_set(dec, names, nonlocal_names, global_names, local_names)
1905+
#
1906+
# find unbound names from the body of the function or class
1907+
#
1908+
inner_global, inner_names, inner_local = set(), set(), set()
1909+
for child in arg.body:
1910+
await self.get_names_set(child, inner_names, None, inner_global, inner_local)
1911+
for name in inner_names:
1912+
if name not in inner_local and name not in inner_global:
1913+
names.add(name)
19131914
return
19141915
elif cls_name == "Delete":
19151916
for arg1 in arg.targets:
19161917
if isinstance(arg1, ast.Name):
19171918
local_names.add(arg1.id)
19181919
for child in ast.iter_child_nodes(arg):
1919-
await self.get_names_set(
1920-
child,
1921-
names,
1922-
nonlocal_names=nonlocal_names,
1923-
global_names=global_names,
1924-
local_names=local_names,
1925-
)
1920+
await self.get_names_set(child, names, nonlocal_names, global_names, local_names)
19261921

19271922
async def get_names(self, this_ast=None, nonlocal_names=None, global_names=None, local_names=None):
19281923
"""Return set of all the names mentioned in our AST tree."""
19291924
names = set()
19301925
this_ast = this_ast or self.ast
19311926
if this_ast:
1932-
await self.get_names_set(
1933-
this_ast,
1934-
names,
1935-
nonlocal_names=nonlocal_names,
1936-
global_names=global_names,
1937-
local_names=local_names,
1938-
)
1927+
await self.get_names_set(this_ast, names, nonlocal_names, global_names, local_names)
19391928
return names
19401929

19411930
def parse(self, code_str, filename=None, mode="exec"):

tests/test_decorators.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,8 @@ def func3():
133133
pyscript.done = seq_num
134134
135135
def repeat(num_times):
136+
num_times += 0
136137
def decorator_repeat(func):
137-
nonlocal num_times
138-
139138
@state_trigger("pyscript.var1 == '4'")
140139
def wrapper_repeat(*args, **kwargs):
141140
for _ in range(num_times):
@@ -167,8 +166,6 @@ def startup_test():
167166
168167
def add_state_trig(value):
169168
def dec_add_state_trig(func):
170-
nonlocal value
171-
172169
@state_trigger(f"pyscript.var1 == '{value}'")
173170
def dec_add_state_wrapper(*args, **kwargs):
174171
return func(*args, **kwargs)

tests/test_unit_eval.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,22 @@ def foo(cnt):
585585
],
586586
[
587587
"""
588+
def f1(x, y, z):
589+
def f2():
590+
y = 5
591+
def f3():
592+
nonlocal x, y
593+
x += 1
594+
y += 1
595+
return x + y + z
596+
return f3()
597+
return [x, y, z, f2()]
598+
f1(10, 20, 30)
599+
""",
600+
[10, 20, 30, 47],
601+
],
602+
[
603+
"""
588604
def twice(func):
589605
def twice_func(*args, **kwargs):
590606
func(*args, **kwargs)
@@ -624,7 +640,6 @@ def twice_func(*args, **kwargs):
624640
625641
def repeat(num_times):
626642
def decorator_repeat(func):
627-
nonlocal num_times
628643
def wrapper_repeat(*args, **kwargs):
629644
for _ in range(num_times):
630645
value = func(*args, **kwargs)
@@ -657,6 +672,37 @@ def foo2():
657672
],
658673
[
659674
"""
675+
def repeat(num_times):
676+
def decorator_repeat(func):
677+
def wrapper_repeat(*args, **kwargs):
678+
for _ in range(num_times):
679+
value = func(*args, **kwargs)
680+
return value
681+
return wrapper_repeat
682+
return decorator_repeat
683+
684+
def repeat2(num_times):
685+
def decorator_repeat(func):
686+
nonlocal num_times
687+
def wrapper_repeat(*args, **kwargs):
688+
for _ in range(num_times):
689+
value = func(*args, **kwargs)
690+
return value
691+
return wrapper_repeat
692+
return decorator_repeat
693+
694+
x = 0
695+
def func(incr):
696+
global x
697+
x += incr
698+
return x
699+
700+
[repeat(3)(func)(10), repeat2(3)(func)(20)]
701+
""",
702+
[30, 90],
703+
],
704+
[
705+
"""
660706
def foo():
661707
global f_bar
662708
def f_bar():

0 commit comments

Comments
 (0)