diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index f3e6514..34ebb2a 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -32,6 +32,7 @@ ) from .eval import AstEval from .event import Event +from .mqtt import Mqtt from .function import Function from .global_ctx import GlobalContext, GlobalContextMgr from .jupyter_kernel import Kernel @@ -118,6 +119,7 @@ async def async_setup_entry(hass, config_entry): Function.init(hass) Event.init(hass) + Mqtt.init(hass) TrigTime.init(hass) State.init(hass) State.register_functions() diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 96f9d40..ad57d57 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -278,11 +278,13 @@ async def trigger_init(self): "time_trigger", "state_trigger", "event_trigger", + "mqtt_trigger", } trig_decorators = { "time_trigger", "state_trigger", "event_trigger", + "mqtt_trigger", "state_active", "time_active", "task_unique", @@ -393,6 +395,7 @@ async def do_service_call(func, ast_ctx, data): # arg_check = { "event_trigger": {"arg_cnt": {1, 2}}, + "mqtt_trigger": {"arg_cnt": {1, 2}}, "state_active": {"arg_cnt": {1}}, "state_trigger": {"arg_cnt": {"*"}, "type": {list, set}}, "task_unique": {"arg_cnt": {1}}, diff --git a/custom_components/pyscript/mqtt.py b/custom_components/pyscript/mqtt.py new file mode 100644 index 0000000..50c5158 --- /dev/null +++ b/custom_components/pyscript/mqtt.py @@ -0,0 +1,91 @@ +"""Handles mqtt messages and notification.""" + +import logging +import json +from homeassistant.components import mqtt + + +from .const import LOGGER_PATH + +_LOGGER = logging.getLogger(LOGGER_PATH + ".mqtt") + + +class Mqtt: + """Define mqtt functions.""" + + # + # Global hass instance + # + hass = None + + # + # notify message queues by mqtt message topic + # + notify = {} + notify_remove = {} + + def __init__(self): + """Warn on Mqtt instantiation.""" + _LOGGER.error("Mqtt class is not meant to be instantiated") + + @classmethod + def init(cls, hass): + """Initialize Mqtt.""" + + cls.hass = hass + + @classmethod + def mqtt_message_handler_maker(cls, subscribed_topic): + """closure for mqtt_message_handler""" + + async def mqtt_message_handler(mqttmsg): + """Listen for MQTT messages.""" + func_args = { + "trigger_type": "mqtt", + "topic": mqttmsg.topic, + "payload": mqttmsg.payload, + "qos": mqttmsg.qos, + } + + try: + func_args["payload_obj"] = json.loads(mqttmsg.payload) + except ValueError: + pass + + await cls.update(subscribed_topic, func_args) + + return mqtt_message_handler + + @classmethod + async def notify_add(cls, topic, queue): + """Register to notify for mqtt messages of given topic to be sent to queue.""" + + if topic not in cls.notify: + cls.notify[topic] = set() + _LOGGER.debug("mqtt.notify_add(%s) -> adding mqtt subscription", topic) + cls.notify_remove[topic] = await mqtt.async_subscribe( + cls.hass, topic, cls.mqtt_message_handler_maker(topic), encoding='utf-8', qos=0 + ) + cls.notify[topic].add(queue) + + @classmethod + def notify_del(cls, topic, queue): + """Unregister to notify for mqtt messages of given topic for given queue.""" + + if topic not in cls.notify or queue not in cls.notify[topic]: + return + cls.notify[topic].discard(queue) + if len(cls.notify[topic]) == 0: + cls.notify_remove[topic]() + _LOGGER.debug("mqtt.notify_del(%s) -> removing mqtt subscription", topic) + del cls.notify[topic] + del cls.notify_remove[topic] + + @classmethod + async def update(cls, topic, func_args): + """Deliver all notifications for an mqtt message on the given topic.""" + + _LOGGER.debug("mqtt.update(%s, %s, %s)", topic, vars, func_args) + if topic in cls.notify: + for queue in cls.notify[topic]: + await queue.put(["mqtt", func_args]) diff --git a/custom_components/pyscript/trigger.py b/custom_components/pyscript/trigger.py index 6537e74..e4a73c1 100644 --- a/custom_components/pyscript/trigger.py +++ b/custom_components/pyscript/trigger.py @@ -16,6 +16,7 @@ from .const import LOGGER_PATH from .eval import AstEval from .event import Event +from .mqtt import Mqtt from .function import Function from .state import STATE_VIRTUAL_ATTRS, State @@ -149,13 +150,14 @@ async def wait_until( state_check_now=True, time_trigger=None, event_trigger=None, + mqtt_trigger=None, timeout=None, state_hold=None, state_hold_false=None, __test_handshake__=None, ): """Wait for zero or more triggers, until an optional timeout.""" - if state_trigger is None and time_trigger is None and event_trigger is None: + if state_trigger is None and time_trigger is None and event_trigger is None and mqtt_trigger is None: if timeout is not None: await asyncio.sleep(timeout) return {"trigger_type": "timeout"} @@ -164,6 +166,7 @@ async def wait_until( state_trig_ident_any = set() state_trig_eval = None event_trig_expr = None + mqtt_trig_expr = None exc = None notify_q = asyncio.Queue(0) @@ -260,6 +263,23 @@ async def wait_until( State.notify_del(state_trig_ident, notify_q) raise exc Event.notify_add(event_trigger[0], notify_q) + if mqtt_trigger is not None: + if isinstance(mqtt_trigger, str): + mqtt_trigger = [mqtt_trigger] + if len(mqtt_trigger) > 1: + mqtt_trig_expr = AstEval( + f"{ast_ctx.name} mqtt_trigger", + ast_ctx.get_global_ctx(), + logger_name=ast_ctx.get_logger_name(), + ) + Function.install_ast_funcs(mqtt_trig_expr) + mqtt_trig_expr.parse(mqtt_trigger[1], mode="eval") + exc = mqtt_trig_expr.get_exception_obj() + if exc is not None: + if len(state_trig_ident) > 0: + State.notify_del(state_trig_ident, notify_q) + raise exc + await Mqtt.notify_add(mqtt_trigger[0], notify_q) time0 = time.monotonic() if __test_handshake__: @@ -297,7 +317,7 @@ async def wait_until( this_timeout = time_left state_trig_timeout = True if this_timeout is None: - if state_trigger is None and event_trigger is None: + if state_trigger is None and event_trigger is None and mqtt_trigger is None: _LOGGER.debug( "trigger %s wait_until no next time - returning with none", ast_ctx.name, ) @@ -403,6 +423,17 @@ async def wait_until( if event_trig_ok: ret = notify_info break + elif notify_type == "mqtt": + if mqtt_trig_expr is None: + ret = notify_info + break + mqtt_trig_ok = await mqtt_trig_expr.eval(notify_info) + exc = mqtt_trig_expr.get_exception_obj() + if exc is not None: + break + if mqtt_trig_ok: + ret = notify_info + break else: _LOGGER.error( "trigger %s wait_until got unexpected queue message %s", ast_ctx.name, notify_type, @@ -412,6 +443,8 @@ async def wait_until( State.notify_del(state_trig_ident, notify_q) if event_trigger is not None: Event.notify_del(event_trigger[0], notify_q) + if mqtt_trigger is not None: + Mqtt.notify_del(mqtt_trigger[0], notify_q) if exc: raise exc return ret @@ -641,6 +674,7 @@ def __init__( self.state_check_now = self.state_trigger_kwargs.get("state_check_now", False) self.time_trigger = trig_cfg.get("time_trigger", {}).get("args", None) self.event_trigger = trig_cfg.get("event_trigger", {}).get("args", None) + self.mqtt_trigger = trig_cfg.get("mqtt_trigger", {}).get("args", None) self.state_active = trig_cfg.get("state_active", {}).get("args", None) self.time_active = trig_cfg.get("time_active", {}).get("args", None) self.time_active_hold_off = trig_cfg.get("time_active", {}).get("kwargs", {}).get("hold_off", None) @@ -656,6 +690,7 @@ def __init__( self.state_trig_ident = None self.state_trig_ident_any = set() self.event_trig_expr = None + self.mqtt_trig_expr = None self.have_trigger = False self.setup_ok = False self.run_on_startup = False @@ -726,6 +761,19 @@ def __init__( return self.have_trigger = True + if self.mqtt_trigger is not None: + if len(self.mqtt_trigger) == 2: + self.mqtt_trig_expr = AstEval( + f"{self.name} @mqtt_trigger()", self.global_ctx, logger_name=self.name, + ) + Function.install_ast_funcs(self.mqtt_trig_expr) + self.mqtt_trig_expr.parse(self.mqtt_trigger[1], mode="eval") + exc = self.mqtt_trig_expr.get_exception_long() + if exc is not None: + self.mqtt_trig_expr.get_logger().error(exc) + return + self.have_trigger = True + self.setup_ok = True def stop(self): @@ -736,6 +784,8 @@ def stop(self): State.notify_del(self.state_trig_ident, self.notify_q) if self.event_trigger is not None: Event.notify_del(self.event_trigger[0], self.notify_q) + if self.mqtt_trigger is not None: + Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q) if self.task: Function.task_cancel(self.task) @@ -765,6 +815,9 @@ async def trigger_watch(self): if self.event_trigger is not None: _LOGGER.debug("trigger %s adding event_trigger %s", self.name, self.event_trigger[0]) Event.notify_add(self.event_trigger[0], self.notify_q) + if self.mqtt_trigger is not None: + _LOGGER.debug("trigger %s adding mqtt_trigger %s", self.name, self.mqtt_trigger[0]) + await Mqtt.notify_add(self.mqtt_trigger[0], self.notify_q) last_trig_time = None last_state_trig_time = None @@ -924,6 +977,10 @@ async def trigger_watch(self): func_args = notify_info if self.event_trig_expr: trig_ok = await self.event_trig_expr.eval(notify_info) + elif notify_type == "mqtt": + func_args = notify_info + if self.mqtt_trig_expr: + trig_ok = await self.mqtt_trig_expr.eval(notify_info) else: func_args = notify_info @@ -1038,4 +1095,6 @@ async def do_func_call(func, ast_ctx, task_unique, task_unique_func, hass_contex State.notify_del(self.state_trig_ident, self.notify_q) if self.event_trigger is not None: Event.notify_del(self.event_trigger[0], self.notify_q) + if self.mqtt_trigger is not None: + Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q) return diff --git a/docs/reference.rst b/docs/reference.rst index 6aaad13..f5a6c14 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -569,6 +569,33 @@ see that the ``EVENT_CALL_SERVICE`` event has parameters ``domain`` set to ``lig This `wiki page `__ gives more examples of built-in and user events and how to create triggers for them. +@mqtt_trigger +^^^^^^^^^^^^^ + +.. code:: python + + @mqtt_trigger(topic, str_expr=None) + +``@mqtt_trigger`` subscribes to the given MQTT ``topic`` and triggers whenever a message is received +on that topic. An optional ``str_expr`` can be used to match the MQTT message data, and the trigger +will only occur if that expression evaluates to ``True`` or non-zero. This expression has available +these four variables: + +- ``trigger_type`` is set to “mqtt” +- ``topic`` is set to the topic the message was received on +- ``payload`` is set to the string payload of the message +- ``payload_obj`` if the payload was valid JSON, this will be set to the native python object + representing that payload. + +When the ``@mqtt_trigger`` occurs, those same variables are passed as keyword arguments to the +function in case it needs them. + +Wildcards in topics are supported. The ``topic`` variables will be set to the full expanded topic +the message arrived on. + +NOTE: The `MQTT Integration in Home Assistant `__ +must be set up to use ``@mqtt_trigger``. + @task_unique ^^^^^^^^^^^^ @@ -859,6 +886,9 @@ It takes the following keyword arguments (all are optional): - ``event_trigger=None`` can be set to a string or list of two strings, just like ``@event_trigger``. The first string is the name of the event, and the second string (when the setting is a two-element list) is an expression based on the event parameters. +- ``mqtt_trigger=None`` can be set to a string or list of two strings, just like + ``@mqtt_trigger``. The first string is the MQTT topic, and the second string + (when the setting is a two-element list) is an expression based on the message variables. - ``timeout=None`` an overall timeout in seconds, which can be floating point. - ``state_check_now=True`` if set, ``task.wait_until()`` checks any ``state_trigger`` immediately to see if it is already ``True``, and will return immediately if so. If diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index 7cd60b1..94e80ca 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -179,7 +179,7 @@ def func_wrapup(): ) assert "SyntaxError: invalid syntax (file.hello.func3 @state_active(), line 1)" in caplog.text assert ( - "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, state_trigger, time_trigger)" + "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, mqtt_trigger, state_trigger, time_trigger)" in caplog.text ) assert (