diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2e055d85fd93..406f1d999d9f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -36,7 +36,10 @@ import importlib_metadata else: import importlib.metadata as importlib_metadata - +try: + _package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls +except Exception: + _package_map = None logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -56,35 +59,32 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]: + global _package_map pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: + if _package_map is None: + _package_map = defaultdict(list) + try: + # Fallback for Python < 3.10 + for dist in importlib_metadata.distributions(): + _top_level_declared = (dist.read_text("top_level.txt") or "").split() + _infered_opt_names = { + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) + } - {None} + _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) + for pkg in _top_level_declared or _top_level_inferred: + _package_map[pkg].append(dist.metadata["Name"]) + except Exception as _: + pass try: - package_map = importlib_metadata.packages_distributions() - except Exception as e: - package_map = defaultdict(list) - if isinstance(e, AttributeError): - try: - # Fallback for Python < 3.10 - for dist in importlib_metadata.distributions(): - _top_level_declared = (dist.read_text("top_level.txt") or "").split() - _infered_opt_names = { - f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) - } - {None} - _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) - for pkg in _top_level_declared or _top_level_inferred: - package_map[pkg].append(dist.metadata["Name"]) - except Exception as _: - pass - - try: - if get_dist_name and pkg_name in package_map and package_map[pkg_name]: - if len(package_map[pkg_name]) > 1: + if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]: + if len(_package_map[pkg_name]) > 1: logger.warning( - f"Multiple distributions found for package {pkg_name}. Picked distribution: {package_map[pkg_name][0]}" + f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}" ) - pkg_name = package_map[pkg_name][0] + pkg_name = _package_map[pkg_name][0] pkg_version = importlib_metadata.version(pkg_name) logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") except (ImportError, importlib_metadata.PackageNotFoundError):