Skip to content

Commit 936ac3b

Browse files
committed
fix type check for kwargs and add function tests
1 parent 028656f commit 936ac3b

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

custom_components/pyscript/function.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ async def service_call(cls, domain, name, **kwargs):
192192
curr_task = asyncio.current_task()
193193
hass_args = {}
194194
for keyword, typ, default in [
195-
("context", Context, cls.task2context.get(curr_task, None)),
196-
("blocking", bool, None),
197-
("limit", float, None),
195+
("context", [Context], cls.task2context.get(curr_task, None)),
196+
("blocking", [bool], None),
197+
("limit", [float, int], None),
198198
]:
199-
if keyword in kwargs and isinstance(kwargs[keyword], typ):
199+
if keyword in kwargs and type(kwargs[keyword]) in typ:
200200
hass_args[keyword] = kwargs.pop(keyword)
201201
elif default:
202202
hass_args[keyword] = default
@@ -262,17 +262,18 @@ async def service_call(*args, **kwargs):
262262
curr_task = asyncio.current_task()
263263
hass_args = {}
264264
for keyword, typ, default in [
265-
("context", Context, cls.task2context.get(curr_task, None)),
266-
("blocking", bool, None),
267-
("limit", float, None),
265+
("context", [Context], cls.task2context.get(curr_task, None)),
266+
("blocking", [bool], None),
267+
("limit", [float, int], None),
268268
]:
269-
if keyword in kwargs and isinstance(kwargs[keyword], typ):
269+
if keyword in kwargs and type(kwargs[keyword]) in typ:
270270
hass_args[keyword] = kwargs.pop(keyword)
271271
elif default:
272272
hass_args[keyword] = default
273273

274274
if len(args) != 0:
275275
raise (TypeError, f"service {domain}.{service} takes no positional arguments")
276+
276277
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
277278

278279
return service_call

custom_components/pyscript/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,11 @@ async def service_call(*args, **kwargs):
260260
curr_task = asyncio.current_task()
261261
hass_args = {}
262262
for keyword, typ, default in [
263-
("context", Context, Function.task2context.get(curr_task, None)),
264-
("blocking", bool, None),
265-
("limit", float, None),
263+
("context", [Context], Function.task2context.get(curr_task, None)),
264+
("blocking", [bool], None),
265+
("limit", [float, int], None),
266266
]:
267-
if keyword in kwargs and isinstance(kwargs[keyword], typ):
267+
if keyword in kwargs and type(kwargs[keyword]) in typ:
268268
hass_args[keyword] = kwargs.pop(keyword)
269269
elif default:
270270
hass_args[keyword] = default

tests/test_function.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytest_homeassistant_custom_component.async_mock import MagicMock, Mock, mock_open, patch
1313

1414
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED
15+
from homeassistant.core import Context
1516
from homeassistant.setup import async_setup_component
1617

1718

@@ -804,3 +805,41 @@ def set_add(entity_id=None, val1=None, val2=None):
804805
assert literal_eval(await wait_until_done(notify_q)) == [4, "pyscript.var1", "32"]
805806
assert literal_eval(await wait_until_done(notify_q)) == [5, "pyscript.var1", "50", "HomeAssistant"]
806807
assert "TypeError: service pyscript.set_add takes no positional arguments" in caplog.text
808+
809+
810+
async def test_service_call_params(hass):
811+
"""Test that hass params get set properly on service calls."""
812+
Function.init(hass)
813+
with patch.object(Function.hass.services, "async_call") as call, patch.object(
814+
Function, "service_has_service", return_value=True
815+
):
816+
await Function.service_call(
817+
"test", "test", context=Context(id="test"), blocking=True, limit=1, other_service_data="test"
818+
)
819+
assert call.called
820+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
821+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1}
822+
call.reset_mock()
823+
824+
await Function.service_call(
825+
"test", "test", context=Context(id="test"), blocking=False, other_service_data="test"
826+
)
827+
assert call.called
828+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
829+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}
830+
call.reset_mock()
831+
832+
await Function.get("test.test")(
833+
context=Context(id="test"), blocking=True, limit=1, other_service_data="test"
834+
)
835+
assert call.called
836+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
837+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1}
838+
call.reset_mock()
839+
840+
await Function.get("test.test")(
841+
context=Context(id="test"), blocking=False, other_service_data="test"
842+
)
843+
assert call.called
844+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
845+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}

0 commit comments

Comments
 (0)