Skip to content

introduce compute arch specific expectations and fix test_sd3_img2img_inference failure #11227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import time
import unittest
import urllib.parse
from collections import UserDict
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -48,6 +49,17 @@
from .logging import get_logger


if is_torch_available():
import torch

IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
else:
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False

global_rng = random.Random()

logger = get_logger(__name__)
Expand Down Expand Up @@ -1275,3 +1287,93 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")


# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090

# Type definition of key used in `Expectations` class.
DeviceProperties = Tuple[Union[str, None], Union[int, None]]


@functools.lru_cache
def get_device_properties() -> DeviceProperties:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return a dict instead? We could have sensible key namings like:

{"device_type": "cuda", "compute_info": 8.0}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise it internally with transformers team as it's best to keep this in sync with their code (except Python 3.8 changes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me. Cc: @ivarflakstad

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey!
Honestly any type that can be used as a key is fine by me.
Originally it was a dataclass with fields, but we wanted to keep it as simple as possible while still solving the problem it's addressing.

Feel free to upstream any improvements you can think of ☺️

"""
Get environment device properties.
"""
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
import torch

major, _ = torch.cuda.get_device_capability()
if IS_ROCM_SYSTEM:
return ("rocm", major)
else:
return ("cuda", major)
elif IS_XPU_SYSTEM:
import torch

# To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
arch = torch.xpu.get_device_capability()["architecture"]
gen_mask = 0x000000FF00000000
gen = (arch & gen_mask) >> 32
return ("xpu", gen)
else:
return (torch_device, None)


if TYPE_CHECKING:
DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
else:
DevicePropertiesUserDict = UserDict


class Expectations(DevicePropertiesUserDict):
def get_expectation(self) -> Any:
"""
Find best matching expectation based on environment device properties.
"""
return self.find_expectation(get_device_properties())

@staticmethod
def is_default(key: DeviceProperties) -> bool:
return all(p is None for p in key)

@staticmethod
def score(key: DeviceProperties, other: DeviceProperties) -> int:
"""
Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
bits, but documented as int. Rules are as follows:
* Matching `type` gives 8 points.
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
* Matching `major` (compute capability major version) gives 2 points.
* Default expectation (if present) gives 1 points.
"""
(device_type, major) = key
(other_device_type, other_major) = other

score = 0b0
if device_type == other_device_type:
score |= 0b1000
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
score |= 0b100

if major == other_major and other_major is not None:
score |= 0b10

if Expectations.is_default(other):
score |= 0b1

return int(score)

def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
"""
Find best matching expectation based on provided device properties.
"""
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))

if Expectations.score(key, result_key) == 0:
raise ValueError(f"No matching expectation found for {key}")

return result

def __repr__(self):
return f"{self.data}"
34 changes: 33 additions & 1 deletion tests/others/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from diffusers import __version__
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import str_to_bool
from diffusers.utils.testing_utils import Expectations, str_to_bool


# Used to test the hub
Expand Down Expand Up @@ -182,6 +182,38 @@ def test_deprecate_stacklevel(self):
assert "diffusers/tests/others/test_utils.py" in warning.filename


# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):
def test_expectations(self):
expectations = Expectations(
{
(None, None): 1,
("cuda", 8): 2,
("cuda", 7): 3,
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
("xpu", 3): 7,
}
)

def check(value, key):
assert expectations.find_expectation(key) == value

# npu has no matches so should find default expectation
check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))
check(4, ("rocm", None))
check(2, ("cuda", 2))

expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))


def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
numpy_cosine_similarity_distance,
Expand Down Expand Up @@ -208,41 +209,115 @@ def test_sd3_img2img_inference(self):
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.5435,
0.4673,
0.5732,
0.4438,
0.3557,
0.4912,
0.4331,
0.3491,
0.4915,
0.4287,
0.3477,
0.4849,
0.4355,
0.3469,
0.4871,
0.4431,
0.3538,
0.4912,
0.4521,
0.3643,
0.5059,
0.4587,
0.3730,
0.5166,
0.4685,
0.3845,
0.5264,
0.4746,
0.3914,
0.5342,
]
expected_slices = Expectations(
{
("xpu", 3): np.array(
[
0.5117,
0.4421,
0.3852,
0.5044,
0.4219,
0.3262,
0.5024,
0.4329,
0.3276,
0.4978,
0.4412,
0.3355,
0.4983,
0.4338,
0.3279,
0.4893,
0.4241,
0.3129,
0.4875,
0.4253,
0.3030,
0.4961,
0.4267,
0.2988,
0.5029,
0.4255,
0.3054,
0.5132,
0.4248,
0.3222,
]
),
("cuda", 7): np.array(
[
0.5435,
0.4673,
0.5732,
0.4438,
0.3557,
0.4912,
0.4331,
0.3491,
0.4915,
0.4287,
0.347,
0.4849,
0.4355,
0.3469,
0.4871,
0.4431,
0.3538,
0.4912,
0.4521,
0.3643,
0.5059,
0.4587,
0.373,
0.5166,
0.4685,
0.3845,
0.5264,
0.4746,
0.3914,
0.5342,
]
),
("cuda", 8): np.array(
[
0.5146,
0.4385,
0.3826,
0.5098,
0.4150,
0.3218,
0.5142,
0.4312,
0.3298,
0.5127,
0.4431,
0.3411,
0.5171,
0.4424,
0.3374,
0.5088,
0.4348,
0.3242,
0.5073,
0.4380,
0.3174,
0.5132,
0.4397,
0.3115,
0.5132,
0.4343,
0.3118,
0.5219,
0.4328,
0.3256,
]
),
}
)

expected_slice = expected_slices.get_expectation()

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
Loading