Skip to content

Commit 8702d4e

Browse files
committed
added task.add_done_callback(), task.remove_done_callback() and tests; see #112
1 parent adc76f2 commit 8702d4e

File tree

4 files changed

+230
-17
lines changed

4 files changed

+230
-17
lines changed

custom_components/pyscript/eval.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,14 @@ def get_name(self):
762762
"""Return the function name."""
763763
return self.eval_func_var.get_name()
764764

765+
def get_ast_ctx(self):
766+
"""Return the ast context."""
767+
return self.ast_ctx
768+
769+
def get_eval_func_var(self):
770+
"""Return the eval_func_var."""
771+
return self.eval_func_var
772+
765773
async def __call__(self, *args, **kwargs):
766774
"""Call the EvalFunc function using our saved ast ctx."""
767775
return await self.eval_func_var.call(self.ast_ctx, *args, **kwargs)
@@ -1703,25 +1711,32 @@ async def ast_call(self, arg):
17031711
else:
17041712
kwargs[kw_arg.arg] = await self.aeval(kw_arg.value)
17051713
args = await self.eval_elt_list(arg.args)
1706-
arg_str = ", ".join(['"' + elt + '"' if isinstance(elt, str) else str(elt) for elt in args])
17071714
#
17081715
# try to deduce function name, although this only works in simple cases
17091716
#
1717+
func_name = None
17101718
if isinstance(arg.func, ast.Name):
17111719
func_name = arg.func.id
17121720
elif isinstance(arg.func, ast.Attribute):
17131721
func_name = arg.func.attr
1714-
else:
1715-
func_name = "<function>"
17161722
if isinstance(func, EvalLocalVar):
17171723
func_name = func.get_name()
17181724
func = func.get()
1719-
_LOGGER.debug("%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs)
17201725
return await self.call_func(func, func_name, *args, **kwargs)
17211726

17221727
async def call_func(self, func, func_name, *args, **kwargs):
17231728
"""Call a function with the given arguments."""
1724-
if isinstance(func, (EvalFuncVar, EvalFuncVarAstCtx)):
1729+
if func_name is None:
1730+
try:
1731+
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
1732+
func_name = func.get_name()
1733+
else:
1734+
func_name = func.__name__
1735+
except Exception:
1736+
func_name = "<function>"
1737+
arg_str = ", ".join(['"' + elt + '"' if isinstance(elt, str) else str(elt) for elt in args])
1738+
_LOGGER.debug("%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs)
1739+
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
17251740
return await func.call(self, *args, **kwargs)
17261741
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
17271742
inst = func()

custom_components/pyscript/function.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ class Function:
3535
#
3636
our_tasks = set()
3737

38+
#
39+
# Done callbacks for each task
40+
#
41+
task2cb = {}
42+
3843
#
3944
# initial list of available functions
4045
#
@@ -149,7 +154,6 @@ async def task_waiter(waiter_q):
149154
if not cls.task_waiter:
150155
cls.task_waiter_q = asyncio.Queue(0)
151156
cls.task_waiter = Function.create_task(task_waiter(cls.task_waiter_q))
152-
_LOGGER.debug("task_waiter: started")
153157

154158
@classmethod
155159
def reaper_cancel(cls, task):
@@ -365,7 +369,12 @@ async def run_coro(cls, coro):
365369
try:
366370
task = asyncio.current_task()
367371
cls.our_tasks.add(task)
368-
return await coro
372+
result = await coro
373+
if task in cls.task2cb:
374+
for callback, info in cls.task2cb[task]["cb"].items():
375+
ast_ctx, args, kwargs = info
376+
await ast_ctx.call_func(callback, None, *args, **kwargs)
377+
return result
369378
except asyncio.CancelledError:
370379
raise
371380
except Exception:
@@ -375,8 +384,8 @@ async def run_coro(cls, coro):
375384
for name in cls.unique_task2name[task]:
376385
del cls.unique_name2task[name]
377386
del cls.unique_task2name[task]
378-
if task in cls.task2context:
379-
del cls.task2context[task]
387+
cls.task2context.pop(task, None)
388+
cls.task2cb.pop(task, None)
380389
cls.our_tasks.discard(task)
381390

382391
@classmethod
@@ -402,3 +411,20 @@ def service_remove(cls, domain, service):
402411
return
403412
cls.service_cnt[key] = 0
404413
cls.hass.services.async_remove(domain, service)
414+
415+
@classmethod
416+
def task_done_callback_ctx(cls, task, ast_ctx):
417+
"""Set the ast_ctx for a task, which is needed for done callbacks."""
418+
cls.task2cb[task] = {"ctx": ast_ctx, "cb": {}}
419+
420+
@classmethod
421+
def task_add_done_callback(cls, task, ast_ctx, callback, *args, **kwargs):
422+
"""Add a done callback to the given task."""
423+
if ast_ctx is None:
424+
ast_ctx = cls.task2cb[task]["ctx"]
425+
cls.task2cb[task]["cb"][callback] = [ast_ctx, args, kwargs]
426+
427+
@classmethod
428+
def task_remove_done_callback(cls, task, callback):
429+
"""Remove a done callback to the given task."""
430+
cls.task2cb[task]["cb"].pop(callback, None)

custom_components/pyscript/trigger.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import homeassistant.helpers.sun as sun
1515

1616
from .const import LOGGER_PATH
17-
from .eval import AstEval, EvalFuncVar, EvalFuncVarAstCtx
17+
from .eval import AstEval, EvalFunc, EvalFuncVar, EvalFuncVarAstCtx
1818
from .event import Event
1919
from .function import Function
2020
from .mqtt import Mqtt
@@ -139,21 +139,28 @@ def user_task_create_factory(ast_ctx):
139139
async def user_task_create(func, *args, **kwargs):
140140
"""Implement task.create()."""
141141

142-
async def func_call(func, new_ast_ctx, *args, **kwargs):
142+
async def func_call(func, func_name, new_ast_ctx, *args, **kwargs):
143143
"""Call user function inside task.create()."""
144-
ret = await func.call(new_ast_ctx, *args, **kwargs)
144+
ret = await new_ast_ctx.call_func(func, func_name, *args, **kwargs)
145145
if new_ast_ctx.get_exception_obj():
146146
new_ast_ctx.get_logger().error(new_ast_ctx.get_exception_long())
147147
return ret
148148

149-
if not isinstance(func, (EvalFuncVar, EvalFuncVarAstCtx)):
150-
raise TypeError("function is not callable by task.create()")
149+
try:
150+
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
151+
func_name = func.get_name()
152+
else:
153+
func_name = func.__name__
154+
except Exception:
155+
func_name = "<function>"
151156

152157
new_ast_ctx = AstEval(
153-
f"{ast_ctx.get_global_ctx_name()}.{func.get_name()}", ast_ctx.get_global_ctx()
158+
f"{ast_ctx.get_global_ctx_name()}.{func_name}", ast_ctx.get_global_ctx()
154159
)
155160
Function.install_ast_funcs(new_ast_ctx)
156-
return Function.create_task(func_call(func, new_ast_ctx, *args, **kwargs))
161+
task = Function.create_task(func_call(func, func_name, new_ast_ctx, *args, **kwargs))
162+
Function.task_done_callback_ctx(task, new_ast_ctx)
163+
return task
157164

158165
return user_task_create
159166

@@ -173,9 +180,25 @@ async def user_task_wait(aws):
173180
"""Implement task.wait()."""
174181
return await asyncio.wait(aws)
175182

183+
async def user_task_add_done_callback(task, callback, *args, **kwargs):
184+
"""Implement task.add_done_callback()."""
185+
ast_ctx = None
186+
if type(callback) is EvalFuncVarAstCtx:
187+
ast_ctx = callback.get_ast_ctx()
188+
callback = callback.get_eval_func_var()
189+
Function.task_add_done_callback(task, ast_ctx, callback, *args, **kwargs)
190+
191+
async def user_task_remove_done_callback(task, callback):
192+
"""Implement task.remove_done_callback()."""
193+
if type(callback) is EvalFuncVarAstCtx:
194+
callback = callback.get_eval_func_var()
195+
Function.task_remove_done_callback(task, callback)
196+
176197
funcs = {
177198
"task.cancel": user_task_cancel,
178199
"task.wait": user_task_wait,
200+
"task.add_done_callback": user_task_add_done_callback,
201+
"task.remove_done_callback": user_task_remove_done_callback,
179202
}
180203
Function.register(funcs)
181204

@@ -1132,7 +1155,7 @@ async def do_func_call(func, ast_ctx, task_unique, task_unique_func, hass_contex
11321155

11331156
if task_unique and task_unique_func:
11341157
await task_unique_func(task_unique)
1135-
await func.call(ast_ctx, **kwargs)
1158+
await ast_ctx.call_func(func, None, **kwargs)
11361159
if ast_ctx.get_exception_obj():
11371160
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
11381161

tests/test_tasks.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Test the pyscript apps, modules and import features."""
2+
3+
import asyncio
4+
import re
5+
6+
from custom_components.pyscript.const import DOMAIN, FOLDER
7+
from mock_open import MockOpen
8+
from pytest_homeassistant_custom_component.async_mock import patch
9+
10+
from homeassistant.const import EVENT_STATE_CHANGED
11+
from homeassistant.setup import async_setup_component
12+
13+
14+
async def wait_until_done(notify_q):
15+
"""Wait for the done handshake."""
16+
return await asyncio.wait_for(notify_q.get(), timeout=4)
17+
18+
19+
async def test_tasks(hass, caplog):
20+
"""Test starting tasks."""
21+
22+
conf_dir = hass.config.path(FOLDER)
23+
24+
file_contents = {
25+
f"{conf_dir}/hello.py": """
26+
27+
#
28+
# check starting multiple tasks, each stopping the prior one
29+
#
30+
def task1(cnt, last):
31+
task.unique('task1')
32+
if not last:
33+
task.sleep(10)
34+
log.info(f"finished task1, cnt={cnt}")
35+
36+
for cnt in range(10):
37+
task.create(task1, cnt, cnt == 9)
38+
39+
#
40+
# check the return value after wait
41+
#
42+
def task2(arg):
43+
return 2 * arg
44+
45+
t2a = task.create(task2, 21)
46+
t2b = task.create(task2, 51)
47+
done, pending = task.wait({t2a, t2b})
48+
log.info(f"task2() results = {[t2a.result(), t2b.result()]}, len(done) = {len(done)};")
49+
50+
#
51+
# check the return value with a regular function
52+
#
53+
@pyscript_compile
54+
def task3(arg):
55+
return 2 * arg
56+
57+
t3a = task.create(task3, 22)
58+
t3b = task.create(task3, 52)
59+
done, pending = task.wait({t3a, t3b})
60+
log.info(f"task3() results = {[t3a.result(), t3b.result()]}, len(done) = {len(done)};")
61+
62+
63+
#
64+
# check that we can do a done callback
65+
#
66+
def task4(arg):
67+
task.wait_until(state_trigger="pyscript.var4 == '1'")
68+
return 2 * arg
69+
70+
def callback4a(arg):
71+
log.info(f"callback4a arg = {arg}")
72+
73+
def callback4b(arg):
74+
log.info(f"callback4b arg = {arg}")
75+
76+
def callback4c(arg):
77+
log.info(f"callback4c arg = {arg}")
78+
79+
t4 = task.create(task4, 23)
80+
task.add_done_callback(t4, callback4a, 26)
81+
task.add_done_callback(t4, callback4b, 101)
82+
task.add_done_callback(t4, callback4c, 200)
83+
task.add_done_callback(t4, callback4a, 25)
84+
task.add_done_callback(t4, callback4c, 201)
85+
task.add_done_callback(t4, callback4b, 100)
86+
task.add_done_callback(t4, callback4a, 24)
87+
task.remove_done_callback(t4, callback4c)
88+
task.remove_done_callback(t4, task4)
89+
pyscript.var4 = 1
90+
done, pending = task.wait({t4})
91+
log.info(f"task4() result = {t4.result()}, len(done) = {len(done)};")
92+
93+
""",
94+
}
95+
96+
mock_open = MockOpen()
97+
for key, value in file_contents.items():
98+
mock_open[key].read_data = value
99+
100+
def isfile_side_effect(arg):
101+
return arg in file_contents
102+
103+
def glob_side_effect(path, recursive=None):
104+
result = []
105+
path_re = path.replace("*", "[^/]*").replace(".", "\\.")
106+
path_re = path_re.replace("[^/]*[^/]*/", ".*")
107+
for this_path in file_contents:
108+
if re.match(path_re, this_path):
109+
result.append(this_path)
110+
return result
111+
112+
conf = {"apps": {"world": {}}}
113+
with patch("custom_components.pyscript.os.path.isdir", return_value=True), patch(
114+
"custom_components.pyscript.glob.iglob"
115+
) as mock_glob, patch("custom_components.pyscript.global_ctx.open", mock_open), patch(
116+
"custom_components.pyscript.open", mock_open
117+
), patch(
118+
"homeassistant.config.load_yaml_config_file", return_value={"pyscript": conf}
119+
), patch(
120+
"custom_components.pyscript.os.path.getmtime", return_value=1000
121+
), patch(
122+
"custom_components.pyscript.global_ctx.os.path.getmtime", return_value=1000
123+
), patch(
124+
"custom_components.pyscript.os.path.isfile"
125+
) as mock_isfile:
126+
mock_isfile.side_effect = isfile_side_effect
127+
mock_glob.side_effect = glob_side_effect
128+
assert await async_setup_component(hass, "pyscript", {DOMAIN: conf})
129+
130+
notify_q = asyncio.Queue(0)
131+
132+
async def state_changed(event):
133+
var_name = event.data["entity_id"]
134+
if var_name != "pyscript.done":
135+
return
136+
value = event.data["new_state"].state
137+
await notify_q.put(value)
138+
139+
hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed)
140+
141+
assert caplog.text.count("finished task1, cnt=9") == 1
142+
assert "task2() results = [42, 102], len(done) = 2;" in caplog.text
143+
assert "task3() results = [44, 104], len(done) = 2;" in caplog.text
144+
assert "task4() result = 46, len(done) = 1;" in caplog.text
145+
assert caplog.text.count("callback4a arg =") == 1
146+
assert "callback4a arg = 24" in caplog.text
147+
assert caplog.text.count("callback4b arg =") == 1
148+
assert "callback4b arg = 100" in caplog.text
149+
assert "callback4c arg =" not in caplog.text

0 commit comments

Comments
 (0)