4
4
import asyncio
5
5
import builtins
6
6
import importlib
7
+ import inspect
7
8
import logging
8
9
import sys
9
10
@@ -122,14 +123,26 @@ def __init__(self, value):
122
123
"""Initialize return statement value."""
123
124
self .value = value
124
125
126
+ def name (self ):
127
+ """Return short name."""
128
+ return "return"
129
+
125
130
126
131
class EvalBreak (EvalStopFlow ):
127
132
"""Break statement."""
128
133
134
+ def name (self ):
135
+ """Return short name."""
136
+ return "break"
137
+
129
138
130
139
class EvalContinue (EvalStopFlow ):
131
140
"""Continue statement."""
132
141
142
+ def name (self ):
143
+ """Return short name."""
144
+ return "continue"
145
+
133
146
134
147
class EvalName :
135
148
"""Identifier that hasn't yet been resolved."""
@@ -380,7 +393,7 @@ async def ast_module(self, arg):
380
393
for arg1 in arg .body :
381
394
val = await self .aeval (arg1 )
382
395
if isinstance (val , EvalStopFlow ):
383
- return val
396
+ raise SyntaxError ( f" { val . name () } statement outside function" )
384
397
return val
385
398
386
399
async def ast_import (self , arg ):
@@ -461,6 +474,47 @@ async def ast_while(self, arg):
461
474
return val
462
475
return None
463
476
477
+ async def ast_classdef (self , arg ):
478
+ """Evaluate class definition."""
479
+ bases = [(await self .aeval (base )) for base in arg .bases ]
480
+ sym_table = {}
481
+ self .sym_table_stack .append (self .sym_table )
482
+ self .sym_table = sym_table
483
+ for arg1 in arg .body :
484
+ val = await self .aeval (arg1 )
485
+ if isinstance (val , EvalStopFlow ):
486
+ raise SyntaxError (f"{ val .name ()} statement outside function" )
487
+ self .sym_table = self .sym_table_stack .pop ()
488
+
489
+ for name , func in sym_table .items ():
490
+ if not isinstance (func , EvalFunc ):
491
+ continue
492
+
493
+ def class_func_factory (func ):
494
+ async def class_func_wrapper (this_self , * args , ** kwargs ):
495
+ method_args = [this_self , * args ]
496
+ return await func .call (self , method_args , kwargs )
497
+
498
+ return class_func_wrapper
499
+
500
+ sym_table [name ] = class_func_factory (func )
501
+
502
+ if "__init__" in sym_table :
503
+ sym_table ["__init__evalfunc_wrap__" ] = sym_table ["__init__" ]
504
+ del sym_table ["__init__" ]
505
+ self .sym_table [arg .name ] = type (arg .name , tuple (bases ), sym_table )
506
+
507
+ async def ast_functiondef (self , arg ):
508
+ """Evaluate function definition."""
509
+ func = EvalFunc (arg , self .code_list , self .code_str )
510
+ await func .eval_defaults (self )
511
+ await func .eval_decorators (self )
512
+ self .sym_table [func .get_name ()] = func
513
+ if self .sym_table == self .global_sym_table :
514
+ # set up any triggers if this function is in the global context
515
+ await self .global_ctx .trigger_init (func )
516
+ return None
517
+
464
518
async def ast_try (self , arg ):
465
519
"""Execute try...except statement."""
466
520
try :
@@ -534,20 +588,21 @@ async def ast_continue(self, arg):
534
588
535
589
async def ast_return (self , arg ):
536
590
"""Execute return statement - return special class."""
537
- val = await self .aeval (arg .value )
538
- return EvalReturn (val )
591
+ return EvalReturn (await self .aeval (arg .value ) if arg .value else None )
539
592
540
593
async def ast_global (self , arg ):
541
594
"""Execute global statement."""
542
- if self .curr_func :
543
- for var_name in arg .names :
544
- self .curr_func .global_names .add (var_name )
595
+ if not self .curr_func :
596
+ raise SyntaxError ("global statement outside function" )
597
+ for var_name in arg .names :
598
+ self .curr_func .global_names .add (var_name )
545
599
546
600
async def ast_nonlocal (self , arg ):
547
601
"""Execute nonlocal statement."""
548
- if self .curr_func :
549
- for var_name in arg .names :
550
- self .curr_func .nonlocal_names .add (var_name )
602
+ if not self .curr_func :
603
+ raise SyntaxError ("nonlocal statement outside function" )
604
+ for var_name in arg .names :
605
+ self .curr_func .nonlocal_names .add (var_name )
551
606
552
607
async def recurse_assign (self , lhs , val ):
553
608
"""Recursive assignment."""
@@ -579,6 +634,10 @@ async def recurse_assign(self, lhs, val):
579
634
if isinstance (var_name , EvalAttrSet ):
580
635
var_name .setattr (val )
581
636
return
637
+ if not isinstance (var_name , str ):
638
+ raise NotImplementedError (
639
+ f"unknown lhs type { lhs } (got { var_name } ) in assign"
640
+ )
582
641
if var_name .find ("." ) >= 0 :
583
642
self .state .set (var_name , val )
584
643
return
@@ -905,11 +964,9 @@ async def eval_elt_list(self, elts):
905
964
val = []
906
965
for arg in elts :
907
966
if isinstance (arg , ast .Starred ):
908
- for this_val in await self .aeval (arg .value ):
909
- val .append (this_val )
967
+ val += await self .aeval (arg .value )
910
968
else :
911
- this_val = await self .aeval (arg )
912
- val .append (this_val )
969
+ val .append (await self .aeval (arg ))
913
970
return val
914
971
915
972
async def ast_list (self , arg ):
@@ -934,10 +991,7 @@ async def ast_dict(self, arg):
934
991
935
992
async def ast_set (self , arg ):
936
993
"""Evaluate set."""
937
- val = set ()
938
- for elt in await self .eval_elt_list (arg .elts ):
939
- val .add (elt )
940
- return val
994
+ return {elt for elt in await self .eval_elt_list (arg .elts )}
941
995
942
996
async def ast_subscript (self , arg ):
943
997
"""Evaluate subscript."""
@@ -986,6 +1040,14 @@ async def ast_call(self, arg):
986
1040
func_name = arg .func .attr
987
1041
else :
988
1042
func_name = "<other>"
1043
+ if inspect .isclass (func ) and hasattr (func , "__init__evalfunc_wrap__" ):
1044
+ #
1045
+ # since our __init__ function is async, create the class instance
1046
+ # without arguments and then call the async __init__evalfunc_wrap__
1047
+ #
1048
+ inst = func ()
1049
+ await inst .__init__evalfunc_wrap__ (* args , ** kwargs )
1050
+ return inst
989
1051
if callable (func ):
990
1052
_LOGGER .debug (
991
1053
"%s: calling %s(%s, %s)" , self .name , func_name , arg_str , kwargs
@@ -995,17 +1057,6 @@ async def ast_call(self, arg):
995
1057
return func (* args , ** kwargs )
996
1058
raise NameError (f"function '{ func_name } ' is not callable (got { func } )" )
997
1059
998
- async def ast_functiondef (self , arg ):
999
- """Evaluate function definition."""
1000
- func = EvalFunc (arg , self .code_list , self .code_str )
1001
- await func .eval_defaults (self )
1002
- await func .eval_decorators (self )
1003
- self .sym_table [func .get_name ()] = func
1004
- if self .sym_table == self .global_sym_table :
1005
- # set up any triggers if this function is in the global context
1006
- await self .global_ctx .trigger_init (func )
1007
- return None
1008
-
1009
1060
async def ast_ifexp (self , arg ):
1010
1061
"""Evaluate if expression."""
1011
1062
return (
0 commit comments