Skip to content

Commit 4cc285c

Browse files
committed
don't initialize trigger on inner function inside decorators; see #126
1 parent 9551fdb commit 4cc285c

File tree

2 files changed

+56
-20
lines changed

2 files changed

+56
-20
lines changed

custom_components/pyscript/eval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ def __init__(self, name, global_ctx, logger_name=None):
802802
self.logger = None
803803
self.set_logger_name(logger_name if logger_name is not None else self.name)
804804
self.config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})
805+
self.dec_eval_depth = 0
805806

806807
async def ast_not_implemented(self, arg, *args):
807808
"""Raise NotImplementedError exception for unimplemented AST types."""
@@ -1001,16 +1002,21 @@ async def ast_functiondef(self, arg):
10011002
await func.resolve_nonlocals(self)
10021003
name = func.get_name()
10031004
dec_trig, dec_other = await func.eval_decorators(self)
1005+
self.dec_eval_depth += 1
10041006
for dec_func in dec_other:
10051007
func = await self.call_func(dec_func, None, func)
10061008
if isinstance(func, EvalFuncVar):
10071009
func = func.remove_func()
10081010
dec_trig += func.decorators
1011+
self.dec_eval_depth -= 1
10091012
if isinstance(func, EvalFunc):
10101013
func.decorators = dec_trig
1011-
func.trigger_stop()
1012-
await func.trigger_init()
1013-
func_var = EvalFuncVar(func)
1014+
if self.dec_eval_depth == 0:
1015+
func.trigger_stop()
1016+
await func.trigger_init()
1017+
func_var = EvalFuncVar(func)
1018+
else:
1019+
func_var = EvalFuncVar(func)
10141020
else:
10151021
func_var = func
10161022

tests/test_decorators.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,11 @@ async def test_decorator_errors(hass, caplog):
7777
"""
7878
seq_num = 0
7979
80-
@time_trigger("startup")
81-
def func_startup_sync(trigger_type=None, trigger_time=None):
82-
global seq_num
83-
84-
seq_num += 1
85-
log.info(f"func_startup_sync setting pyscript.done = {seq_num}, trigger_type = {trigger_type}, trigger_time = {trigger_time}")
86-
pyscript.done = seq_num
80+
def add_startup_trig(func):
81+
@time_trigger("startup")
82+
def dec_add_startup_wrapper(*args, **kwargs):
83+
return func(*args, **kwargs)
84+
return dec_add_startup_wrapper
8785
8886
def once(func):
8987
def once_func(*args, **kwargs):
@@ -96,6 +94,16 @@ def twice_func(*args, **kwargs):
9694
return func(*args, **kwargs)
9795
return twice_func
9896
97+
@twice
98+
@add_startup_trig
99+
@twice
100+
def func_startup_sync(trigger_type=None, trigger_time=None):
101+
global seq_num
102+
103+
seq_num += 1
104+
log.info(f"func_startup_sync setting pyscript.done = {seq_num}, trigger_type = {trigger_type}, trigger_time = {trigger_time}")
105+
pyscript.done = seq_num
106+
99107
@state_trigger("pyscript.var1 == '1'")
100108
@once
101109
def func1():
@@ -143,6 +151,23 @@ def func4():
143151
seq_num += 1
144152
pyscript.done = seq_num
145153
154+
@state_trigger("pyscript.var1 == '5'")
155+
def func5(value=None):
156+
global seq_num
157+
global startup_test_save
158+
159+
seq_num += 1
160+
pyscript.done = [seq_num, int(value)]
161+
162+
@add_startup_trig
163+
def startup_test():
164+
global seq_num
165+
166+
seq_num += 1
167+
pyscript.done = [seq_num, int(value)]
168+
169+
startup_test_save = startup_test
170+
146171
def add_state_trig(value):
147172
def dec_add_state_trig(func):
148173
nonlocal value
@@ -153,24 +178,24 @@ def dec_add_state_wrapper(*args, **kwargs):
153178
return dec_add_state_wrapper
154179
return dec_add_state_trig
155180
156-
157-
@add_state_trig(5) # same as @state_trigger("pyscript.var1 == '5'")
158-
@add_state_trig(7) # same as @state_trigger("pyscript.var1 == '7'")
159-
@state_trigger("pyscript.var1 == '9'")
160-
def func5():
181+
@add_state_trig(6) # same as @state_trigger("pyscript.var1 == '6'")
182+
@add_state_trig(8) # same as @state_trigger("pyscript.var1 == '8'")
183+
@state_trigger("pyscript.var1 == '10'")
184+
def func6(value):
161185
global seq_num
162186
163187
seq_num += 1
164-
pyscript.done = seq_num
188+
pyscript.done = [seq_num, int(value)]
165189
166190
""",
167191
)
168192
seq_num = 0
169193

170-
seq_num += 1
171194
# fire event to start triggers, and handshake when they are running
172195
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
173-
assert literal_eval(await wait_until_done(notify_q)) == seq_num
196+
for _ in range(4):
197+
seq_num += 1
198+
assert literal_eval(await wait_until_done(notify_q)) == seq_num
174199

175200
hass.states.async_set("pyscript.var1", 0)
176201
hass.states.async_set("pyscript.var1", 1)
@@ -192,7 +217,12 @@ def func5():
192217
seq_num += 1
193218
assert literal_eval(await wait_until_done(notify_q)) == seq_num
194219

220+
hass.states.async_set("pyscript.var1", 5)
221+
for _ in range(2):
222+
seq_num += 1
223+
assert literal_eval(await wait_until_done(notify_q)) == [seq_num, 5]
224+
195225
for i in range(3):
196-
hass.states.async_set("pyscript.var1", 5 + 2 * i)
226+
hass.states.async_set("pyscript.var1", 6 + 2 * i)
197227
seq_num += 1
198-
assert literal_eval(await wait_until_done(notify_q)) == seq_num
228+
assert literal_eval(await wait_until_done(notify_q)) == [seq_num, 6 + 2 * i]

0 commit comments

Comments
 (0)