From 380868c3068c0932b1fa2cd7e20b59bc6e4ac313 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 2 Sep 2020 20:43:03 +0200 Subject: [PATCH] add configuration option "allow_all_imports" --- custom_components/pyscript/__init__.py | 18 +++++++++++++++++- custom_components/pyscript/eval.py | 11 ++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index 456150c..9b44017 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -5,12 +5,15 @@ import logging import os +import voluptuous as vol + from homeassistant.const import ( EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, SERVICE_RELOAD, ) +import homeassistant.helpers.config_validation as cv from homeassistant.loader import bind_hass from .const import DOMAIN, FOLDER, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START @@ -24,10 +27,20 @@ _LOGGER = logging.getLogger(LOGGER_PATH) +CONF_ALLOW_ALL_IMPORTS = "allow_all_imports" + +CONFIG_SCHEMA = vol.Schema( + { + DOMAIN: vol.Schema( + {vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): cv.boolean} + ), + }, + extra=vol.ALLOW_EXTRA, +) + async def async_setup(hass, config): """Initialize the pyscript component.""" - handler_func = Handler(hass) event_func = Event(hass) trig_time_func = TrigTime(hass, handler_func) @@ -44,6 +57,9 @@ def check_isdir(path): _LOGGER.error("Folder %s not found in configuration folder", FOLDER) return False + hass.data.setdefault(DOMAIN, {}) + hass.data[DOMAIN]["allow_all_imports"] = config[DOMAIN].get(CONF_ALLOW_ALL_IMPORTS) + await compile_scripts( # pylint: disable=unused-variable hass, event_func=event_func, diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 74290eb..9f7de44 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -8,7 +8,7 @@ import logging import sys -from .const import ALLOWED_IMPORTS, LOGGER_PATH +from .const import ALLOWED_IMPORTS, DOMAIN, LOGGER_PATH _LOGGER = logging.getLogger(LOGGER_PATH + ".eval") @@ -361,6 +361,11 @@ def __init__( self.logger_handlers = set() self.logger = None self.set_logger_name(logger_name if logger_name is not None else self.name) + self.allow_all_imports = ( + global_ctx.hass.data[DOMAIN]["allow_all_imports"] + if global_ctx.hass is not None + else False + ) async def ast_not_implemented(self, arg, *args): """Raise NotImplementedError exception for unimplemented AST types.""" @@ -399,7 +404,7 @@ async def ast_module(self, arg): async def ast_import(self, arg): """Execute import.""" for imp in arg.names: - if imp.name not in ALLOWED_IMPORTS: + if not self.allow_all_imports and imp.name not in ALLOWED_IMPORTS: raise ModuleNotFoundError(f"import of {imp.name} not allowed") if imp.name not in sys.modules: mod = importlib.import_module(imp.name) @@ -409,7 +414,7 @@ async def ast_import(self, arg): async def ast_importfrom(self, arg): """Execute from X import Y.""" - if arg.module not in ALLOWED_IMPORTS: + if not self.allow_all_imports and arg.module not in ALLOWED_IMPORTS: raise ModuleNotFoundError(f"import from {arg.module} not allowed") if arg.module not in sys.modules: mod = importlib.import_module(arg.module)