8
8
import logging
9
9
import sys
10
10
11
- from .const import ALLOWED_IMPORTS , LOGGER_PATH
11
+ from .const import ALLOWED_IMPORTS , DOMAIN , LOGGER_PATH
12
12
13
13
_LOGGER = logging .getLogger (LOGGER_PATH + ".eval" )
14
14
@@ -361,6 +361,11 @@ def __init__(
361
361
self .logger_handlers = set ()
362
362
self .logger = None
363
363
self .set_logger_name (logger_name if logger_name is not None else self .name )
364
+ self .allow_all_imports = (
365
+ global_ctx .hass .data [DOMAIN ]["allow_all_imports" ]
366
+ if global_ctx .hass is not None
367
+ else False
368
+ )
364
369
365
370
async def ast_not_implemented (self , arg , * args ):
366
371
"""Raise NotImplementedError exception for unimplemented AST types."""
@@ -399,7 +404,7 @@ async def ast_module(self, arg):
399
404
async def ast_import (self , arg ):
400
405
"""Execute import."""
401
406
for imp in arg .names :
402
- if imp .name not in ALLOWED_IMPORTS :
407
+ if not self . allow_all_imports and imp .name not in ALLOWED_IMPORTS :
403
408
raise ModuleNotFoundError (f"import of { imp .name } not allowed" )
404
409
if imp .name not in sys .modules :
405
410
mod = importlib .import_module (imp .name )
@@ -409,7 +414,7 @@ async def ast_import(self, arg):
409
414
410
415
async def ast_importfrom (self , arg ):
411
416
"""Execute from X import Y."""
412
- if arg .module not in ALLOWED_IMPORTS :
417
+ if not self . allow_all_imports and arg .module not in ALLOWED_IMPORTS :
413
418
raise ModuleNotFoundError (f"import from { arg .module } not allowed" )
414
419
if arg .module not in sys .modules :
415
420
mod = importlib .import_module (arg .module )
0 commit comments