Skip to content

Commit 006cdc3

Browse files
authored
Merge pull request #10 from basnijholt/allow_all_imports
add configuration option "allow_all_imports"
2 parents 4333d31 + 380868c commit 006cdc3

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

custom_components/pyscript/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
import logging
66
import os
77

8+
import voluptuous as vol
9+
810
from homeassistant.const import (
911
EVENT_HOMEASSISTANT_STARTED,
1012
EVENT_HOMEASSISTANT_STOP,
1113
EVENT_STATE_CHANGED,
1214
SERVICE_RELOAD,
1315
)
16+
import homeassistant.helpers.config_validation as cv
1417
from homeassistant.loader import bind_hass
1518

1619
from .const import DOMAIN, FOLDER, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START
@@ -24,10 +27,20 @@
2427

2528
_LOGGER = logging.getLogger(LOGGER_PATH)
2629

30+
CONF_ALLOW_ALL_IMPORTS = "allow_all_imports"
31+
32+
CONFIG_SCHEMA = vol.Schema(
33+
{
34+
DOMAIN: vol.Schema(
35+
{vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): cv.boolean}
36+
),
37+
},
38+
extra=vol.ALLOW_EXTRA,
39+
)
40+
2741

2842
async def async_setup(hass, config):
2943
"""Initialize the pyscript component."""
30-
3144
handler_func = Handler(hass)
3245
event_func = Event(hass)
3346
trig_time_func = TrigTime(hass, handler_func)
@@ -44,6 +57,9 @@ def check_isdir(path):
4457
_LOGGER.error("Folder %s not found in configuration folder", FOLDER)
4558
return False
4659

60+
hass.data.setdefault(DOMAIN, {})
61+
hass.data[DOMAIN]["allow_all_imports"] = config[DOMAIN].get(CONF_ALLOW_ALL_IMPORTS)
62+
4763
await compile_scripts( # pylint: disable=unused-variable
4864
hass,
4965
event_func=event_func,

custom_components/pyscript/eval.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import sys
1010

11-
from .const import ALLOWED_IMPORTS, LOGGER_PATH
11+
from .const import ALLOWED_IMPORTS, DOMAIN, LOGGER_PATH
1212

1313
_LOGGER = logging.getLogger(LOGGER_PATH + ".eval")
1414

@@ -361,6 +361,11 @@ def __init__(
361361
self.logger_handlers = set()
362362
self.logger = None
363363
self.set_logger_name(logger_name if logger_name is not None else self.name)
364+
self.allow_all_imports = (
365+
global_ctx.hass.data[DOMAIN]["allow_all_imports"]
366+
if global_ctx.hass is not None
367+
else False
368+
)
364369

365370
async def ast_not_implemented(self, arg, *args):
366371
"""Raise NotImplementedError exception for unimplemented AST types."""
@@ -399,7 +404,7 @@ async def ast_module(self, arg):
399404
async def ast_import(self, arg):
400405
"""Execute import."""
401406
for imp in arg.names:
402-
if imp.name not in ALLOWED_IMPORTS:
407+
if not self.allow_all_imports and imp.name not in ALLOWED_IMPORTS:
403408
raise ModuleNotFoundError(f"import of {imp.name} not allowed")
404409
if imp.name not in sys.modules:
405410
mod = importlib.import_module(imp.name)
@@ -409,7 +414,7 @@ async def ast_import(self, arg):
409414

410415
async def ast_importfrom(self, arg):
411416
"""Execute from X import Y."""
412-
if arg.module not in ALLOWED_IMPORTS:
417+
if not self.allow_all_imports and arg.module not in ALLOWED_IMPORTS:
413418
raise ModuleNotFoundError(f"import from {arg.module} not allowed")
414419
if arg.module not in sys.modules:
415420
mod = importlib.import_module(arg.module)

0 commit comments

Comments
 (0)