5
5
6
6
from homeassistant .core import Context
7
7
from homeassistant .helpers .restore_state import RestoreStateData
8
+ from homeassistant .helpers .service import async_get_all_descriptions
8
9
9
10
from .const import LOGGER_PATH
10
11
from .function import Function
@@ -43,6 +44,11 @@ class State:
43
44
#
44
45
persisted_vars = set ()
45
46
47
+ #
48
+ # other parameters of all services that have "entity_id" as a parameter
49
+ #
50
+ service2args = {}
51
+
46
52
def __init__ (self ):
47
53
"""Warn on State instantiation."""
48
54
_LOGGER .error ("State class is not meant to be instantiated" )
@@ -52,6 +58,19 @@ def init(cls, hass):
52
58
"""Initialize State."""
53
59
cls .hass = hass
54
60
61
+ @classmethod
62
+ async def get_service_params (cls ):
63
+ """Get parameters for all services."""
64
+ cls .service2args = {}
65
+ all_services = await async_get_all_descriptions (cls .hass )
66
+ for domain in all_services :
67
+ cls .service2args [domain ] = {}
68
+ for service , desc in all_services [domain ].items ():
69
+ if "entity_id" not in desc ["fields" ]:
70
+ continue
71
+ cls .service2args [domain ][service ] = set (desc ["fields" ].keys ())
72
+ cls .service2args [domain ][service ].discard ("entity_id" )
73
+
55
74
@classmethod
56
75
async def notify_add (cls , var_names , queue ):
57
76
"""Register to notify state variables changes to be sent to queue."""
@@ -165,25 +184,70 @@ def exist(cls, var_name):
165
184
if len (parts ) != 2 and len (parts ) != 3 :
166
185
return False
167
186
value = cls .hass .states .get (f"{ parts [0 ]} .{ parts [1 ]} " )
168
- return value and (
169
- len (parts ) == 2 or parts [2 ] in value .attributes or parts [2 ] in {"last_changed" , "last_updated" }
170
- )
187
+ if value is None :
188
+ return False
189
+ if (
190
+ len (parts ) == 2
191
+ or (parts [0 ] in cls .service2args and parts [2 ] in cls .service2args [parts [0 ]])
192
+ or parts [2 ] in value .attributes
193
+ or parts [2 ] in {"last_changed" , "last_updated" }
194
+ ):
195
+ return True
196
+ return False
171
197
172
198
@classmethod
173
199
async def get (cls , var_name ):
174
200
"""Get a state variable value or attribute from hass."""
175
201
parts = var_name .split ("." )
176
202
if len (parts ) != 2 and len (parts ) != 3 :
177
- raise NameError (f"invalid name '{ var_name } ' (should be 'domain.entity')" )
203
+ raise NameError (f"invalid name '{ var_name } ' (should be 'domain.entity' or 'domain.entity.attr' )" )
178
204
value = cls .hass .states .get (f"{ parts [0 ]} .{ parts [1 ]} " )
179
205
if not value :
180
206
raise NameError (f"name '{ parts [0 ]} .{ parts [1 ]} ' is not defined" )
207
+ #
208
+ # simplest case is just the state value
209
+ #
181
210
if len (parts ) == 2 :
182
211
return value .state
212
+ #
213
+ # handle virtual attributes
214
+ #
183
215
if parts [2 ] == "last_changed" :
184
216
return value .last_changed
185
217
if parts [2 ] == "last_updated" :
186
218
return value .last_updated
219
+ #
220
+ # see if this is a service that has an entity_id parameter
221
+ #
222
+ if parts [0 ] in cls .service2args and parts [2 ] in cls .service2args [parts [0 ]]:
223
+ params = cls .service2args [parts [0 ]][parts [2 ]]
224
+
225
+ def service_call_factory (domain , service , entity_id , params ):
226
+ async def service_call (* args , ** kwargs ):
227
+ curr_task = asyncio .current_task ()
228
+ if "context" in kwargs and isinstance (kwargs ["context" ], Context ):
229
+ context = kwargs ["context" ]
230
+ del kwargs ["context" ]
231
+ else :
232
+ context = Function .task2context .get (curr_task , None )
233
+
234
+ kwargs ["entity_id" ] = entity_id
235
+ if len (args ) == 1 and len (params ) == 1 :
236
+ #
237
+ # with just a single parameter and positional argument, create the keyword setting
238
+ #
239
+ [param_name ] = params
240
+ kwargs [param_name ] = args [0 ]
241
+ elif len (args ) != 0 :
242
+ raise TypeError (f"service { domain } .{ service } takes no positional arguments" )
243
+ await cls .hass .services .async_call (domain , service , kwargs , context = context )
244
+
245
+ return service_call
246
+
247
+ return service_call_factory (parts [0 ], parts [2 ], f"{ parts [0 ]} .{ parts [1 ]} " , params )
248
+ #
249
+ # finally see if it is an attribute
250
+ #
187
251
if parts [2 ] not in value .attributes :
188
252
raise AttributeError (f"state '{ parts [0 ]} .{ parts [1 ]} ' has no attribute '{ parts [2 ]} '" )
189
253
return value .attributes .get (parts [2 ])
@@ -202,7 +266,8 @@ async def get_attr(cls, var_name):
202
266
def completions (cls , root ):
203
267
"""Return possible completions of state variables."""
204
268
words = set ()
205
- num_period = root .count ("." )
269
+ parts = root .split ("." )
270
+ num_period = len (parts ) - 1
206
271
if num_period == 2 :
207
272
#
208
273
# complete state attributes
@@ -212,7 +277,10 @@ def completions(cls, root):
212
277
value = cls .hass .states .get (name )
213
278
if value :
214
279
attr_root = root [last_period + 1 :]
215
- for attr_name in set (value .attributes .keys ()).union ({"last_changed" , "last_updated" }):
280
+ attrs = set (value .attributes .keys ()).union ({"last_changed" , "last_updated" })
281
+ if parts [0 ] in cls .service2args :
282
+ attrs .update (set (cls .service2args [parts [0 ]].keys ()))
283
+ for attr_name in attrs :
216
284
if attr_name .lower ().startswith (attr_root ):
217
285
words .add (f"{ name } .{ attr_name } " )
218
286
elif num_period < 2 :
0 commit comments