Skip to content

Commit 3ffe6e4

Browse files
committed
adds virtual methods for services that have an entity_id parameter; see #64
1 parent 79140a3 commit 3ffe6e4

File tree

3 files changed

+93
-17
lines changed

3 files changed

+93
-17
lines changed

custom_components/pyscript/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ async def reload_scripts_handler(call):
142142
await update_yaml_config(hass, config_entry)
143143
State.set_pyscript_config(config_entry.data)
144144

145+
await State.get_service_params()
146+
145147
global_ctx_only = call.data.get("global_ctx", None)
146148

147149
if global_ctx_only is not None and not GlobalContextMgr.get(global_ctx_only):
@@ -202,6 +204,7 @@ async def state_changed(event):
202204

203205
async def hass_started(event):
204206
_LOGGER.debug("adding state changed listener and starting global contexts")
207+
await State.get_service_params()
205208
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed))
206209
start_global_contexts()
207210

custom_components/pyscript/function.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,22 @@ def get(cls, name):
252252
if not cls.service_has_service(domain, service):
253253
return None
254254

255-
async def service_call(*args, **kwargs):
256-
curr_task = asyncio.current_task()
257-
if "context" in kwargs and isinstance(kwargs["context"], Context):
258-
context = kwargs["context"]
259-
del kwargs["context"]
260-
else:
261-
context = cls.task2context.get(curr_task, None)
262-
263-
await cls.hass.services.async_call(domain, service, kwargs, context=context)
264-
265-
return service_call
255+
def service_call_factory(domain, service):
256+
async def service_call(*args, **kwargs):
257+
curr_task = asyncio.current_task()
258+
if "context" in kwargs and isinstance(kwargs["context"], Context):
259+
context = kwargs["context"]
260+
del kwargs["context"]
261+
else:
262+
context = cls.task2context.get(curr_task, None)
263+
264+
if len(args) != 0:
265+
raise (TypeError, f"service {domain}.{service} takes no positional arguments")
266+
await cls.hass.services.async_call(domain, service, kwargs, context=context)
267+
268+
return service_call
269+
270+
return service_call_factory(domain, service)
266271

267272
@classmethod
268273
async def run_coro(cls, coro):

custom_components/pyscript/state.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from homeassistant.core import Context
77
from homeassistant.helpers.restore_state import RestoreStateData
8+
from homeassistant.helpers.service import async_get_all_descriptions
89

910
from .const import LOGGER_PATH
1011
from .function import Function
@@ -43,6 +44,11 @@ class State:
4344
#
4445
persisted_vars = set()
4546

47+
#
48+
# other parameters of all services that have "entity_id" as a parameter
49+
#
50+
service2args = {}
51+
4652
def __init__(self):
4753
"""Warn on State instantiation."""
4854
_LOGGER.error("State class is not meant to be instantiated")
@@ -52,6 +58,19 @@ def init(cls, hass):
5258
"""Initialize State."""
5359
cls.hass = hass
5460

61+
@classmethod
62+
async def get_service_params(cls):
63+
"""Get parameters for all services."""
64+
cls.service2args = {}
65+
all_services = await async_get_all_descriptions(cls.hass)
66+
for domain in all_services:
67+
cls.service2args[domain] = {}
68+
for service, desc in all_services[domain].items():
69+
if "entity_id" not in desc["fields"]:
70+
continue
71+
cls.service2args[domain][service] = set(desc["fields"].keys())
72+
cls.service2args[domain][service].discard("entity_id")
73+
5574
@classmethod
5675
async def notify_add(cls, var_names, queue):
5776
"""Register to notify state variables changes to be sent to queue."""
@@ -165,25 +184,70 @@ def exist(cls, var_name):
165184
if len(parts) != 2 and len(parts) != 3:
166185
return False
167186
value = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
168-
return value and (
169-
len(parts) == 2 or parts[2] in value.attributes or parts[2] in {"last_changed", "last_updated"}
170-
)
187+
if value is None:
188+
return False
189+
if (
190+
len(parts) == 2
191+
or (parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]])
192+
or parts[2] in value.attributes
193+
or parts[2] in {"last_changed", "last_updated"}
194+
):
195+
return True
196+
return False
171197

172198
@classmethod
173199
async def get(cls, var_name):
174200
"""Get a state variable value or attribute from hass."""
175201
parts = var_name.split(".")
176202
if len(parts) != 2 and len(parts) != 3:
177-
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity')")
203+
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
178204
value = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
179205
if not value:
180206
raise NameError(f"name '{parts[0]}.{parts[1]}' is not defined")
207+
#
208+
# simplest case is just the state value
209+
#
181210
if len(parts) == 2:
182211
return value.state
212+
#
213+
# handle virtual attributes
214+
#
183215
if parts[2] == "last_changed":
184216
return value.last_changed
185217
if parts[2] == "last_updated":
186218
return value.last_updated
219+
#
220+
# see if this is a service that has an entity_id parameter
221+
#
222+
if parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]]:
223+
params = cls.service2args[parts[0]][parts[2]]
224+
225+
def service_call_factory(domain, service, entity_id, params):
226+
async def service_call(*args, **kwargs):
227+
curr_task = asyncio.current_task()
228+
if "context" in kwargs and isinstance(kwargs["context"], Context):
229+
context = kwargs["context"]
230+
del kwargs["context"]
231+
else:
232+
context = Function.task2context.get(curr_task, None)
233+
234+
kwargs["entity_id"] = entity_id
235+
if len(args) == 1 and len(params) == 1:
236+
#
237+
# with just a single parameter and positional argument, create the keyword setting
238+
#
239+
[param_name] = params
240+
kwargs[param_name] = args[0]
241+
elif len(args) != 0:
242+
raise TypeError(f"service {domain}.{service} takes no positional arguments")
243+
await cls.hass.services.async_call(domain, service, kwargs, context=context)
244+
245+
return service_call
246+
247+
return service_call_factory(parts[0], parts[2], f"{parts[0]}.{parts[1]}", params)
248+
#
249+
# finally see if it is an attribute
250+
#
187251
if parts[2] not in value.attributes:
188252
raise AttributeError(f"state '{parts[0]}.{parts[1]}' has no attribute '{parts[2]}'")
189253
return value.attributes.get(parts[2])
@@ -202,7 +266,8 @@ async def get_attr(cls, var_name):
202266
def completions(cls, root):
203267
"""Return possible completions of state variables."""
204268
words = set()
205-
num_period = root.count(".")
269+
parts = root.split(".")
270+
num_period = len(parts) - 1
206271
if num_period == 2:
207272
#
208273
# complete state attributes
@@ -212,7 +277,10 @@ def completions(cls, root):
212277
value = cls.hass.states.get(name)
213278
if value:
214279
attr_root = root[last_period + 1 :]
215-
for attr_name in set(value.attributes.keys()).union({"last_changed", "last_updated"}):
280+
attrs = set(value.attributes.keys()).union({"last_changed", "last_updated"})
281+
if parts[0] in cls.service2args:
282+
attrs.update(set(cls.service2args[parts[0]].keys()))
283+
for attr_name in attrs:
216284
if attr_name.lower().startswith(attr_root):
217285
words.add(f"{name}.{attr_name}")
218286
elif num_period < 2:

0 commit comments

Comments
 (0)