diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 105d8d6..2bdc869 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -41,6 +41,17 @@ "print", } +TRIG_DECORATORS = { + "time_trigger", + "state_trigger", + "event_trigger", + "mqtt_trigger", + "state_active", + "time_active", + "task_unique", +} + +ALL_DECORATORS = TRIG_DECORATORS.union({"service"}) def ast_eval_exec_factory(ast_ctx, mode): """Generate a function that executes eval() or exec() with given ast_ctx.""" @@ -280,15 +291,7 @@ async def trigger_init(self): "event_trigger", "mqtt_trigger", } - trig_decorators = { - "time_trigger", - "state_trigger", - "event_trigger", - "mqtt_trigger", - "state_active", - "time_active", - "task_unique", - } + decorator_used = set() for dec in self.decorators: dec_name, dec_args, dec_kwargs = dec[0], dec[1], dec[2] @@ -303,7 +306,7 @@ async def trigger_init(self): decorator_used.add(dec_name) if dec_name in trig_decorators_reqd: got_reqd_dec = True - if dec_name in trig_decorators: + if dec_name in TRIG_DECORATORS: if dec_name not in trig_args: trig_args[dec_name] = {} trig_args[dec_name]["args"] = [] @@ -383,7 +386,7 @@ async def do_service_call(func, ast_ctx, data): dec_name, ) - for dec_name in trig_decorators: + for dec_name in TRIG_DECORATORS: if dec_name in trig_args and len(trig_args[dec_name]["args"]) == 0: trig_args[dec_name]["args"] = None @@ -518,21 +521,32 @@ async def eval_decorators(self, ast_ctx): self.decorators = [] code_str, code_list = ast_ctx.code_str, ast_ctx.code_list ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list + + dec_funcs = [] for dec in self.func_def.decorator_list: - if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name): - args = [] - kwargs = {} - for arg in dec.args: - args.append(await ast_ctx.aeval(arg)) - for keyw in dec.keywords: - kwargs[keyw.arg] = await ast_ctx.aeval(keyw.value) + if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in ALL_DECORATORS: + args = [await ast_ctx.aeval(arg) for arg in dec.args] + kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec.keywords} if len(kwargs) == 0: kwargs = None self.decorators.append([dec.func.id, args, kwargs]) - elif isinstance(dec, ast.Name): + elif isinstance(dec, ast.Name) and dec.id in ALL_DECORATORS: self.decorators.append([dec.id, None, None]) else: - _LOGGER.error("function %s has unexpected decorator type %s", self.name, dec) + dec_funcs.append(await ast_ctx.aeval(dec)) + + def make_dec_call(func): + async def dec_call(*args_tuple, **kwargs): + args = list(args_tuple) + if len(args) > 0 and isinstance(args[0], AstEval): + args.pop(0) + return await func(ast_ctx, *args, **kwargs) + + return dec_call + + for func in reversed(dec_funcs): + self.call = await ast_ctx.call_func(func, None, make_dec_call(self.call)) + ast_ctx.code_str, ast_ctx.code_list = code_str, code_list async def resolve_nonlocals(self, ast_ctx):