21
21
_LOGGER = logging .getLogger (LOGGER_PATH + ".trigger" )
22
22
23
23
24
+ STATE_RE = re .compile (r"[a-zA-Z]\w*\.[a-zA-Z]\w*$" )
25
+
26
+
24
27
def dt_now ():
25
28
"""Return current time."""
26
29
return dt .datetime .now ()
@@ -101,32 +104,57 @@ async def wait_until(
101
104
await asyncio .sleep (timeout )
102
105
return {"trigger_type" : "timeout" }
103
106
return {"trigger_type" : "none" }
104
- state_trig_ident = None
105
- state_trig_expr = None
107
+ state_trig_ident = set ()
108
+ state_trig_ident_any = set ()
109
+ state_trig_eval = None
106
110
event_trig_expr = None
107
111
exc = None
108
112
notify_q = asyncio .Queue (0 )
109
113
if state_trigger is not None :
110
- state_trig_expr = AstEval (
111
- f"{ ast_ctx .name } state_trigger" ,
112
- ast_ctx .get_global_ctx (),
113
- logger_name = ast_ctx .get_logger_name (),
114
- )
115
- Function .install_ast_funcs (state_trig_expr )
116
- state_trig_expr .parse (state_trigger )
117
- exc = state_trig_expr .get_exception_obj ()
118
- if exc is not None :
119
- raise exc
114
+ state_trig = []
115
+ if isinstance (state_trigger , str ):
116
+ state_trigger = [state_trigger ]
117
+ elif isinstance (state_trigger , set ):
118
+ state_trigger = list (state_trigger )
120
119
#
121
- # check straight away to see if the condition is met (to avoid race conditions)
120
+ # separate out the entries that are just state var names, which mean trigger
121
+ # on any change (no expr)
122
122
#
123
- state_trig_ok = await state_trig_expr .eval ()
124
- exc = state_trig_expr .get_exception_obj ()
125
- if exc is not None :
126
- raise exc
127
- if state_trig_ok :
128
- return {"trigger_type" : "state" }
129
- state_trig_ident = await state_trig_expr .get_names ()
123
+ for trig in state_trigger :
124
+ if STATE_RE .match (trig ):
125
+ state_trig_ident_any .add (trig )
126
+ else :
127
+ state_trig .append (trig )
128
+
129
+ if len (state_trig ) > 0 :
130
+ if len (state_trig ) == 1 :
131
+ state_trig_expr = state_trig [0 ]
132
+ else :
133
+ state_trig_expr = f"any([{ ', ' .join (state_trig )} ])"
134
+ state_trig_eval = AstEval (
135
+ f"{ ast_ctx .name } state_trigger" ,
136
+ ast_ctx .get_global_ctx (),
137
+ logger_name = ast_ctx .get_logger_name (),
138
+ )
139
+ Function .install_ast_funcs (state_trig_eval )
140
+ state_trig_eval .parse (state_trig_expr )
141
+ state_trig_ident = await state_trig_eval .get_names ()
142
+ exc = state_trig_eval .get_exception_obj ()
143
+ if exc is not None :
144
+ raise exc
145
+
146
+ state_trig_ident .update (state_trig_ident_any )
147
+ if state_trig_eval :
148
+ #
149
+ # check straight away to see if the condition is met (to avoid race conditions)
150
+ #
151
+ state_trig_ok = await state_trig_eval .eval (State .notify_var_get (state_trig_ident , {}))
152
+ exc = state_trig_eval .get_exception_obj ()
153
+ if exc is not None :
154
+ raise exc
155
+ if state_trig_ok :
156
+ return {"trigger_type" : "state" }
157
+
130
158
_LOGGER .debug (
131
159
"trigger %s wait_until: watching vars %s" , ast_ctx .name , state_trig_ident ,
132
160
)
@@ -145,7 +173,7 @@ async def wait_until(
145
173
event_trig_expr .parse (event_trigger [1 ])
146
174
exc = event_trig_expr .get_exception_obj ()
147
175
if exc is not None :
148
- if state_trig_ident :
176
+ if len ( state_trig_ident ) > 0 :
149
177
State .notify_del (state_trig_ident , notify_q )
150
178
raise exc
151
179
Event .notify_add (event_trigger [0 ], notify_q )
@@ -191,11 +219,19 @@ async def wait_until(
191
219
ret ["trigger_time" ] = time_next
192
220
break
193
221
if notify_type == "state" :
194
- new_vars = notify_info [0 ] if notify_info else None
195
- state_trig_ok = await state_trig_expr .eval (new_vars )
196
- exc = state_trig_expr .get_exception_obj ()
197
- if exc is not None :
198
- break
222
+ if notify_info :
223
+ new_vars , func_args = notify_info
224
+ else :
225
+ new_vars , func_args = None , {}
226
+
227
+ state_trig_ok = False
228
+ if func_args .get ("var_name" , "" ) in state_trig_ident_any :
229
+ state_trig_ok = True
230
+ elif state_trig_eval :
231
+ state_trig_ok = await state_trig_eval .eval (new_vars )
232
+ exc = state_trig_eval .get_exception_obj ()
233
+ if exc is not None :
234
+ break
199
235
if state_trig_ok :
200
236
ret = notify_info [1 ] if notify_info else None
201
237
break
@@ -215,7 +251,7 @@ async def wait_until(
215
251
"trigger %s wait_until got unexpected queue message %s" , ast_ctx .name , notify_type ,
216
252
)
217
253
218
- if state_trig_ident :
254
+ if len ( state_trig_ident ) > 0 :
219
255
State .notify_del (state_trig_ident , notify_q )
220
256
if event_trigger is not None :
221
257
Event .notify_del (event_trigger [0 ], notify_q )
@@ -454,7 +490,9 @@ def __init__(
454
490
self .active_expr = None
455
491
self .state_active_ident = None
456
492
self .state_trig_expr = None
493
+ self .state_trig_eval = None
457
494
self .state_trig_ident = None
495
+ self .state_trig_ident_any = set ()
458
496
self .event_trig_expr = None
459
497
self .have_trigger = False
460
498
self .setup_ok = False
@@ -481,15 +519,36 @@ def __init__(
481
519
self .run_on_startup = True
482
520
483
521
if self .state_trigger is not None :
484
- self .state_trig_expr = AstEval (
485
- f"{ self .name } @state_trigger()" , self .global_ctx , logger_name = self .name
486
- )
487
- Function .install_ast_funcs (self .state_trig_expr )
488
- self .state_trig_expr .parse (self .state_trigger )
489
- exc = self .state_trig_expr .get_exception_long ()
490
- if exc is not None :
491
- self .state_trig_expr .get_logger ().error (exc )
492
- return
522
+ state_trig = []
523
+ for triggers in self .state_trigger :
524
+ if isinstance (triggers , str ):
525
+ triggers = [triggers ]
526
+ elif isinstance (triggers , set ):
527
+ triggers = list (triggers )
528
+ #
529
+ # separate out the entries that are just state var names, which mean trigger
530
+ # on any change (no expr)
531
+ #
532
+ for trig in triggers :
533
+ if STATE_RE .match (trig ):
534
+ self .state_trig_ident_any .add (trig )
535
+ else :
536
+ state_trig .append (trig )
537
+
538
+ if len (state_trig ) > 0 :
539
+ if len (state_trig ) == 1 :
540
+ self .state_trig_expr = state_trig [0 ]
541
+ else :
542
+ self .state_trig_expr = f"any([{ ', ' .join (state_trig )} ])"
543
+ self .state_trig_eval = AstEval (
544
+ f"{ self .name } @state_trigger()" , self .global_ctx , logger_name = self .name
545
+ )
546
+ Function .install_ast_funcs (self .state_trig_eval )
547
+ self .state_trig_eval .parse (self .state_trig_expr )
548
+ exc = self .state_trig_eval .get_exception_long ()
549
+ if exc is not None :
550
+ self .state_trig_eval .get_logger ().error (exc )
551
+ return
493
552
self .have_trigger = True
494
553
495
554
if self .event_trigger is not None :
@@ -530,7 +589,10 @@ async def trigger_watch(self):
530
589
try :
531
590
532
591
if self .state_trigger is not None :
533
- self .state_trig_ident = await self .state_trig_expr .get_names ()
592
+ self .state_trig_ident = set ()
593
+ if self .state_trig_eval :
594
+ self .state_trig_ident = await self .state_trig_eval .get_names ()
595
+ self .state_trig_ident .update (self .state_trig_ident_any )
534
596
_LOGGER .debug ("trigger %s: watching vars %s" , self .name , self .state_trig_ident )
535
597
if len (self .state_trig_ident ) > 0 :
536
598
State .notify_add (self .state_trig_ident , self .notify_q )
@@ -587,11 +649,14 @@ async def trigger_watch(self):
587
649
if notify_type == "state" :
588
650
new_vars , func_args = notify_info
589
651
590
- if self .state_trig_expr :
591
- trig_ok = await self .state_trig_expr .eval (new_vars )
592
- exc = self .state_trig_expr .get_exception_long ()
593
- if exc is not None :
594
- self .state_trig_expr .get_logger ().error (exc )
652
+ if func_args ["var_name" ] not in self .state_trig_ident_any :
653
+ if self .state_trig_eval :
654
+ trig_ok = await self .state_trig_eval .eval (new_vars )
655
+ exc = self .state_trig_eval .get_exception_long ()
656
+ if exc is not None :
657
+ self .state_trig_eval .get_logger ().error (exc )
658
+ trig_ok = False
659
+ else :
595
660
trig_ok = False
596
661
597
662
elif notify_type == "event" :
0 commit comments