diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index f36e20bed10d..51e7e640fb02 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -14,10 +14,11 @@ import time import unittest import urllib.parse +from collections import UserDict from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -48,6 +49,17 @@ from .logging import get_logger +if is_torch_available(): + import torch + + IS_ROCM_SYSTEM = torch.version.hip is not None + IS_CUDA_SYSTEM = torch.version.cuda is not None + IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None +else: + IS_ROCM_SYSTEM = False + IS_CUDA_SYSTEM = False + IS_XPU_SYSTEM = False + global_rng = random.Random() logger = get_logger(__name__) @@ -1275,3 +1287,93 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN") update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN") update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN") + + +# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090 + +# Type definition of key used in `Expectations` class. +DeviceProperties = Tuple[Union[str, None], Union[int, None]] + + +@functools.lru_cache +def get_device_properties() -> DeviceProperties: + """ + Get environment device properties. + """ + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + import torch + + major, _ = torch.cuda.get_device_capability() + if IS_ROCM_SYSTEM: + return ("rocm", major) + else: + return ("cuda", major) + elif IS_XPU_SYSTEM: + import torch + + # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def + arch = torch.xpu.get_device_capability()["architecture"] + gen_mask = 0x000000FF00000000 + gen = (arch & gen_mask) >> 32 + return ("xpu", gen) + else: + return (torch_device, None) + + +if TYPE_CHECKING: + DevicePropertiesUserDict = UserDict[DeviceProperties, Any] +else: + DevicePropertiesUserDict = UserDict + + +class Expectations(DevicePropertiesUserDict): + def get_expectation(self) -> Any: + """ + Find best matching expectation based on environment device properties. + """ + return self.find_expectation(get_device_properties()) + + @staticmethod + def is_default(key: DeviceProperties) -> bool: + return all(p is None for p in key) + + @staticmethod + def score(key: DeviceProperties, other: DeviceProperties) -> int: + """ + Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using + bits, but documented as int. Rules are as follows: + * Matching `type` gives 8 points. + * Semi-matching `type`, for example cuda and rocm, gives 4 points. + * Matching `major` (compute capability major version) gives 2 points. + * Default expectation (if present) gives 1 points. + """ + (device_type, major) = key + (other_device_type, other_major) = other + + score = 0b0 + if device_type == other_device_type: + score |= 0b1000 + elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]: + score |= 0b100 + + if major == other_major and other_major is not None: + score |= 0b10 + + if Expectations.is_default(other): + score |= 0b1 + + return int(score) + + def find_expectation(self, key: DeviceProperties = (None, None)) -> Any: + """ + Find best matching expectation based on provided device properties. + """ + (result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0])) + + if Expectations.score(key, result_key) == 0: + raise ValueError(f"No matching expectation found for {key}") + + return result + + def __repr__(self): + return f"{self.data}" diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py index 65c0b3ece4d2..a17c7a50c866 100755 --- a/tests/others/test_utils.py +++ b/tests/others/test_utils.py @@ -20,7 +20,7 @@ from diffusers import __version__ from diffusers.utils import deprecate -from diffusers.utils.testing_utils import str_to_bool +from diffusers.utils.testing_utils import Expectations, str_to_bool # Used to test the hub @@ -182,6 +182,38 @@ def test_deprecate_stacklevel(self): assert "diffusers/tests/others/test_utils.py" in warning.filename +# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py +class ExpectationsTester(unittest.TestCase): + def test_expectations(self): + expectations = Expectations( + { + (None, None): 1, + ("cuda", 8): 2, + ("cuda", 7): 3, + ("rocm", 8): 4, + ("rocm", None): 5, + ("cpu", None): 6, + ("xpu", 3): 7, + } + ) + + def check(value, key): + assert expectations.find_expectation(key) == value + + # npu has no matches so should find default expectation + check(1, ("npu", None)) + check(7, ("xpu", 3)) + check(2, ("cuda", 8)) + check(3, ("cuda", 7)) + check(4, ("rocm", 9)) + check(4, ("rocm", None)) + check(2, ("cuda", 2)) + + expectations = Expectations({("cuda", 8): 1}) + with self.assertRaises(ValueError): + expectations.find_expectation(("xpu", None)) + + def parse_flag_from_env(key, default=False): try: value = os.environ[key] diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index f7c450aab93e..80bb35a08e16 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -15,6 +15,7 @@ ) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( + Expectations, backend_empty_cache, floats_tensor, numpy_cosine_similarity_distance, @@ -208,41 +209,115 @@ def test_sd3_img2img_inference(self): inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - 0.5435, - 0.4673, - 0.5732, - 0.4438, - 0.3557, - 0.4912, - 0.4331, - 0.3491, - 0.4915, - 0.4287, - 0.3477, - 0.4849, - 0.4355, - 0.3469, - 0.4871, - 0.4431, - 0.3538, - 0.4912, - 0.4521, - 0.3643, - 0.5059, - 0.4587, - 0.3730, - 0.5166, - 0.4685, - 0.3845, - 0.5264, - 0.4746, - 0.3914, - 0.5342, - ] + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.5117, + 0.4421, + 0.3852, + 0.5044, + 0.4219, + 0.3262, + 0.5024, + 0.4329, + 0.3276, + 0.4978, + 0.4412, + 0.3355, + 0.4983, + 0.4338, + 0.3279, + 0.4893, + 0.4241, + 0.3129, + 0.4875, + 0.4253, + 0.3030, + 0.4961, + 0.4267, + 0.2988, + 0.5029, + 0.4255, + 0.3054, + 0.5132, + 0.4248, + 0.3222, + ] + ), + ("cuda", 7): np.array( + [ + 0.5435, + 0.4673, + 0.5732, + 0.4438, + 0.3557, + 0.4912, + 0.4331, + 0.3491, + 0.4915, + 0.4287, + 0.347, + 0.4849, + 0.4355, + 0.3469, + 0.4871, + 0.4431, + 0.3538, + 0.4912, + 0.4521, + 0.3643, + 0.5059, + 0.4587, + 0.373, + 0.5166, + 0.4685, + 0.3845, + 0.5264, + 0.4746, + 0.3914, + 0.5342, + ] + ), + ("cuda", 8): np.array( + [ + 0.5146, + 0.4385, + 0.3826, + 0.5098, + 0.4150, + 0.3218, + 0.5142, + 0.4312, + 0.3298, + 0.5127, + 0.4431, + 0.3411, + 0.5171, + 0.4424, + 0.3374, + 0.5088, + 0.4348, + 0.3242, + 0.5073, + 0.4380, + 0.3174, + 0.5132, + 0.4397, + 0.3115, + 0.5132, + 0.4343, + 0.3118, + 0.5219, + 0.4328, + 0.3256, + ] + ), + } ) + expected_slice = expected_slices.get_expectation() + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"