Skip to content

Commit 95ac1c9

Browse files
committed
use builtins; support assign and augassign to object attributes;
small cleanup on if/while
1 parent 89533eb commit 95ac1c9

File tree

1 file changed

+88
-169
lines changed

1 file changed

+88
-169
lines changed

custom_components/pyscript/eval.py

Lines changed: 88 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import asyncio
5+
import builtins
56
import importlib
67
import logging
78
import sys
@@ -11,125 +12,15 @@
1112
_LOGGER = logging.getLogger(LOGGER_PATH + ".eval")
1213

1314
#
14-
# Built-in functions available. Certain functions are excluded
15-
# to avoid potential security issues.
15+
# Built-ins to exclude to improve security or avoid i/o
1616
#
17-
BUILTIN_FUNCS = {
18-
"abs": abs,
19-
"all": all,
20-
"any": any,
21-
"ascii": ascii,
22-
"bin": bin,
23-
"bool": bool,
24-
"bytearray": bytearray,
25-
"bytearray.fromhex": bytearray.fromhex,
26-
"bytes": bytes,
27-
"bytes.fromhex": bytes.fromhex,
28-
"callable": callable,
29-
"chr": chr,
30-
"complex": complex,
31-
"dict": dict,
32-
"divmod": divmod,
33-
"enumerate": enumerate,
34-
"filter": filter,
35-
"float": float,
36-
"format": format,
37-
"frozenset": frozenset,
38-
"hash": hash,
39-
"hex": hex,
40-
"int": int,
41-
"isinstance": isinstance,
42-
"issubclass": issubclass,
43-
"iter": iter,
44-
"len": len,
45-
"list": list,
46-
"map": map,
47-
"max": max,
48-
"min": min,
49-
"next": next,
50-
"oct": oct,
51-
"ord": ord,
52-
"pow": pow,
53-
"range": range,
54-
"repr": repr,
55-
"reversed": reversed,
56-
"round": round,
57-
"set": set,
58-
"slice": slice,
59-
"sorted": sorted,
60-
"str": str,
61-
"sum": sum,
62-
"tuple": tuple,
63-
"type": type,
64-
"zip": zip,
65-
}
66-
67-
68-
BUILTIN_EXCEPTIONS = {
69-
"BaseException": BaseException,
70-
"SystemExit": SystemExit,
71-
"KeyboardInterrupt": KeyboardInterrupt,
72-
"GeneratorExit": GeneratorExit,
73-
"Exception": Exception,
74-
"StopIteration": StopIteration,
75-
"StopAsyncIteration": StopAsyncIteration,
76-
"ArithmeticError": ArithmeticError,
77-
"FloatingPointError": FloatingPointError,
78-
"OverflowError": OverflowError,
79-
"ZeroDivisionError": ZeroDivisionError,
80-
"AssertionError": AssertionError,
81-
"AttributeError": AttributeError,
82-
"BufferError": BufferError,
83-
"EOFError": EOFError,
84-
"ImportError": ImportError,
85-
"ModuleNotFoundError": ModuleNotFoundError,
86-
"LookupError": LookupError,
87-
"IndexError": IndexError,
88-
"KeyError": KeyError,
89-
"MemoryError": MemoryError,
90-
"NameError": NameError,
91-
"UnboundLocalError": UnboundLocalError,
92-
"OSError": OSError,
93-
"BlockingIOError": BlockingIOError,
94-
"ChildProcessError": ChildProcessError,
95-
"ConnectionError": ConnectionError,
96-
"BrokenPipeError": BrokenPipeError,
97-
"ConnectionAbortedError": ConnectionAbortedError,
98-
"ConnectionRefusedError": ConnectionRefusedError,
99-
"ConnectionResetError": ConnectionResetError,
100-
"FileExistsError": FileExistsError,
101-
"FileNotFoundError": FileNotFoundError,
102-
"InterruptedError": InterruptedError,
103-
"IsADirectoryError": IsADirectoryError,
104-
"NotADirectoryError": NotADirectoryError,
105-
"PermissionError": PermissionError,
106-
"ProcessLookupError": ProcessLookupError,
107-
"TimeoutError": TimeoutError,
108-
"ReferenceError": ReferenceError,
109-
"RuntimeError": RuntimeError,
110-
"NotImplementedError": NotImplementedError,
111-
"RecursionError": RecursionError,
112-
"SyntaxError": SyntaxError,
113-
"IndentationError": IndentationError,
114-
"TabError": TabError,
115-
"SystemError": SystemError,
116-
"TypeError": TypeError,
117-
"ValueError": ValueError,
118-
"UnicodeError": UnicodeError,
119-
"UnicodeDecodeError": UnicodeDecodeError,
120-
"UnicodeEncodeError": UnicodeEncodeError,
121-
"UnicodeTranslateError": UnicodeTranslateError,
122-
"Warning": Warning,
123-
"DeprecationWarning": DeprecationWarning,
124-
"PendingDeprecationWarning": PendingDeprecationWarning,
125-
"RuntimeWarning": RuntimeWarning,
126-
"SyntaxWarning": SyntaxWarning,
127-
"UserWarning": UserWarning,
128-
"FutureWarning": FutureWarning,
129-
"ImportWarning": ImportWarning,
130-
"UnicodeWarning": UnicodeWarning,
131-
"BytesWarning": BytesWarning,
132-
"ResourceWarning": ResourceWarning,
17+
BUILTIN_EXCLUDE = {
18+
"breakpoint",
19+
"compile",
20+
"input",
21+
"memoryview",
22+
"open",
23+
"print",
13324
}
13425

13526

@@ -217,7 +108,8 @@ async def locals_func():
217108

218109

219110
#
220-
# Objects returned by return, break and continue statements that change execution flow
111+
# Objects returned by return, break and continue statements that change execution flow,
112+
# or objects returned that capture particular information
221113
#
222114
class EvalStopFlow:
223115
"""Denotes a statement or action that stops execution flow, eg: return, break etc."""
@@ -251,6 +143,23 @@ def __getattr__(self, attr):
251143
raise NameError(f"name '{self.name}.{attr}' is not defined")
252144

253145

146+
class EvalAttrSet:
147+
"""Class for object and attribute on lhs of assignment."""
148+
149+
def __init__(self, obj, attr):
150+
"""Initialize identifier to name."""
151+
self.obj = obj
152+
self.attr = attr
153+
154+
def setattr(self, value):
155+
"""Set the attribute value."""
156+
setattr(self.obj, self.attr, value)
157+
158+
def getattr(self):
159+
"""Get the attribute value."""
160+
return getattr(self.obj, self.attr)
161+
162+
254163
class EvalFunc:
255164
"""Class for a callable pyscript function."""
256165

@@ -525,11 +434,9 @@ async def ast_for(self, arg):
525434
break
526435
if isinstance(val, EvalBreak):
527436
break
528-
if isinstance(val, EvalContinue):
529-
continue
530437
if isinstance(val, EvalReturn):
531438
return val
532-
if not isinstance(val, EvalBreak):
439+
else:
533440
for arg1 in arg.orelse:
534441
val = await self.aeval(arg1)
535442
if isinstance(val, EvalReturn):
@@ -538,21 +445,16 @@ async def ast_for(self, arg):
538445

539446
async def ast_while(self, arg):
540447
"""Execute while statement."""
541-
while 1:
542-
val = await self.aeval(arg.test)
543-
if not val:
544-
break
448+
while await self.aeval(arg.test):
545449
for arg1 in arg.body:
546450
val = await self.aeval(arg1)
547451
if isinstance(val, EvalStopFlow):
548452
break
549453
if isinstance(val, EvalBreak):
550454
break
551-
if isinstance(val, EvalContinue):
552-
continue
553455
if isinstance(val, EvalReturn):
554456
return val
555-
if not isinstance(val, EvalBreak):
457+
else:
556458
for arg1 in arg.orelse:
557459
val = await self.aeval(arg1)
558460
if isinstance(val, EvalReturn):
@@ -566,7 +468,6 @@ async def ast_try(self, arg):
566468
val = await self.aeval(arg1)
567469
if isinstance(val, EvalStopFlow):
568470
return val
569-
print(f"exception_obj = {self.exception_obj}")
570471
if self.exception_obj is not None:
571472
raise self.exception_obj # pylint: disable=raising-bad-type
572473
except Exception as err: # pylint: disable=broad-except
@@ -675,22 +576,23 @@ async def recurse_assign(self, lhs, val):
675576
var[slice(lower, upper, step)] = val
676577
else:
677578
var_name = await self.aeval(lhs)
579+
if isinstance(var_name, EvalAttrSet):
580+
var_name.setattr(val)
581+
return
678582
if var_name.find(".") >= 0:
679583
self.state.set(var_name, val)
680-
else:
681-
if self.curr_func and var_name in self.curr_func.global_names:
682-
self.global_sym_table[var_name] = val
683-
elif self.curr_func and var_name in self.curr_func.nonlocal_names:
684-
for sym_table in reversed(self.sym_table_stack[1:]):
685-
if var_name in sym_table:
686-
sym_table[var_name] = val
687-
break
688-
else:
689-
raise TypeError(
690-
f"can't find nonlocal '{var_name}' for assignment"
691-
)
584+
return
585+
if self.curr_func and var_name in self.curr_func.global_names:
586+
self.global_sym_table[var_name] = val
587+
return
588+
if self.curr_func and var_name in self.curr_func.nonlocal_names:
589+
for sym_table in reversed(self.sym_table_stack[1:]):
590+
if var_name in sym_table:
591+
sym_table[var_name] = val
592+
return
692593
else:
693-
self.sym_table[var_name] = val
594+
raise TypeError(f"can't find nonlocal '{var_name}' for assignment")
595+
self.sym_table[var_name] = val
694596

695597
async def ast_assign(self, arg):
696598
"""Execute assignment statement."""
@@ -699,24 +601,36 @@ async def ast_assign(self, arg):
699601
async def ast_augassign(self, arg):
700602
"""Execute augmented assignment statement (lhs <BinOp>= value)."""
701603
var_name = await self.aeval(arg.target)
702-
val = await self.aeval(
703-
ast.BinOp(
704-
left=ast.Name(id=var_name, ctx=ast.Load()), op=arg.op, right=arg.value
604+
if isinstance(var_name, EvalAttrSet):
605+
val = await self.aeval(
606+
ast.BinOp(
607+
left=ast.Constant(value=var_name.getattr()),
608+
op=arg.op,
609+
right=arg.value,
610+
)
705611
)
706-
)
707-
if self.curr_func and var_name in self.curr_func.global_names:
708-
self.global_sym_table[var_name] = val
709-
elif self.curr_func and var_name in self.curr_func.nonlocal_names:
710-
for sym_table in reversed(self.sym_table_stack[1:]):
711-
if var_name in sym_table:
712-
sym_table[var_name] = val
713-
break
714-
else:
715-
raise TypeError(f"can't find nonlocal '{var_name}' for assignment")
716-
elif self.state.exist(var_name):
717-
self.state.set(var_name, val)
612+
var_name.setattr(val)
718613
else:
719-
self.sym_table[var_name] = val
614+
val = await self.aeval(
615+
ast.BinOp(
616+
left=ast.Name(id=var_name, ctx=ast.Load()),
617+
op=arg.op,
618+
right=arg.value,
619+
)
620+
)
621+
if self.curr_func and var_name in self.curr_func.global_names:
622+
self.global_sym_table[var_name] = val
623+
elif self.curr_func and var_name in self.curr_func.nonlocal_names:
624+
for sym_table in reversed(self.sym_table_stack[1:]):
625+
if var_name in sym_table:
626+
sym_table[var_name] = val
627+
break
628+
else:
629+
raise TypeError(f"can't find nonlocal '{var_name}' for assignment")
630+
elif self.state.exist(var_name):
631+
self.state.set(var_name, val)
632+
else:
633+
self.sym_table[var_name] = val
720634

721635
async def ast_delete(self, arg):
722636
"""Execute del statement."""
@@ -771,11 +685,10 @@ async def ast_attribute_collapse(self, arg): # pylint: disable=no-self-use
771685
val = val.value
772686
if isinstance(val, ast.Name):
773687
name = val.id + "." + name
774-
if isinstance(arg.ctx, ast.Load):
775-
# ensure the first portion of name is undefined
776-
val = await self.ast_name(ast.Name(id=val.id, ctx=arg.ctx))
777-
if not isinstance(val, EvalName):
778-
return None
688+
# ensure the first portion of name is undefined
689+
val = await self.ast_name(ast.Name(id=val.id, ctx=ast.Load()))
690+
if not isinstance(val, EvalName):
691+
return None
779692
return name
780693
return None
781694

@@ -789,6 +702,8 @@ async def ast_attribute(self, arg):
789702
if not isinstance(val, EvalName):
790703
return val
791704
val = await self.aeval(arg.value)
705+
if isinstance(arg.ctx, ast.Store):
706+
return EvalAttrSet(val, arg.attr)
792707
return getattr(val, arg.attr)
793708

794709
async def ast_name(self, arg):
@@ -815,10 +730,14 @@ async def ast_name(self, arg):
815730
return self.local_sym_table[arg.id]
816731
if arg.id in self.global_sym_table:
817732
return self.global_sym_table[arg.id]
818-
if arg.id in BUILTIN_FUNCS:
819-
return BUILTIN_FUNCS[arg.id]
820733
if arg.id in BUILTIN_AST_FUNCS_FACTORY:
821734
return BUILTIN_AST_FUNCS_FACTORY[arg.id](self)
735+
if (
736+
hasattr(builtins, arg.id)
737+
and arg.id not in BUILTIN_EXCLUDE
738+
and arg.id[0] != "_"
739+
):
740+
return getattr(builtins, arg.id)
822741
if self.handler.get(arg.id):
823742
return self.handler.get(arg.id)
824743
num_dots = arg.id.count(".")
@@ -832,8 +751,6 @@ async def ast_name(self, arg):
832751
)
833752
if num_dots == 1 or (num_dots == 2 and self.state.exist(arg.id)):
834753
return self.state.get(arg.id)
835-
if arg.id in BUILTIN_EXCEPTIONS:
836-
return BUILTIN_EXCEPTIONS[arg.id]
837754
#
838755
# Couldn't find it, so return just the name wrapped in EvalName to
839756
# distinguish from a string variable value. This is to support
@@ -1288,7 +1205,9 @@ def completions(self, root):
12881205
except Exception: # pylint: disable=broad-except
12891206
pass
12901207
sym_table = BUILTIN_AST_FUNCS_FACTORY.copy()
1291-
sym_table.update(BUILTIN_FUNCS)
1208+
for name, value in builtins.__dict__.items():
1209+
if name[0] != "_" and name not in BUILTIN_EXCLUDE:
1210+
sym_table[name] = value
12921211
sym_table.update(self.global_sym_table.items())
12931212
for name, value in sym_table.items():
12941213
if name.lower().startswith(root):

0 commit comments

Comments
 (0)