Skip to content

Commit 4333d31

Browse files
committed
added class defintion support; only has minimal tests so far
1 parent 68befa8 commit 4333d31

File tree

2 files changed

+126
-28
lines changed

2 files changed

+126
-28
lines changed

custom_components/pyscript/eval.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import builtins
66
import importlib
7+
import inspect
78
import logging
89
import sys
910

@@ -122,14 +123,26 @@ def __init__(self, value):
122123
"""Initialize return statement value."""
123124
self.value = value
124125

126+
def name(self):
127+
"""Return short name."""
128+
return "return"
129+
125130

126131
class EvalBreak(EvalStopFlow):
127132
"""Break statement."""
128133

134+
def name(self):
135+
"""Return short name."""
136+
return "break"
137+
129138

130139
class EvalContinue(EvalStopFlow):
131140
"""Continue statement."""
132141

142+
def name(self):
143+
"""Return short name."""
144+
return "continue"
145+
133146

134147
class EvalName:
135148
"""Identifier that hasn't yet been resolved."""
@@ -380,7 +393,7 @@ async def ast_module(self, arg):
380393
for arg1 in arg.body:
381394
val = await self.aeval(arg1)
382395
if isinstance(val, EvalStopFlow):
383-
return val
396+
raise SyntaxError(f"{val.name()} statement outside function")
384397
return val
385398

386399
async def ast_import(self, arg):
@@ -461,6 +474,47 @@ async def ast_while(self, arg):
461474
return val
462475
return None
463476

477+
async def ast_classdef(self, arg):
478+
"""Evaluate class definition."""
479+
bases = [(await self.aeval(base)) for base in arg.bases]
480+
sym_table = {}
481+
self.sym_table_stack.append(self.sym_table)
482+
self.sym_table = sym_table
483+
for arg1 in arg.body:
484+
val = await self.aeval(arg1)
485+
if isinstance(val, EvalStopFlow):
486+
raise SyntaxError(f"{val.name()} statement outside function")
487+
self.sym_table = self.sym_table_stack.pop()
488+
489+
for name, func in sym_table.items():
490+
if not isinstance(func, EvalFunc):
491+
continue
492+
493+
def class_func_factory(func):
494+
async def class_func_wrapper(this_self, *args, **kwargs):
495+
method_args = [this_self, *args]
496+
return await func.call(self, method_args, kwargs)
497+
498+
return class_func_wrapper
499+
500+
sym_table[name] = class_func_factory(func)
501+
502+
if "__init__" in sym_table:
503+
sym_table["__init__evalfunc_wrap__"] = sym_table["__init__"]
504+
del sym_table["__init__"]
505+
self.sym_table[arg.name] = type(arg.name, tuple(bases), sym_table)
506+
507+
async def ast_functiondef(self, arg):
508+
"""Evaluate function definition."""
509+
func = EvalFunc(arg, self.code_list, self.code_str)
510+
await func.eval_defaults(self)
511+
await func.eval_decorators(self)
512+
self.sym_table[func.get_name()] = func
513+
if self.sym_table == self.global_sym_table:
514+
# set up any triggers if this function is in the global context
515+
await self.global_ctx.trigger_init(func)
516+
return None
517+
464518
async def ast_try(self, arg):
465519
"""Execute try...except statement."""
466520
try:
@@ -534,20 +588,21 @@ async def ast_continue(self, arg):
534588

535589
async def ast_return(self, arg):
536590
"""Execute return statement - return special class."""
537-
val = await self.aeval(arg.value)
538-
return EvalReturn(val)
591+
return EvalReturn(await self.aeval(arg.value) if arg.value else None)
539592

540593
async def ast_global(self, arg):
541594
"""Execute global statement."""
542-
if self.curr_func:
543-
for var_name in arg.names:
544-
self.curr_func.global_names.add(var_name)
595+
if not self.curr_func:
596+
raise SyntaxError("global statement outside function")
597+
for var_name in arg.names:
598+
self.curr_func.global_names.add(var_name)
545599

546600
async def ast_nonlocal(self, arg):
547601
"""Execute nonlocal statement."""
548-
if self.curr_func:
549-
for var_name in arg.names:
550-
self.curr_func.nonlocal_names.add(var_name)
602+
if not self.curr_func:
603+
raise SyntaxError("nonlocal statement outside function")
604+
for var_name in arg.names:
605+
self.curr_func.nonlocal_names.add(var_name)
551606

552607
async def recurse_assign(self, lhs, val):
553608
"""Recursive assignment."""
@@ -579,6 +634,10 @@ async def recurse_assign(self, lhs, val):
579634
if isinstance(var_name, EvalAttrSet):
580635
var_name.setattr(val)
581636
return
637+
if not isinstance(var_name, str):
638+
raise NotImplementedError(
639+
f"unknown lhs type {lhs} (got {var_name}) in assign"
640+
)
582641
if var_name.find(".") >= 0:
583642
self.state.set(var_name, val)
584643
return
@@ -905,11 +964,9 @@ async def eval_elt_list(self, elts):
905964
val = []
906965
for arg in elts:
907966
if isinstance(arg, ast.Starred):
908-
for this_val in await self.aeval(arg.value):
909-
val.append(this_val)
967+
val += await self.aeval(arg.value)
910968
else:
911-
this_val = await self.aeval(arg)
912-
val.append(this_val)
969+
val.append(await self.aeval(arg))
913970
return val
914971

915972
async def ast_list(self, arg):
@@ -934,10 +991,7 @@ async def ast_dict(self, arg):
934991

935992
async def ast_set(self, arg):
936993
"""Evaluate set."""
937-
val = set()
938-
for elt in await self.eval_elt_list(arg.elts):
939-
val.add(elt)
940-
return val
994+
return {elt for elt in await self.eval_elt_list(arg.elts)}
941995

942996
async def ast_subscript(self, arg):
943997
"""Evaluate subscript."""
@@ -986,6 +1040,14 @@ async def ast_call(self, arg):
9861040
func_name = arg.func.attr
9871041
else:
9881042
func_name = "<other>"
1043+
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
1044+
#
1045+
# since our __init__ function is async, create the class instance
1046+
# without arguments and then call the async __init__evalfunc_wrap__
1047+
#
1048+
inst = func()
1049+
await inst.__init__evalfunc_wrap__(*args, **kwargs)
1050+
return inst
9891051
if callable(func):
9901052
_LOGGER.debug(
9911053
"%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs
@@ -995,17 +1057,6 @@ async def ast_call(self, arg):
9951057
return func(*args, **kwargs)
9961058
raise NameError(f"function '{func_name}' is not callable (got {func})")
9971059

998-
async def ast_functiondef(self, arg):
999-
"""Evaluate function definition."""
1000-
func = EvalFunc(arg, self.code_list, self.code_str)
1001-
await func.eval_defaults(self)
1002-
await func.eval_decorators(self)
1003-
self.sym_table[func.get_name()] = func
1004-
if self.sym_table == self.global_sym_table:
1005-
# set up any triggers if this function is in the global context
1006-
await self.global_ctx.trigger_init(func)
1007-
return None
1008-
10091060
async def ast_ifexp(self, arg):
10101061
"""Evaluate if expression."""
10111062
return (

tests/custom_components/pyscript/test_unit_eval.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@
148148
["Foo = type('Foo', (), {'x': 100}); Foo.x = 10; Foo.x", 10],
149149
["Foo = type('Foo', (), {'x': 100}); Foo.x += 10; Foo.x", 110],
150150
["Foo = [type('Foo', (), {'x': 100})]; Foo[0].x = 10; Foo[0].x", 10],
151+
[
152+
"Foo = [type('Foo', (), {'x': [100, 101]})]; Foo[0].x[1] = 10; Foo[0].x",
153+
[100, 10],
154+
],
155+
[
156+
"Foo = [type('Foo', (), {'x': [0, [[100, 101]]]})]; Foo[0].x[1][0][1] = 10; Foo[0].x[1]",
157+
[[100, 10]],
158+
],
159+
[
160+
"Foo = [type('Foo', (), {'x': [0, [[100, 101, 102, 103]]]})]; Foo[0].x[1][0][1:2] = [11, 12]; Foo[0].x[1]",
161+
[[100, 11, 12, 102, 103]],
162+
],
151163
["eval('1+2')", 3],
152164
["x = 5; eval('2 * x')", 10],
153165
["x = 5; exec('x = 2 * x'); x", 10],
@@ -168,6 +180,17 @@
168180
["from math import sqrt as sqroot\nsqroot(1024)", 32],
169181
[
170182
"""
183+
def foo(bar=6):
184+
if bar == 5:
185+
return
186+
else:
187+
return 2 * bar
188+
[foo(), foo(5), foo('xxx')]
189+
""",
190+
[12, None, "xxxxxx"],
191+
],
192+
[
193+
"""
171194
bar = 100
172195
def foo(bar=6):
173196
bar += 2
@@ -481,6 +504,30 @@ def func(exc):
481504
""",
482505
[3, 103, 203, 303],
483506
],
507+
[
508+
"""
509+
class Test:
510+
x = 10
511+
def __init__(self, value):
512+
self.y = value
513+
514+
def set_x(self, value):
515+
Test.x += 2
516+
self.x = value
517+
518+
def set_y(self, value):
519+
self.y = value
520+
521+
def get(self):
522+
return [self.x, self.y]
523+
524+
t1 = Test(20)
525+
t2 = Test(40)
526+
Test.x = 5
527+
[t1.get(), t2.get(), t1.set_x(100), t1.get(), t2.get(), Test.x]
528+
""",
529+
[[5, 20], [5, 40], None, [100, 20], [7, 40], 7],
530+
],
484531
]
485532

486533

0 commit comments

Comments
 (0)