From 24a46cc3c6e7e158f9f58043c32b7562154a0658 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 May 2025 18:05:05 +0530 Subject: [PATCH 01/56] start overhauling the benchmarking suite. --- benchmarks/__init__.py | 0 benchmarks/base_classes.py | 346 -------------------------- benchmarks/benchmark_controlnet.py | 26 -- benchmarks/benchmark_ip_adapters.py | 33 --- benchmarks/benchmark_sd_img.py | 29 --- benchmarks/benchmark_sd_inpainting.py | 28 --- benchmarks/benchmark_t2i_adapter.py | 28 --- benchmarks/benchmark_t2i_lcm_lora.py | 23 -- benchmarks/benchmark_text_to_image.py | 40 --- benchmarks/benchmarking_flux.py | 40 +++ benchmarks/benchmarking_utils.py | 81 ++++++ benchmarks/push_results.py | 72 ------ benchmarks/run_all.py | 101 -------- benchmarks/utils.py | 98 -------- 14 files changed, 121 insertions(+), 824 deletions(-) create mode 100644 benchmarks/__init__.py delete mode 100644 benchmarks/base_classes.py delete mode 100644 benchmarks/benchmark_controlnet.py delete mode 100644 benchmarks/benchmark_ip_adapters.py delete mode 100644 benchmarks/benchmark_sd_img.py delete mode 100644 benchmarks/benchmark_sd_inpainting.py delete mode 100644 benchmarks/benchmark_t2i_adapter.py delete mode 100644 benchmarks/benchmark_t2i_lcm_lora.py delete mode 100644 benchmarks/benchmark_text_to_image.py create mode 100644 benchmarks/benchmarking_flux.py create mode 100644 benchmarks/benchmarking_utils.py delete mode 100644 benchmarks/push_results.py delete mode 100644 benchmarks/run_all.py delete mode 100644 benchmarks/utils.py diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py deleted file mode 100644 index 45bf65c93c93..000000000000 --- a/benchmarks/base_classes.py +++ /dev/null @@ -1,346 +0,0 @@ -import os -import sys - -import torch - -from diffusers import ( - AutoPipelineForImage2Image, - AutoPipelineForInpainting, - AutoPipelineForText2Image, - ControlNetModel, - LCMScheduler, - StableDiffusionAdapterPipeline, - StableDiffusionControlNetPipeline, - StableDiffusionXLAdapterPipeline, - StableDiffusionXLControlNetPipeline, - T2IAdapter, - WuerstchenCombinedPipeline, -) -from diffusers.utils import load_image - - -sys.path.append(".") - -from utils import ( # noqa: E402 - BASE_PATH, - PROMPT, - BenchmarkInfo, - benchmark_fn, - bytes_to_giga_bytes, - flush, - generate_csv_dict, - write_to_csv, -) - - -RESOLUTION_MAPPING = { - "Lykon/DreamShaper": (512, 512), - "lllyasviel/sd-controlnet-canny": (512, 512), - "diffusers/controlnet-canny-sdxl-1.0": (1024, 1024), - "TencentARC/t2iadapter_canny_sd14v1": (512, 512), - "TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024), - "stabilityai/stable-diffusion-2-1": (768, 768), - "stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024), - "stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024), - "stabilityai/sdxl-turbo": (512, 512), -} - - -class BaseBenchmak: - pipeline_class = None - - def __init__(self, args): - super().__init__() - - def run_inference(self, args): - raise NotImplementedError - - def benchmark(self, args): - raise NotImplementedError - - def get_result_filepath(self, args): - pipeline_class_name = str(self.pipe.__class__.__name__) - name = ( - args.ckpt.replace("/", "_") - + "_" - + pipeline_class_name - + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv" - ) - filepath = os.path.join(BASE_PATH, name) - return filepath - - -class TextToImageBenchmark(BaseBenchmak): - pipeline_class = AutoPipelineForText2Image - - def __init__(self, args): - pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) - pipe = pipe.to("cuda") - - if args.run_compile: - if not isinstance(pipe, WuerstchenCombinedPipeline): - pipe.unet.to(memory_format=torch.channels_last) - print("Run torch compile") - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - - if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None: - pipe.movq.to(memory_format=torch.channels_last) - pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True) - else: - print("Run torch compile") - pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True) - pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True) - - pipe.set_progress_bar_config(disable=True) - self.pipe = pipe - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - ) - - def benchmark(self, args): - flush() - - print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n") - - time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds. - memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. - benchmark_info = BenchmarkInfo(time=time, memory=memory) - - pipeline_class_name = str(self.pipe.__class__.__name__) - flush() - csv_dict = generate_csv_dict( - pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info - ) - filepath = self.get_result_filepath(args) - write_to_csv(filepath, csv_dict) - print(f"Logs written to: {filepath}") - flush() - - -class TurboTextToImageBenchmark(TextToImageBenchmark): - def __init__(self, args): - super().__init__(args) - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - guidance_scale=0.0, - ) - - -class LCMLoRATextToImageBenchmark(TextToImageBenchmark): - lora_id = "latent-consistency/lcm-lora-sdxl" - - def __init__(self, args): - super().__init__(args) - self.pipe.load_lora_weights(self.lora_id) - self.pipe.fuse_lora() - self.pipe.unload_lora_weights() - self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) - - def get_result_filepath(self, args): - pipeline_class_name = str(self.pipe.__class__.__name__) - name = ( - self.lora_id.replace("/", "_") - + "_" - + pipeline_class_name - + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv" - ) - filepath = os.path.join(BASE_PATH, name) - return filepath - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - guidance_scale=1.0, - ) - - def benchmark(self, args): - flush() - - print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n") - - time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds. - memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. - benchmark_info = BenchmarkInfo(time=time, memory=memory) - - pipeline_class_name = str(self.pipe.__class__.__name__) - flush() - csv_dict = generate_csv_dict( - pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info - ) - filepath = self.get_result_filepath(args) - write_to_csv(filepath, csv_dict) - print(f"Logs written to: {filepath}") - flush() - - -class ImageToImageBenchmark(TextToImageBenchmark): - pipeline_class = AutoPipelineForImage2Image - url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg" - image = load_image(url).convert("RGB") - - def __init__(self, args): - super().__init__(args) - self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - image=self.image, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - ) - - -class TurboImageToImageBenchmark(ImageToImageBenchmark): - def __init__(self, args): - super().__init__(args) - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - image=self.image, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - guidance_scale=0.0, - strength=0.5, - ) - - -class InpaintingBenchmark(ImageToImageBenchmark): - pipeline_class = AutoPipelineForInpainting - mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png" - mask = load_image(mask_url).convert("RGB") - - def __init__(self, args): - super().__init__(args) - self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) - self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt]) - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - image=self.image, - mask_image=self.mask, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - ) - - -class IPAdapterTextToImageBenchmark(TextToImageBenchmark): - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png" - image = load_image(url) - - def __init__(self, args): - pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda") - pipe.load_ip_adapter( - args.ip_adapter_id[0], - subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models", - weight_name=args.ip_adapter_id[1], - ) - - if args.run_compile: - pipe.unet.to(memory_format=torch.channels_last) - print("Run torch compile") - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - - pipe.set_progress_bar_config(disable=True) - self.pipe = pipe - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - ip_adapter_image=self.image, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - ) - - -class ControlNetBenchmark(TextToImageBenchmark): - pipeline_class = StableDiffusionControlNetPipeline - aux_network_class = ControlNetModel - root_ckpt = "Lykon/DreamShaper" - - url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png" - image = load_image(url).convert("RGB") - - def __init__(self, args): - aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) - pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16) - pipe = pipe.to("cuda") - - pipe.set_progress_bar_config(disable=True) - self.pipe = pipe - - if args.run_compile: - pipe.unet.to(memory_format=torch.channels_last) - pipe.controlnet.to(memory_format=torch.channels_last) - - print("Run torch compile") - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) - - self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) - - def run_inference(self, pipe, args): - _ = pipe( - prompt=PROMPT, - image=self.image, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.batch_size, - ) - - -class ControlNetSDXLBenchmark(ControlNetBenchmark): - pipeline_class = StableDiffusionXLControlNetPipeline - root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" - - def __init__(self, args): - super().__init__(args) - - -class T2IAdapterBenchmark(ControlNetBenchmark): - pipeline_class = StableDiffusionAdapterPipeline - aux_network_class = T2IAdapter - root_ckpt = "Lykon/DreamShaper" - - url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png" - image = load_image(url).convert("L") - - def __init__(self, args): - aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) - pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16) - pipe = pipe.to("cuda") - - pipe.set_progress_bar_config(disable=True) - self.pipe = pipe - - if args.run_compile: - pipe.unet.to(memory_format=torch.channels_last) - pipe.adapter.to(memory_format=torch.channels_last) - - print("Run torch compile") - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True) - - self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) - - -class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark): - pipeline_class = StableDiffusionXLAdapterPipeline - root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" - - url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png" - image = load_image(url) - - def __init__(self, args): - super().__init__(args) diff --git a/benchmarks/benchmark_controlnet.py b/benchmarks/benchmark_controlnet.py deleted file mode 100644 index 9217004461dc..000000000000 --- a/benchmarks/benchmark_controlnet.py +++ /dev/null @@ -1,26 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402 - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="lllyasviel/sd-controlnet-canny", - choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"], - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_pipe = ( - ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args) - ) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_ip_adapters.py b/benchmarks/benchmark_ip_adapters.py deleted file mode 100644 index 9a31a21fc60d..000000000000 --- a/benchmarks/benchmark_ip_adapters.py +++ /dev/null @@ -1,33 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import IPAdapterTextToImageBenchmark # noqa: E402 - - -IP_ADAPTER_CKPTS = { - # because original SD v1.5 has been taken down. - "Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"), - "stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"), -} - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="rstabilityai/stable-diffusion-xl-base-1.0", - choices=list(IP_ADAPTER_CKPTS.keys()), - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt] - benchmark_pipe = IPAdapterTextToImageBenchmark(args) - args.ckpt = f"{args.ckpt} (IP-Adapter)" - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_sd_img.py b/benchmarks/benchmark_sd_img.py deleted file mode 100644 index 772befe8795f..000000000000 --- a/benchmarks/benchmark_sd_img.py +++ /dev/null @@ -1,29 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402 - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="Lykon/DreamShaper", - choices=[ - "Lykon/DreamShaper", - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-xl-refiner-1.0", - "stabilityai/sdxl-turbo", - ], - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_sd_inpainting.py b/benchmarks/benchmark_sd_inpainting.py deleted file mode 100644 index 143adcb0d87c..000000000000 --- a/benchmarks/benchmark_sd_inpainting.py +++ /dev/null @@ -1,28 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import InpaintingBenchmark # noqa: E402 - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="Lykon/DreamShaper", - choices=[ - "Lykon/DreamShaper", - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-xl-base-1.0", - ], - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_pipe = InpaintingBenchmark(args) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_t2i_adapter.py b/benchmarks/benchmark_t2i_adapter.py deleted file mode 100644 index 44b04b470ea6..000000000000 --- a/benchmarks/benchmark_t2i_adapter.py +++ /dev/null @@ -1,28 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402 - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="TencentARC/t2iadapter_canny_sd14v1", - choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"], - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_pipe = ( - T2IAdapterBenchmark(args) - if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1" - else T2IAdapterSDXLBenchmark(args) - ) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_t2i_lcm_lora.py b/benchmarks/benchmark_t2i_lcm_lora.py deleted file mode 100644 index 957e0a463e28..000000000000 --- a/benchmarks/benchmark_t2i_lcm_lora.py +++ /dev/null @@ -1,23 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import LCMLoRATextToImageBenchmark # noqa: E402 - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="stabilityai/stable-diffusion-xl-base-1.0", - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=4) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_pipe = LCMLoRATextToImageBenchmark(args) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_text_to_image.py b/benchmarks/benchmark_text_to_image.py deleted file mode 100644 index ddc7fb2676a5..000000000000 --- a/benchmarks/benchmark_text_to_image.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -import sys - - -sys.path.append(".") -from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402 - - -ALL_T2I_CKPTS = [ - "Lykon/DreamShaper", - "segmind/SSD-1B", - "stabilityai/stable-diffusion-xl-base-1.0", - "kandinsky-community/kandinsky-2-2-decoder", - "warp-ai/wuerstchen", - "stabilityai/sdxl-turbo", -] - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - type=str, - default="Lykon/DreamShaper", - choices=ALL_T2I_CKPTS, - ) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_cpu_offload", action="store_true") - parser.add_argument("--run_compile", action="store_true") - args = parser.parse_args() - - benchmark_cls = None - if "turbo" in args.ckpt: - benchmark_cls = TurboTextToImageBenchmark - else: - benchmark_cls = TextToImageBenchmark - - benchmark_pipe = benchmark_cls(args) - benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py new file mode 100644 index 000000000000..5218eb9ac698 --- /dev/null +++ b/benchmarks/benchmarking_flux.py @@ -0,0 +1,40 @@ +import torch + +from diffusers import FluxTransformer2DModel +from diffusers.utils.testing_utils import torch_device + +from .benchmarking_utils import BenchmarkMixin + + +class BenchmarkFlux(BenchmarkMixin): + model_class = FluxTransformer2DModel + compile_kwargs = {"fullgraph": True, "mode": "max-autotune"} + + def get_model_init_dict(self): + return {"ckpt_id": "black-forest-labs/FLUX.1-dev", "subfolder": "transformer", "torch_dtype": torch.bfloat16} + + def initialize_model(self): + model = self.model_class.from_pretrained(**self.get_model_init_dict()) + model = model.to(torch_device).eval() + return model + + def get_input_dict(self): + # resolution: 1024x1024 + # maximum sequence length 512 + hidden_states = torch.randn(1, 4096, 64, device=torch_device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn(1, 512, 4096, device=torch_device, dtype=torch.bfloat16) + pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16) + image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16) + text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0], device=torch_device) + guidance = torch.tensor([1.0], device=torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + "guidance": guidance, + } diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py new file mode 100644 index 000000000000..0199f1fd7ec6 --- /dev/null +++ b/benchmarks/benchmarking_utils.py @@ -0,0 +1,81 @@ +import gc + +import torch +from torch.utils.benchmark import benchmark + +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.testing_utils import require_torch_gpu + + +def benchmark_fn(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f}, + num_threads=torch.get_num_threads(), + ) + return f"{(t0.blocked_autorange().mean):.3f}" + + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + +@require_torch_gpu +class BenchmarkMixin: + model_class: ModelMixin = None + compile_kwargs: dict = None + + def get_model_init_dict(self): + raise NotImplementedError + + def initialize_model(self): + raise NotImplementedError + + def get_input_dict(self): + raise NotImplementedError + + def pre_benchmark(self): + flush() + torch.compiler.reset() + + def post_benchmark(self, model): + model.cpu() + flush() + torch.compiler.reset() + + @torch.no_grad() + def run_benchmark(self): + self.pre_benchmark() + + model = self.initialize_model() # Takes care of device placement. + input_dict = self.get_input_dict() # Takes care of device placement. + + # warmup + for _ in range(5): + _ = model(**input_dict) + + time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) + memory = torch.cuda.max_memory_allocated() / (1024**3) + memory = float(f"{memory:.2f}") + non_compile_stats = {"time": time, "memory": memory} + + self.post_benchmark(model) + del model + self.pre_benchmark() + + compile_stats = None + if self.compile_kwargs is not None: + model = self.initialize_model() + with torch._inductor.utils.fresh_inductor_cache(): + model.compile(**self.compile_kwargs) + time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) + memory = torch.cuda.max_memory_allocated() / (1024**3) + memory = float(f"{memory:.2f}") + compile_stats = {"time": time, "memory": memory} + + self.post_benchmark(model) + del model + return non_compile_stats, compile_stats diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py deleted file mode 100644 index 71cd60f32c0f..000000000000 --- a/benchmarks/push_results.py +++ /dev/null @@ -1,72 +0,0 @@ -import glob -import sys - -import pandas as pd -from huggingface_hub import hf_hub_download, upload_file -from huggingface_hub.utils import EntryNotFoundError - - -sys.path.append(".") -from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402 - - -def has_previous_benchmark() -> str: - csv_path = None - try: - csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE) - except EntryNotFoundError: - csv_path = None - return csv_path - - -def filter_float(value): - if isinstance(value, str): - return float(value.split()[0]) - return value - - -def push_to_hf_dataset(): - all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv")) - collate_csv(all_csvs, FINAL_CSV_FILE) - - # If there's an existing benchmark file, we should report the changes. - csv_path = has_previous_benchmark() - if csv_path is not None: - current_results = pd.read_csv(FINAL_CSV_FILE) - previous_results = pd.read_csv(csv_path) - - numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns - numeric_columns = [ - c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"] - ] - - for column in numeric_columns: - previous_results[column] = previous_results[column].map(lambda x: filter_float(x)) - - # Calculate the percentage change - current_results[column] = current_results[column].astype(float) - previous_results[column] = previous_results[column].astype(float) - percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100 - - # Format the values with '+' or '-' sign and append to original values - current_results[column] = current_results[column].map(str) + percent_change.map( - lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)" - ) - # There might be newly added rows. So, filter out the NaNs. - current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", "")) - - # Overwrite the current result file. - current_results.to_csv(FINAL_CSV_FILE, index=False) - - commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" - upload_file( - repo_id=REPO_ID, - path_in_repo=FINAL_CSV_FILE, - path_or_fileobj=FINAL_CSV_FILE, - repo_type="dataset", - commit_message=commit_message, - ) - - -if __name__ == "__main__": - push_to_hf_dataset() diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py deleted file mode 100644 index c9932cc71c38..000000000000 --- a/benchmarks/run_all.py +++ /dev/null @@ -1,101 +0,0 @@ -import glob -import subprocess -import sys -from typing import List - - -sys.path.append(".") -from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402 - - -PATTERN = "benchmark_*.py" - - -class SubprocessCallException(Exception): - pass - - -# Taken from `test_examples_utils.py` -def run_command(command: List[str], return_stdout=False): - """ - Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture - if an error occurred while running `command` - """ - try: - output = subprocess.check_output(command, stderr=subprocess.STDOUT) - if return_stdout: - if hasattr(output, "decode"): - output = output.decode("utf-8") - return output - except subprocess.CalledProcessError as e: - raise SubprocessCallException( - f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" - ) from e - - -def main(): - python_files = glob.glob(PATTERN) - - for file in python_files: - print(f"****** Running file: {file} ******") - - # Run with canonical settings. - if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py": - command = f"python {file}" - run_command(command.split()) - - command += " --run_compile" - run_command(command.split()) - - # Run variants. - for file in python_files: - # See: https://github.com/pytorch/pytorch/issues/129637 - if file == "benchmark_ip_adapters.py": - continue - - if file == "benchmark_text_to_image.py": - for ckpt in ALL_T2I_CKPTS: - command = f"python {file} --ckpt {ckpt}" - - if "turbo" in ckpt: - command += " --num_inference_steps 1" - - run_command(command.split()) - - command += " --run_compile" - run_command(command.split()) - - elif file == "benchmark_sd_img.py": - for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]: - command = f"python {file} --ckpt {ckpt}" - - if ckpt == "stabilityai/sdxl-turbo": - command += " --num_inference_steps 2" - - run_command(command.split()) - command += " --run_compile" - run_command(command.split()) - - elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]: - sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" - command = f"python {file} --ckpt {sdxl_ckpt}" - run_command(command.split()) - - command += " --run_compile" - run_command(command.split()) - - elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]: - sdxl_ckpt = ( - "diffusers/controlnet-canny-sdxl-1.0" - if "controlnet" in file - else "TencentARC/t2i-adapter-canny-sdxl-1.0" - ) - command = f"python {file} --ckpt {sdxl_ckpt}" - run_command(command.split()) - - command += " --run_compile" - run_command(command.split()) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/utils.py b/benchmarks/utils.py deleted file mode 100644 index 5fce920ac6c3..000000000000 --- a/benchmarks/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import csv -import gc -import os -from dataclasses import dataclass -from typing import Dict, List, Union - -import torch -import torch.utils.benchmark as benchmark - - -GITHUB_SHA = os.getenv("GITHUB_SHA", None) -BENCHMARK_FIELDS = [ - "pipeline_cls", - "ckpt_id", - "batch_size", - "num_inference_steps", - "model_cpu_offload", - "run_compile", - "time (secs)", - "memory (gbs)", - "actual_gpu_memory (gbs)", - "github_sha", -] - -PROMPT = "ghibli style, a fantasy landscape with castles" -BASE_PATH = os.getenv("BASE_PATH", ".") -TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3))) - -REPO_ID = "diffusers/benchmarks" -FINAL_CSV_FILE = "collated_results.csv" - - -@dataclass -class BenchmarkInfo: - time: float - memory: float - - -def flush(): - """Wipes off memory.""" - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - -def bytes_to_giga_bytes(bytes): - return f"{(bytes / 1024 / 1024 / 1024):.3f}" - - -def benchmark_fn(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "f": f}, - num_threads=torch.get_num_threads(), - ) - return f"{(t0.blocked_autorange().mean):.3f}" - - -def generate_csv_dict( - pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo -) -> Dict[str, Union[str, bool, float]]: - """Packs benchmarking data into a dictionary for latter serialization.""" - data_dict = { - "pipeline_cls": pipeline_cls, - "ckpt_id": ckpt, - "batch_size": args.batch_size, - "num_inference_steps": args.num_inference_steps, - "model_cpu_offload": args.model_cpu_offload, - "run_compile": args.run_compile, - "time (secs)": benchmark_info.time, - "memory (gbs)": benchmark_info.memory, - "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}", - "github_sha": GITHUB_SHA, - } - return data_dict - - -def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]): - """Serializes a dictionary into a CSV file.""" - with open(file_name, mode="w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS) - writer.writeheader() - writer.writerow(data_dict) - - -def collate_csv(input_files: List[str], output_file: str): - """Collates multiple identically structured CSVs into a single CSV file.""" - with open(output_file, mode="w", newline="") as outfile: - writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS) - writer.writeheader() - - for file in input_files: - with open(file, mode="r") as infile: - reader = csv.DictReader(infile) - for row in reader: - writer.writerow(row) From ab7f381c5219eaae445128f404598acc300df15f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 May 2025 18:41:52 +0530 Subject: [PATCH 02/56] fixes --- benchmarks/benchmarking_flux.py | 9 ++++++--- benchmarks/benchmarking_utils.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py index 5218eb9ac698..64820e66327f 100644 --- a/benchmarks/benchmarking_flux.py +++ b/benchmarks/benchmarking_flux.py @@ -1,17 +1,20 @@ import torch +from benchmarking_utils import BenchmarkMixin from diffusers import FluxTransformer2DModel from diffusers.utils.testing_utils import torch_device -from .benchmarking_utils import BenchmarkMixin - class BenchmarkFlux(BenchmarkMixin): model_class = FluxTransformer2DModel compile_kwargs = {"fullgraph": True, "mode": "max-autotune"} def get_model_init_dict(self): - return {"ckpt_id": "black-forest-labs/FLUX.1-dev", "subfolder": "transformer", "torch_dtype": torch.bfloat16} + return { + "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", + "subfolder": "transformer", + "torch_dtype": torch.bfloat16, + } def initialize_model(self): model = self.model_class.from_pretrained(**self.get_model_init_dict()) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 0199f1fd7ec6..c69766ae1e45 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,7 +1,7 @@ import gc import torch -from torch.utils.benchmark import benchmark +import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin from diffusers.utils.testing_utils import require_torch_gpu From cc0a38a2254263b35af5a3c9dcb8e46fba80f034 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 May 2025 21:05:11 +0530 Subject: [PATCH 03/56] fixes --- benchmarks/benchmarking_flux.py | 6 +++--- benchmarks/benchmarking_utils.py | 12 ++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py index 64820e66327f..802cadaf9df8 100644 --- a/benchmarks/benchmarking_flux.py +++ b/benchmarks/benchmarking_flux.py @@ -7,7 +7,7 @@ class BenchmarkFlux(BenchmarkMixin): model_class = FluxTransformer2DModel - compile_kwargs = {"fullgraph": True, "mode": "max-autotune"} + compile_kwargs = {"fullgraph": True} def get_model_init_dict(self): return { @@ -29,8 +29,8 @@ def get_input_dict(self): pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16) image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16) text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16) - timestep = torch.tensor([1.0], device=torch_device) - guidance = torch.tensor([1.0], device=torch_device) + timestep = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) + guidance = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) return { "hidden_states": hidden_states, diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index c69766ae1e45..9f92424582ce 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -11,7 +11,7 @@ def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}, - num_threads=torch.get_num_threads(), + num_threads=1, ) return f"{(t0.blocked_autorange().mean):.3f}" @@ -53,10 +53,6 @@ def run_benchmark(self): model = self.initialize_model() # Takes care of device placement. input_dict = self.get_input_dict() # Takes care of device placement. - # warmup - for _ in range(5): - _ = model(**input_dict) - time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) memory = torch.cuda.max_memory_allocated() / (1024**3) memory = float(f"{memory:.2f}") @@ -69,9 +65,9 @@ def run_benchmark(self): compile_stats = None if self.compile_kwargs is not None: model = self.initialize_model() - with torch._inductor.utils.fresh_inductor_cache(): - model.compile(**self.compile_kwargs) - time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) + input_dict = self.get_input_dict() + model.compile(**self.compile_kwargs) + time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) memory = torch.cuda.max_memory_allocated() / (1024**3) memory = float(f"{memory:.2f}") compile_stats = {"time": time, "memory": memory} From 169f831cf33f8e24706ea67546b28dd48d18c956 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 May 2025 21:20:13 +0530 Subject: [PATCH 04/56] checking. --- benchmarks/benchmarking_utils.py | 89 +++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 9f92424582ce..0074dedf4339 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,4 +1,6 @@ import gc +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional import torch import torch.utils.benchmark as benchmark @@ -13,7 +15,7 @@ def benchmark_fn(f, *args, **kwargs): globals={"args": args, "kwargs": kwargs, "f": f}, num_threads=1, ) - return f"{(t0.blocked_autorange().mean):.3f}" + return float(f"{(t0.blocked_autorange().mean):.3f}") def flush(): @@ -23,11 +25,18 @@ def flush(): torch.cuda.reset_peak_memory_stats() +@dataclass +class BenchmarkScenario: + name: str + model_cls: ModelMixin + model_init_kwargs: Dict[str, Any] + model_init_fn: Callable + get_model_input_dict: Callable[[], Dict[str, Any]] + compile_kwargs: Optional[Dict[str, Any]] = None + + @require_torch_gpu class BenchmarkMixin: - model_class: ModelMixin = None - compile_kwargs: dict = None - def get_model_init_dict(self): raise NotImplementedError @@ -47,31 +56,61 @@ def post_benchmark(self, model): torch.compiler.reset() @torch.no_grad() - def run_benchmark(self): + def run_benchmark(self, scenario: BenchmarkScenario): + # 1) plain stats + plain = self._run_phase( + init_fn=scenario.model_init_fn, + init_kwargs=scenario.model_init_kwargs, + get_input_fn=scenario.get_model_input_dict, + compile_kwargs=None, + ) + + # 2) compiled stats (if any) + compiled = None + if scenario.compile_kwargs: + compiled = self._run_phase( + init_fn=scenario.model_init_fn, + init_kwargs=scenario.model_init_kwargs, + get_input_fn=scenario.get_model_input_dict, + compile_kwargs=scenario.compile_kwargs, + ) + + # 3) merge + result = {"scenario": scenario.name, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"]} + if compiled: + result.update( + { + "time_compile_s": compiled["time"], + "mem_compile_GB": compiled["memory"], + } + ) + return result + + def _run_phase( + self, + *, + init_fn: Callable[..., Any], + init_kwargs: Dict[str, Any], + get_input_fn: Callable[[], Dict[str, torch.Tensor]], + compile_kwargs: Optional[Dict[str, Any]], + ) -> Dict[str, float]: + # setup self.pre_benchmark() - model = self.initialize_model() # Takes care of device placement. - input_dict = self.get_input_dict() # Takes care of device placement. - - time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) - memory = torch.cuda.max_memory_allocated() / (1024**3) - memory = float(f"{memory:.2f}") - non_compile_stats = {"time": time, "memory": memory} + # init & (optional) compile + model = init_fn(**init_kwargs) + if compile_kwargs: + model.compile(**compile_kwargs) - self.post_benchmark(model) - del model - self.pre_benchmark() + # build inputs + inp = get_input_fn() - compile_stats = None - if self.compile_kwargs is not None: - model = self.initialize_model() - input_dict = self.get_input_dict() - model.compile(**self.compile_kwargs) - time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict) - memory = torch.cuda.max_memory_allocated() / (1024**3) - memory = float(f"{memory:.2f}") - compile_stats = {"time": time, "memory": memory} + # measure + time_s = benchmark_fn(lambda m, d: m(**d), model, inp) + mem_gb = torch.cuda.max_memory_allocated() / (1024**3) + mem_gb = round(mem_gb, 2) + # teardown self.post_benchmark(model) del model - return non_compile_stats, compile_stats + return {"time": time_s, "memory": mem_gb} From ad18983985b6bc4823f89f97986473869277ecad Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 May 2025 21:20:33 +0530 Subject: [PATCH 05/56] checking --- benchmarks/benchmarking_utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 0074dedf4339..f9004f4d5af5 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -37,15 +37,6 @@ class BenchmarkScenario: @require_torch_gpu class BenchmarkMixin: - def get_model_init_dict(self): - raise NotImplementedError - - def initialize_model(self): - raise NotImplementedError - - def get_input_dict(self): - raise NotImplementedError - def pre_benchmark(self): flush() torch.compiler.reset() From 31e34d5e3ebf6909befd4f7e32e3b55271b44611 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 16 May 2025 13:33:10 +0530 Subject: [PATCH 06/56] fixes. --- benchmarks/benchmarking_flux.py | 130 ++++++++++++++++++++++--------- benchmarks/benchmarking_utils.py | 64 +++++++++++---- 2 files changed, 140 insertions(+), 54 deletions(-) diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py index 802cadaf9df8..dbea44d7da6a 100644 --- a/benchmarks/benchmarking_flux.py +++ b/benchmarks/benchmarking_flux.py @@ -1,43 +1,97 @@ +from functools import partial + import torch -from benchmarking_utils import BenchmarkMixin +from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn -from diffusers import FluxTransformer2DModel +from diffusers import BitsAndBytesConfig, FluxTransformer2DModel from diffusers.utils.testing_utils import torch_device -class BenchmarkFlux(BenchmarkMixin): - model_class = FluxTransformer2DModel - compile_kwargs = {"fullgraph": True} - - def get_model_init_dict(self): - return { - "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", - "subfolder": "transformer", - "torch_dtype": torch.bfloat16, - } - - def initialize_model(self): - model = self.model_class.from_pretrained(**self.get_model_init_dict()) - model = model.to(torch_device).eval() - return model - - def get_input_dict(self): - # resolution: 1024x1024 - # maximum sequence length 512 - hidden_states = torch.randn(1, 4096, 64, device=torch_device, dtype=torch.bfloat16) - encoder_hidden_states = torch.randn(1, 512, 4096, device=torch_device, dtype=torch.bfloat16) - pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16) - image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16) - text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16) - timestep = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) - guidance = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) - - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "pooled_projections": pooled_prompt_embeds, - "timestep": timestep, - "guidance": guidance, - } +CKPT_ID = "black-forest-labs/FLUX.1-dev" + + +def get_input_dict(**device_dtype_kwargs): + # resolution: 1024x1024 + # maximum sequence length 512 + hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs) + encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs) + pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs) + image_ids = torch.ones(512, 3, **device_dtype_kwargs) + text_ids = torch.ones(4096, 3, **device_dtype_kwargs) + timestep = torch.tensor([1.0], **device_dtype_kwargs) + guidance = torch.tensor([1.0], **device_dtype_kwargs) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + "guidance": guidance, + } + + +if __name__ == "__main__": + scenarios = [ + BenchmarkScenario( + name=f"{CKPT_ID}-bf16", + model_cls=FluxTransformer2DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, + compile_kwargs={"fullgraph": True}, + ), + BenchmarkScenario( + name=f"{CKPT_ID}-bnb-nf4", + model_cls=FluxTransformer2DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + "quantization_config": BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4" + ), + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, + ), + BenchmarkScenario( + name=f"{CKPT_ID}-layerwise-upcasting", + model_cls=FluxTransformer2DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial(model_init_fn, layerwise_upcasting=True), + ), + BenchmarkScenario( + name=f"{CKPT_ID}-group-offload-leaf", + model_cls=FluxTransformer2DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial( + model_init_fn, + group_offload_kwargs={ + "onload_device": torch_device, + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + }, + ), + ), + ] + + runner = BenchmarkMixin() + runner.run_bencmarks_and_collate(scenarios, filename="flux.csv") diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index f9004f4d5af5..27d73be98bb5 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,12 +1,14 @@ import gc +from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union +import pandas as pd import torch import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import require_torch_gpu, torch_device def benchmark_fn(f, *args, **kwargs): @@ -25,13 +27,26 @@ def flush(): torch.cuda.reset_peak_memory_stats() +def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs): + model = model_cls.from_pretrained(**init_kwargs).eval() + if group_offload_kwargs and isinstance(group_offload_kwargs, dict): + model.enable_group_offload(**group_offload_kwargs) + else: + model.to(torch_device) + if layerwise_upcasting: + model.enable_layerwise_casting( + storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16) + ) + return model + + @dataclass class BenchmarkScenario: name: str model_cls: ModelMixin model_init_kwargs: Dict[str, Any] model_init_fn: Callable - get_model_input_dict: Callable[[], Dict[str, Any]] + get_model_input_dict: Callable compile_kwargs: Optional[Dict[str, Any]] = None @@ -50,6 +65,7 @@ def post_benchmark(self, model): def run_benchmark(self, scenario: BenchmarkScenario): # 1) plain stats plain = self._run_phase( + model_cls=scenario.model_cls, init_fn=scenario.model_init_fn, init_kwargs=scenario.model_init_kwargs, get_input_fn=scenario.get_model_input_dict, @@ -57,9 +73,10 @@ def run_benchmark(self, scenario: BenchmarkScenario): ) # 2) compiled stats (if any) - compiled = None + compiled = {"time": None, "memory": None} if scenario.compile_kwargs: compiled = self._run_phase( + model_cls=scenario.model_cls, init_fn=scenario.model_init_fn, init_kwargs=scenario.model_init_kwargs, get_input_fn=scenario.get_model_input_dict, @@ -67,29 +84,42 @@ def run_benchmark(self, scenario: BenchmarkScenario): ) # 3) merge - result = {"scenario": scenario.name, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"]} - if compiled: - result.update( - { - "time_compile_s": compiled["time"], - "mem_compile_GB": compiled["memory"], - } - ) + result = { + "scenario": scenario.name, + "model_cls": scenario.model_cls.__name__, + "time_plain_s": plain["time"], + "mem_plain_GB": plain["memory"], + "time_compile_s": compiled["time"], + "mem_compile_GB": compiled["memory"], + } + if scenario.compile_kwargs: + result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False) + result["mode"] = scenario.compile_kwargs.get("mode", "default") + else: + result["fullgraph"], result["mode"] = None, None return result + def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str): + if not isinstance(scenarios, list): + scenarios = [scenarios] + records = [self.run_benchmark(s) for s in scenarios] + df = pd.DataFrame.from_records(records) + df.to_csv(filename, index=False) + def _run_phase( self, *, - init_fn: Callable[..., Any], + model_cls: ModelMixin, + init_fn: Callable, init_kwargs: Dict[str, Any], - get_input_fn: Callable[[], Dict[str, torch.Tensor]], + get_input_fn: Callable, compile_kwargs: Optional[Dict[str, Any]], ) -> Dict[str, float]: # setup self.pre_benchmark() # init & (optional) compile - model = init_fn(**init_kwargs) + model = init_fn(model_cls, **init_kwargs) if compile_kwargs: model.compile(**compile_kwargs) @@ -97,7 +127,9 @@ def _run_phase( inp = get_input_fn() # measure - time_s = benchmark_fn(lambda m, d: m(**d), model, inp) + run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext() + with run_ctx: + time_s = benchmark_fn(lambda m, d: m(**d), model, inp) mem_gb = torch.cuda.max_memory_allocated() / (1024**3) mem_gb = round(mem_gb, 2) From 36afdea9ab2598bc3e9cbbb64ff86847965392f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 16 May 2025 14:01:47 +0530 Subject: [PATCH 07/56] error handling and logging. --- benchmarks/benchmarking_utils.py | 46 ++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 27d73be98bb5..7377b4dcbd08 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -8,9 +8,13 @@ import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging from diffusers.utils.testing_utils import require_torch_gpu, torch_device +logger = logging.get_logger(__name__) + + def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", @@ -27,6 +31,8 @@ def flush(): torch.cuda.reset_peak_memory_stats() +# Users can define their own in case this doesn't suffice. For most cases, +# it should be sufficient. def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs): model = model_cls.from_pretrained(**init_kwargs).eval() if group_offload_kwargs and isinstance(group_offload_kwargs, dict): @@ -64,24 +70,35 @@ def post_benchmark(self, model): @torch.no_grad() def run_benchmark(self, scenario: BenchmarkScenario): # 1) plain stats - plain = self._run_phase( - model_cls=scenario.model_cls, - init_fn=scenario.model_init_fn, - init_kwargs=scenario.model_init_kwargs, - get_input_fn=scenario.get_model_input_dict, - compile_kwargs=None, - ) - - # 2) compiled stats (if any) - compiled = {"time": None, "memory": None} - if scenario.compile_kwargs: - compiled = self._run_phase( + results = {} + plain = None + try: + plain = self._run_phase( model_cls=scenario.model_cls, init_fn=scenario.model_init_fn, init_kwargs=scenario.model_init_kwargs, get_input_fn=scenario.get_model_input_dict, - compile_kwargs=scenario.compile_kwargs, + compile_kwargs=None, ) + except Exception as e: + logger.error(f"Benchmark could not be run with the following error\n: {e}") + return results + + # 2) compiled stats (if any) + compiled = {"time": None, "memory": None} + if scenario.compile_kwargs: + try: + compiled = self._run_phase( + model_cls=scenario.model_cls, + init_fn=scenario.model_init_fn, + init_kwargs=scenario.model_init_kwargs, + get_input_fn=scenario.get_model_input_dict, + compile_kwargs=scenario.compile_kwargs, + ) + except Exception as e: + logger.error(f"Compilation benchmark could not be run with the following error\n: {e}") + if plain is None: + return results # 3) merge result = { @@ -103,8 +120,9 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben if not isinstance(scenarios, list): scenarios = [scenarios] records = [self.run_benchmark(s) for s in scenarios] - df = pd.DataFrame.from_records(records) + df = pd.DataFrame.from_records([r for r in records if r]) df.to_csv(filename, index=False) + logger.info(f"Results serialized to {filename=}.") def _run_phase( self, From 4d83a478b26a3425943759ddd278386c25a3bacc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 09:34:40 +0530 Subject: [PATCH 08/56] add flops and params. --- benchmarks/benchmarking_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 7377b4dcbd08..743bf4c11217 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -6,6 +6,7 @@ import pandas as pd import torch import torch.utils.benchmark as benchmark +from torchprofile import profile_macs from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import logging @@ -31,6 +32,19 @@ def flush(): torch.cuda.reset_peak_memory_stats() +# Taken from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py +def calculate_flops(model, input_dict): + model.eval() + with torch.no_grad(): + macs = profile_macs(model, **input_dict) + flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition) + return flops + + +def calculate_params(model): + return sum(p.numel() for p in model.parameters()) + + # Users can define their own in case this doesn't suffice. For most cases, # it should be sufficient. def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs): @@ -69,6 +83,14 @@ def post_benchmark(self, model): @torch.no_grad() def run_benchmark(self, scenario: BenchmarkScenario): + # 0) Basic stats + model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) + num_params = calculate_params(model) + flops = calculate_flops(model, input_dict=scenario.model_init_kwargs) + model.cpu() + del model + self.pre_benchmark() + # 1) plain stats results = {} plain = None @@ -104,6 +126,8 @@ def run_benchmark(self, scenario: BenchmarkScenario): result = { "scenario": scenario.name, "model_cls": scenario.model_cls.__name__, + "num_params": num_params, + "flops": flops, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"], "time_compile_s": compiled["time"], From 6815cae9eeef2d241100f1545aef57a01fcd1084 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 12:37:29 +0530 Subject: [PATCH 09/56] add more models. --- benchmarks/benchmarking_ltx.py | 79 +++++++++++++++++++++++++++++++ benchmarks/benchmarking_sdxl.py | 81 ++++++++++++++++++++++++++++++++ benchmarks/benchmarking_utils.py | 29 +++++++++--- benchmarks/benchmarking_wan.py | 73 ++++++++++++++++++++++++++++ 4 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 benchmarks/benchmarking_ltx.py create mode 100644 benchmarks/benchmarking_sdxl.py create mode 100644 benchmarks/benchmarking_wan.py diff --git a/benchmarks/benchmarking_ltx.py b/benchmarks/benchmarking_ltx.py new file mode 100644 index 000000000000..3515fbf6837e --- /dev/null +++ b/benchmarks/benchmarking_ltx.py @@ -0,0 +1,79 @@ +from functools import partial + +import torch +from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn + +from diffusers import LTXVideoTransformer3DModel +from diffusers.utils.testing_utils import torch_device + + +CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev" + + +def get_input_dict(**device_dtype_kwargs): + # 512x704 (161 frames) + # `max_sequence_length`: 256 + hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs) + encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs) + encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs) + timestep = torch.tensor([1.0], **device_dtype_kwargs) + video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "timestep": timestep, + "video_coords": video_coords, + } + + +if __name__ == "__main__": + scenarios = [ + BenchmarkScenario( + name=f"{CKPT_ID}-bf16", + model_cls=LTXVideoTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, + compile_kwargs={"fullgraph": True}, + ), + BenchmarkScenario( + name=f"{CKPT_ID}-layerwise-upcasting", + model_cls=LTXVideoTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial(model_init_fn, layerwise_upcasting=True), + ), + BenchmarkScenario( + name=f"{CKPT_ID}-group-offload-leaf", + model_cls=LTXVideoTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial( + model_init_fn, + group_offload_kwargs={ + "onload_device": torch_device, + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + }, + ), + ), + ] + + runner = BenchmarkMixin() + runner.run_bencmarks_and_collate(scenarios, filename="ltx.csv") diff --git a/benchmarks/benchmarking_sdxl.py b/benchmarks/benchmarking_sdxl.py new file mode 100644 index 000000000000..165ed3c89052 --- /dev/null +++ b/benchmarks/benchmarking_sdxl.py @@ -0,0 +1,81 @@ +from functools import partial + +import torch +from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn + +from diffusers import UNet2DConditionModel +from diffusers.utils.testing_utils import torch_device + + +CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0" + + +def get_input_dict(**device_dtype_kwargs): + # height: 1024 + # width: 1024 + # max_sequence_length: 77 + hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs) + encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs) + timestep = torch.tensor([1.0], **device_dtype_kwargs) + added_cond_kwargs = { + "text_embeds": torch.randn(1, 1280, **device_dtype_kwargs), + "time_ids": torch.ones(1, 6, **device_dtype_kwargs), + } + + return { + "sample": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "added_cond_kwargs": added_cond_kwargs, + } + + +if __name__ == "__main__": + scenarios = [ + BenchmarkScenario( + name=f"{CKPT_ID}-bf16", + model_cls=UNet2DConditionModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "unet", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, + compile_kwargs={"fullgraph": True}, + ), + BenchmarkScenario( + name=f"{CKPT_ID}-layerwise-upcasting", + model_cls=UNet2DConditionModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "unet", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial(model_init_fn, layerwise_upcasting=True), + ), + BenchmarkScenario( + name=f"{CKPT_ID}-group-offload-leaf", + model_cls=UNet2DConditionModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "unet", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial( + model_init_fn, + group_offload_kwargs={ + "onload_device": torch_device, + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + }, + ), + ), + ] + + runner = BenchmarkMixin() + runner.run_bencmarks_and_collate(scenarios, filename="sdxl.csv") diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 743bf4c11217..6759fcc75dd5 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,4 +1,5 @@ import gc +import inspect from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union @@ -32,11 +33,27 @@ def flush(): torch.cuda.reset_peak_memory_stats() -# Taken from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py +# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py def calculate_flops(model, input_dict): + # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs. + sig = inspect.signature(model.forward) + param_names = [ + p.name + for p in sig.parameters.values() + if p.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and p.name != "self" + ] + bound = sig.bind_partial(**input_dict) + bound.apply_defaults() + args = tuple(bound.arguments[name] for name in param_names) + model.eval() with torch.no_grad(): - macs = profile_macs(model, **input_dict) + macs = profile_macs(model, args) flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition) return flops @@ -85,8 +102,8 @@ def post_benchmark(self, model): def run_benchmark(self, scenario: BenchmarkScenario): # 0) Basic stats model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) - num_params = calculate_params(model) - flops = calculate_flops(model, input_dict=scenario.model_init_kwargs) + num_params = round(calculate_params(model) / 1e6, 2) + flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) model.cpu() del model self.pre_benchmark() @@ -126,8 +143,8 @@ def run_benchmark(self, scenario: BenchmarkScenario): result = { "scenario": scenario.name, "model_cls": scenario.model_cls.__name__, - "num_params": num_params, - "flops": flops, + "num_params_M": num_params, + "flops_M": flops, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"], "time_compile_s": compiled["time"], diff --git a/benchmarks/benchmarking_wan.py b/benchmarks/benchmarking_wan.py new file mode 100644 index 000000000000..349793e4468b --- /dev/null +++ b/benchmarks/benchmarking_wan.py @@ -0,0 +1,73 @@ +from functools import partial + +import torch +from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn + +from diffusers import WanTransformer3DModel +from diffusers.utils.testing_utils import torch_device + + +CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + + +def get_input_dict(**device_dtype_kwargs): + # height: 480 + # width: 832 + # num_frames: 81 + # max_sequence_length: 512 + hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs) + encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs) + timestep = torch.tensor([1.0], **device_dtype_kwargs) + + return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep} + + +if __name__ == "__main__": + scenarios = [ + BenchmarkScenario( + name=f"{CKPT_ID}-bf16", + model_cls=WanTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, + compile_kwargs={"fullgraph": True}, + ), + BenchmarkScenario( + name=f"{CKPT_ID}-layerwise-upcasting", + model_cls=WanTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial(model_init_fn, layerwise_upcasting=True), + ), + BenchmarkScenario( + name=f"{CKPT_ID}-group-offload-leaf", + model_cls=WanTransformer3DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=partial( + model_init_fn, + group_offload_kwargs={ + "onload_device": torch_device, + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + }, + ), + ), + ] + + runner = BenchmarkMixin() + runner.run_bencmarks_and_collate(scenarios, filename="wan.csv") From 5635bf86e6f1a664cee31258da95be22b83c44b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 15:21:00 +0530 Subject: [PATCH 10/56] utility to fire execution of all benchmarking scripts. --- benchmarks/run_all.py | 52 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 benchmarks/run_all.py diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py new file mode 100644 index 000000000000..3e91775de0b0 --- /dev/null +++ b/benchmarks/run_all.py @@ -0,0 +1,52 @@ +import glob +import subprocess +import pandas as pd +import os + +PATTERN = "benchmarking_*.py" +FINAL_CSV_FILENAME = "collated_results.csv" +GITHUB_SHA = os.getenv("GITHUB_SHA", None) + +class SubprocessCallException(Exception): + pass + + +# Taken from `test_examples_utils.py` +def run_command(command: list[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +def run_scripts(): + python_files = sorted(glob.glob(PATTERN)) + + for file in python_files: + if file != "benchmarking_utils.py": + print(f"****** Running file: {file} ******") + command = f"python {file}" + run_command(command.split()) + + +def merge_csvs(): + all_csvs = glob.glob("*.csv") + final_df = pd.concat([pd.read_csv(f) for f in all_csvs]).reset_index(drop=True) + if GITHUB_SHA: + final_df["github_sha"] = GITHUB_SHA + final_df.to_csv(FINAL_CSV_FILENAME) + + +if __name__ == "__main__": + run_scripts() + merge_csvs() From cfbd21e2cc3d44af0e45e51655f387806d05ad30 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 15:41:45 +0530 Subject: [PATCH 11/56] utility to push to the hub. --- benchmarks/push_results.py | 75 ++++++++++++++++++++++++++++++++++++++ benchmarks/run_all.py | 7 +++- 2 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 benchmarks/push_results.py diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py new file mode 100644 index 000000000000..0e6dc4e87040 --- /dev/null +++ b/benchmarks/push_results.py @@ -0,0 +1,75 @@ +import os + +import pandas as pd +import torch +from huggingface_hub import hf_hub_download, upload_file +from huggingface_hub.utils import EntryNotFoundError + + +if torch.cuda.is_available(): + TOTAL_GPU_MEMORY = float( + os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)) + ) +else: + raise + +REPO_ID = "diffusers/benchmarks" + + +def has_previous_benchmark() -> str: + from run_all import FINAL_CSV_FILENAME + + csv_path = None + try: + csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME) + except EntryNotFoundError: + csv_path = None + return csv_path + + +def filter_float(value): + if isinstance(value, str): + return float(value.split()[0]) + return value + + +def push_to_hf_dataset(): + from run_all import FINAL_CSV_FILENAME, GITHUB_SHA + + # If there's an existing benchmark file, we should report the changes. + csv_path = has_previous_benchmark() + if csv_path is not None: + current_results = pd.read_csv(FINAL_CSV_FILENAME) + previous_results = pd.read_csv(csv_path) + + # identify the numeric columns we want to annotate + numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns + + # for each numeric column, append the old value in () if present + for column in numeric_columns: + # coerce any “x units” strings back to float + prev_vals = previous_results[column].map(filter_float) + # align indices in case rows were added/removed + prev_vals = prev_vals.reindex(current_results.index) + + # build the new string: "current_value (previous_value)" + curr_str = current_results[column].astype(str) + prev_str = prev_vals.map(lambda x: f" ({x})" if pd.notnull(x) else "") + + current_results[column] = curr_str + prev_str + + # overwrite the CSV + current_results.to_csv(FINAL_CSV_FILENAME, index=False) + + commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" + upload_file( + repo_id=REPO_ID, + path_in_repo=FINAL_CSV_FILENAME, + path_or_fileobj=FINAL_CSV_FILENAME, + repo_type="dataset", + commit_message=commit_message, + ) + + +if __name__ == "__main__": + push_to_hf_dataset() diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 3e91775de0b0..7c02d7a5921b 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -1,12 +1,15 @@ import glob +import os import subprocess + import pandas as pd -import os + PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.csv" GITHUB_SHA = os.getenv("GITHUB_SHA", None) + class SubprocessCallException(Exception): pass @@ -33,7 +36,7 @@ def run_scripts(): python_files = sorted(glob.glob(PATTERN)) for file in python_files: - if file != "benchmarking_utils.py": + if file != "benchmarking_utils.py": print(f"****** Running file: {file} ******") command = f"python {file}" run_command(command.split()) From 4ccfad027b37eb1bcf5003fcb2bcd2df52379a80 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 15:57:25 +0530 Subject: [PATCH 12/56] push utility improvement --- benchmarks/push_results.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py index 0e6dc4e87040..b2451dc64261 100644 --- a/benchmarks/push_results.py +++ b/benchmarks/push_results.py @@ -36,32 +36,43 @@ def filter_float(value): def push_to_hf_dataset(): from run_all import FINAL_CSV_FILENAME, GITHUB_SHA - # If there's an existing benchmark file, we should report the changes. csv_path = has_previous_benchmark() if csv_path is not None: current_results = pd.read_csv(FINAL_CSV_FILENAME) previous_results = pd.read_csv(csv_path) - # identify the numeric columns we want to annotate numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns - # for each numeric column, append the old value in () if present for column in numeric_columns: - # coerce any “x units” strings back to float - prev_vals = previous_results[column].map(filter_float) - # align indices in case rows were added/removed - prev_vals = prev_vals.reindex(current_results.index) + # get previous values as floats, aligned to current index + prev_vals = ( + previous_results[column] + .map(filter_float) + .reindex(current_results.index) + ) - # build the new string: "current_value (previous_value)" - curr_str = current_results[column].astype(str) - prev_str = prev_vals.map(lambda x: f" ({x})" if pd.notnull(x) else "") + # get current values as floats + curr_vals = current_results[column].astype(float) - current_results[column] = curr_str + prev_str + # stringify the current values + curr_str = curr_vals.map(str) + + # build an appendage only when prev exists and differs + append_str = prev_vals.where( + prev_vals.notnull() & (prev_vals != curr_vals), + other=pd.NA + ).map(lambda x: f" ({x})" if pd.notnull(x) else "") + + # combine + current_results[column] = curr_str + append_str - # overwrite the CSV current_results.to_csv(FINAL_CSV_FILENAME, index=False) - commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" + commit_message = ( + f"upload from sha: {GITHUB_SHA}" + if GITHUB_SHA is not None else + "upload benchmark results" + ) upload_file( repo_id=REPO_ID, path_in_repo=FINAL_CSV_FILENAME, From dff314469a1032b1ae34f8395f10fe9bac975fae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 16:30:18 +0530 Subject: [PATCH 13/56] seems to be working. --- benchmarks/benchmarking_flux.py | 3 ++- benchmarks/benchmarking_ltx.py | 3 ++- benchmarks/benchmarking_sdxl.py | 3 ++- benchmarks/benchmarking_wan.py | 3 ++- benchmarks/push_results.py | 19 +++++-------------- benchmarks/run_all.py | 7 ++++++- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py index dbea44d7da6a..18a2680052ea 100644 --- a/benchmarks/benchmarking_flux.py +++ b/benchmarks/benchmarking_flux.py @@ -8,6 +8,7 @@ CKPT_ID = "black-forest-labs/FLUX.1-dev" +RESULT_FILENAME = "flux.csv" def get_input_dict(**device_dtype_kwargs): @@ -94,4 +95,4 @@ def get_input_dict(**device_dtype_kwargs): ] runner = BenchmarkMixin() - runner.run_bencmarks_and_collate(scenarios, filename="flux.csv") + runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME) diff --git a/benchmarks/benchmarking_ltx.py b/benchmarks/benchmarking_ltx.py index 3515fbf6837e..3d698fd0bd57 100644 --- a/benchmarks/benchmarking_ltx.py +++ b/benchmarks/benchmarking_ltx.py @@ -8,6 +8,7 @@ CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev" +RESULT_FILENAME = "ltx.csv" def get_input_dict(**device_dtype_kwargs): @@ -76,4 +77,4 @@ def get_input_dict(**device_dtype_kwargs): ] runner = BenchmarkMixin() - runner.run_bencmarks_and_collate(scenarios, filename="ltx.csv") + runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME) diff --git a/benchmarks/benchmarking_sdxl.py b/benchmarks/benchmarking_sdxl.py index 165ed3c89052..ded62784f290 100644 --- a/benchmarks/benchmarking_sdxl.py +++ b/benchmarks/benchmarking_sdxl.py @@ -8,6 +8,7 @@ CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0" +RESULT_FILENAME = "sdxl.csv" def get_input_dict(**device_dtype_kwargs): @@ -78,4 +79,4 @@ def get_input_dict(**device_dtype_kwargs): ] runner = BenchmarkMixin() - runner.run_bencmarks_and_collate(scenarios, filename="sdxl.csv") + runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME) diff --git a/benchmarks/benchmarking_wan.py b/benchmarks/benchmarking_wan.py index 349793e4468b..64e81fdb6b09 100644 --- a/benchmarks/benchmarking_wan.py +++ b/benchmarks/benchmarking_wan.py @@ -8,6 +8,7 @@ CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" +RESULT_FILENAME = "wan.csv" def get_input_dict(**device_dtype_kwargs): @@ -70,4 +71,4 @@ def get_input_dict(**device_dtype_kwargs): ] runner = BenchmarkMixin() - runner.run_bencmarks_and_collate(scenarios, filename="wan.csv") + runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME) diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py index b2451dc64261..105b0444f0e2 100644 --- a/benchmarks/push_results.py +++ b/benchmarks/push_results.py @@ -45,11 +45,7 @@ def push_to_hf_dataset(): for column in numeric_columns: # get previous values as floats, aligned to current index - prev_vals = ( - previous_results[column] - .map(filter_float) - .reindex(current_results.index) - ) + prev_vals = previous_results[column].map(filter_float).reindex(current_results.index) # get current values as floats curr_vals = current_results[column].astype(float) @@ -58,21 +54,16 @@ def push_to_hf_dataset(): curr_str = curr_vals.map(str) # build an appendage only when prev exists and differs - append_str = prev_vals.where( - prev_vals.notnull() & (prev_vals != curr_vals), - other=pd.NA - ).map(lambda x: f" ({x})" if pd.notnull(x) else "") + append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map( + lambda x: f" ({x})" if pd.notnull(x) else "" + ) # combine current_results[column] = curr_str + append_str current_results.to_csv(FINAL_CSV_FILENAME, index=False) - commit_message = ( - f"upload from sha: {GITHUB_SHA}" - if GITHUB_SHA is not None else - "upload benchmark results" - ) + commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" upload_file( repo_id=REPO_ID, path_in_repo=FINAL_CSV_FILENAME, diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 7c02d7a5921b..bb6fd6a225fc 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -4,12 +4,17 @@ import pandas as pd +from diffusers.utils import logging + PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.csv" GITHUB_SHA = os.getenv("GITHUB_SHA", None) +logger = logging.get_logger(__name__) + + class SubprocessCallException(Exception): pass @@ -37,7 +42,7 @@ def run_scripts(): for file in python_files: if file != "benchmarking_utils.py": - print(f"****** Running file: {file} ******") + logger.info(f"****** Running file: {file} ******") command = f"python {file}" run_command(command.split()) From accd5989363678f10ec493c835f29ea460c2b5eb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 16:33:59 +0530 Subject: [PATCH 14/56] okay --- .github/workflows/benchmark.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index ff915e046946..87b4964b9e6b 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -13,10 +13,10 @@ env: MKL_NUM_THREADS: 8 jobs: - torch_pipelines_cuda_benchmark_tests: + torch_models_cuda_benchmark_tests: env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }} - name: Torch Core Pipelines CUDA Benchmarking Tests + name: Torch Core Models CUDA Benchmarking Tests strategy: fail-fast: false max-parallel: 1 @@ -37,8 +37,7 @@ jobs: run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] - python -m uv pip install pandas peft - python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0 + python -m uv pip install pandas peft bitsandbytes - name: Environment run: | python utils/print_env.py @@ -48,7 +47,8 @@ jobs: BASE_PATH: benchmark_outputs run: | export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))") - cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py + cd benchmarks && python run_all.py && python push_results.py + mkdir ${BASE_PATH} && mv *.csv ${BASE_PATH} - name: Test suite reports artifacts if: ${{ always() }} From 41f79a00c5a090a5a3887b22cdaf8e8e90fad189 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 16:41:41 +0530 Subject: [PATCH 15/56] add torchprofile dep. --- .github/workflows/benchmark.yml | 2 +- benchmarks/benchmarking_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 87b4964b9e6b..850348319370 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -37,7 +37,7 @@ jobs: run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] - python -m uv pip install pandas peft bitsandbytes + python -m uv pip install pandas peft bitsandbytes torchprofile - name: Environment run: | python utils/print_env.py diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 6759fcc75dd5..7e80653178e8 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -7,7 +7,6 @@ import pandas as pd import torch import torch.utils.benchmark as benchmark -from torchprofile import profile_macs from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import logging @@ -35,6 +34,11 @@ def flush(): # Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py def calculate_flops(model, input_dict): + try: + from torchprofile import profile_macs + except ModuleNotFoundError: + raise + # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs. sig = inspect.signature(model.forward) param_names = [ From befdd9ea297eaae371309025901bad426f416c44 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 16:57:09 +0530 Subject: [PATCH 16/56] remove total gpu memory --- .github/workflows/benchmark.yml | 1 - benchmarks/push_results.py | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 850348319370..6e632b3badfe 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -46,7 +46,6 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} BASE_PATH: benchmark_outputs run: | - export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))") cd benchmarks && python run_all.py && python push_results.py mkdir ${BASE_PATH} && mv *.csv ${BASE_PATH} diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py index 105b0444f0e2..30da0c053863 100644 --- a/benchmarks/push_results.py +++ b/benchmarks/push_results.py @@ -1,18 +1,8 @@ -import os - import pandas as pd -import torch from huggingface_hub import hf_hub_download, upload_file from huggingface_hub.utils import EntryNotFoundError -if torch.cuda.is_available(): - TOTAL_GPU_MEMORY = float( - os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)) - ) -else: - raise - REPO_ID = "diffusers/benchmarks" From 4784b8bc9ebb53092d908eceebf0fe1017f89a0a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 17:27:24 +0530 Subject: [PATCH 17/56] fixes --- .github/workflows/benchmark.yml | 4 +++- benchmarks/benchmarking_utils.py | 1 + benchmarks/run_all.py | 6 +++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 6e632b3badfe..6346fbd92424 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -37,7 +37,9 @@ jobs: run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] - python -m uv pip install pandas peft bitsandbytes torchprofile + python -m uv pip install pandas peft torchprofile + # Temporary. + pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl - name: Environment run: | python utils/print_env.py diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 7e80653178e8..9241d99acd57 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -105,6 +105,7 @@ def post_benchmark(self, model): @torch.no_grad() def run_benchmark(self, scenario: BenchmarkScenario): # 0) Basic stats + logger.info(f"Running scenario: {scenario.name}.") model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) num_params = round(calculate_params(model) / 1e6, 2) flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index bb6fd6a225fc..2951dd647a2b 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -44,7 +44,11 @@ def run_scripts(): if file != "benchmarking_utils.py": logger.info(f"****** Running file: {file} ******") command = f"python {file}" - run_command(command.split()) + try: + run_command(command) + except SubprocessCallException as e: + logger.error(f"Error running {file}: {e}") + continue def merge_csvs(): From c19dc5ba4a6dee40d8c07301025aeef126d38387 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 17:38:23 +0530 Subject: [PATCH 18/56] fix --- benchmarks/run_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 2951dd647a2b..5d30ec3953c4 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -45,7 +45,7 @@ def run_scripts(): logger.info(f"****** Running file: {file} ******") command = f"python {file}" try: - run_command(command) + run_command(command.split()) except SubprocessCallException as e: logger.error(f"Error running {file}: {e}") continue From 2da4facd9423c4b7732591ff294a0356446e8dc1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 17:44:06 +0530 Subject: [PATCH 19/56] need a big gpu --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 6346fbd92424..1fb02b86fe96 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false max-parallel: 1 runs-on: - group: aws-g6-4xlarge-plus + group: aws-g6e-xlarge-plus container: image: diffusers/diffusers-pytorch-compile-cuda options: --shm-size "16gb" --ipc host --gpus 0 From 7367bb10127b7224b1ce4c05c38682e0588913be Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 17:50:46 +0530 Subject: [PATCH 20/56] better --- benchmarks/benchmarking_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 9241d99acd57..cf65961d839b 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -165,7 +165,12 @@ def run_benchmark(self, scenario: BenchmarkScenario): def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str): if not isinstance(scenarios, list): scenarios = [scenarios] - records = [self.run_benchmark(s) for s in scenarios] + records = [] + for s in records: + try: + records.append(self.run_benchmark(s)) + except Exception as e: + logger.error(f"Running scenario ({s.name}) led to error:\n{e}") df = pd.DataFrame.from_records([r for r in records if r]) df.to_csv(filename, index=False) logger.info(f"Results serialized to {filename=}.") From 1cd472fbb0d593e8acb0009256cad7a88aca2c7b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 18:07:41 +0530 Subject: [PATCH 21/56] what's happening. --- benchmarks/benchmarking_utils.py | 14 +++++--------- benchmarks/run_all.py | 9 ++------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index cf65961d839b..12b90a080f94 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -9,13 +9,9 @@ import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils import logging from diffusers.utils.testing_utils import require_torch_gpu, torch_device -logger = logging.get_logger(__name__) - - def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", @@ -105,7 +101,7 @@ def post_benchmark(self, model): @torch.no_grad() def run_benchmark(self, scenario: BenchmarkScenario): # 0) Basic stats - logger.info(f"Running scenario: {scenario.name}.") + print(f"Running scenario: {scenario.name}.") model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) num_params = round(calculate_params(model) / 1e6, 2) flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) @@ -125,7 +121,7 @@ def run_benchmark(self, scenario: BenchmarkScenario): compile_kwargs=None, ) except Exception as e: - logger.error(f"Benchmark could not be run with the following error\n: {e}") + print(f"Benchmark could not be run with the following error\n: {e}") return results # 2) compiled stats (if any) @@ -140,7 +136,7 @@ def run_benchmark(self, scenario: BenchmarkScenario): compile_kwargs=scenario.compile_kwargs, ) except Exception as e: - logger.error(f"Compilation benchmark could not be run with the following error\n: {e}") + print(f"Compilation benchmark could not be run with the following error\n: {e}") if plain is None: return results @@ -170,10 +166,10 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben try: records.append(self.run_benchmark(s)) except Exception as e: - logger.error(f"Running scenario ({s.name}) led to error:\n{e}") + print(f"Running scenario ({s.name}) led to error:\n{e}") df = pd.DataFrame.from_records([r for r in records if r]) df.to_csv(filename, index=False) - logger.info(f"Results serialized to {filename=}.") + print(f"Results serialized to {filename=}.") def _run_phase( self, diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 5d30ec3953c4..278683bdc254 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -4,17 +4,12 @@ import pandas as pd -from diffusers.utils import logging - PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.csv" GITHUB_SHA = os.getenv("GITHUB_SHA", None) -logger = logging.get_logger(__name__) - - class SubprocessCallException(Exception): pass @@ -42,12 +37,12 @@ def run_scripts(): for file in python_files: if file != "benchmarking_utils.py": - logger.info(f"****** Running file: {file} ******") + print(f"****** Running file: {file} ******") command = f"python {file}" try: run_command(command.split()) except SubprocessCallException as e: - logger.error(f"Error running {file}: {e}") + print(f"Error running {file}: {e}") continue From 214795d56f7f1dfe861f94e72f5f426ae36948d5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 20 May 2025 18:42:20 +0530 Subject: [PATCH 22/56] okay --- benchmarks/benchmarking_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 12b90a080f94..c4a3a976309e 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -162,7 +162,7 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben if not isinstance(scenarios, list): scenarios = [scenarios] records = [] - for s in records: + for s in scenarios: try: records.append(self.run_benchmark(s)) except Exception as e: From 1122cad949ca5f5a0ad2eb08f9f11c663879c2b7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 08:52:49 +0530 Subject: [PATCH 23/56] separate requirements and make it nightly. --- .github/workflows/benchmark.yml | 6 ++---- benchmarks/requirements.txt | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 benchmarks/requirements.txt diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index a042f8614ccb..c443ec640a0e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -3,7 +3,7 @@ name: Benchmarking tests on: workflow_dispatch: schedule: - - cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM + - cron: "0 0 * * *" # every day at midnight env: DIFFUSERS_IS_CI: yes @@ -37,9 +37,7 @@ jobs: run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] - python -m uv pip install pandas peft torchprofile - # Temporary. - pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl + python -m uv pip install -r benchmarks/requirements.txt - name: Environment run: | python utils/print_env.py diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000000..ddbb535341be --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,4 @@ +pandas +peft +torchprofile +bitsandbytes \ No newline at end of file From baa92c279f95179bef237e2306eff9259dc4d3ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 09:21:14 +0530 Subject: [PATCH 24/56] add db population script. --- .github/workflows/benchmark.yml | 12 +++- benchmarks/populate_into_db.py | 124 ++++++++++++++++++++++++++++++++ benchmarks/requirements.txt | 3 +- 3 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 benchmarks/populate_into_db.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c443ec640a0e..809de0e48727 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -11,6 +11,7 @@ env: HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 + BASE_PATH: benchmark_outputs jobs: torch_models_cuda_benchmark_tests: @@ -43,8 +44,7 @@ jobs: python utils/print_env.py - name: Diffusers Benchmarking env: - HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} - BASE_PATH: benchmark_outputs + HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} run: | cd benchmarks && python run_all.py && python push_results.py mkdir ${BASE_PATH} && mv *.csv ${BASE_PATH} @@ -56,6 +56,14 @@ jobs: name: benchmark_test_reports path: benchmarks/benchmark_outputs + - name: Update benchmarking results to DB + env: + PGDATABASE: metrics + PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }} + PGUSER: transformers_benchmarks + PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }} + run: cd benchmarks && python populate_into_db.py + - name: Report success status if: ${{ success() }} run: | diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py new file mode 100644 index 000000000000..62cfbd132717 --- /dev/null +++ b/benchmarks/populate_into_db.py @@ -0,0 +1,124 @@ +import os + +import pandas as pd +import psycopg2 +import psycopg2.extras + + +FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" +TABLE_NAME = "diffusers_benchmarks" + +if __name__ == "__main__": + conn = psycopg2.connect( + host=os.getenv("PGHOST"), + database=os.getenv("PGDATABASE"), + user=os.getenv("PGUSER"), + password=os.getenv("PGPASSWORD"), + ) + cur = conn.cursor() + + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {TABLE_NAME} ( + scenario TEXT, + model_cls TEXT, + num_params_M REAL, + flops_M REAL, + time_plain_s REAL, + mem_plain_GB REAL, + time_compile_s REAL, + mem_compile_GB REAL, + fullgraph BOOLEAN, + mode TEXT, + github_sha TEXT + ); + """) + conn.commit() + + df = pd.read_csv(FINAL_CSV_FILENAME) + + # Helper to cast values (or None) given a dtype + def _cast_value(val, dtype: str): + if pd.isna(val): + return None + + if dtype == "text": + return str(val).strip() + + if dtype == "float": + try: + return float(val) + except ValueError: + return None + + if dtype == "bool": + s = str(val).strip().lower() + if s in ("true", "t", "yes", "1"): + return True + if s in ("false", "f", "no", "0"): + return False + if val in (1, 1.0): + return True + if val in (0, 0.0): + return False + return None + + return val + + rows_to_insert = [] + for _, row in df.iterrows(): + scenario = _cast_value(row.get("scenario"), "text") + model_cls = _cast_value(row.get("model_cls"), "text") + num_params_M = _cast_value(row.get("num_params_M"), "float") + flops_M = _cast_value(row.get("flops_M"), "float") + time_plain_s = _cast_value(row.get("time_plain_s"), "float") + mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") + time_compile_s = _cast_value(row.get("time_compile_s"), "float") + mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") + fullgraph = _cast_value(row.get("fullgraph"), "bool") + mode = _cast_value(row.get("mode"), "text") + + # If "github_sha" column exists in the CSV, cast it; else default to None + if "github_sha" in df.columns: + github_sha = _cast_value(row.get("github_sha"), "text") + else: + github_sha = None + + rows_to_insert.append( + ( + scenario, + model_cls, + num_params_M, + flops_M, + time_plain_s, + mem_plain_GB, + time_compile_s, + mem_compile_GB, + fullgraph, + mode, + github_sha, + ) + ) + + # Batch-insert all rows (with NULL for any None) + insert_sql = """ + INSERT INTO benchmarks ( + scenario, + model_cls, + num_params_M, + flops_M, + time_plain_s, + mem_plain_GB, + time_compile_s, + mem_compile_GB, + fullgraph, + mode, + github_sha + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); + """ + + psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) + conn.commit() + + cur.close() + conn.close() diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt index ddbb535341be..29e681a2d61d 100644 --- a/benchmarks/requirements.txt +++ b/benchmarks/requirements.txt @@ -1,4 +1,5 @@ pandas peft torchprofile -bitsandbytes \ No newline at end of file +bitsandbytes +psycopg2==2.9.9 \ No newline at end of file From 9e1f17fdf73c55b4ac1afda0b304ebd3892ad475 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 09:24:22 +0530 Subject: [PATCH 25/56] update secret name --- .github/workflows/benchmark.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 809de0e48727..8e3599a1ac22 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -59,9 +59,9 @@ jobs: - name: Update benchmarking results to DB env: PGDATABASE: metrics - PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }} + PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }} # TODO PGUSER: transformers_benchmarks - PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }} + PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} run: cd benchmarks && python populate_into_db.py - name: Report success status From 71200da071034d0d984cf43ea9d70bd8d9119282 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:08:17 +0530 Subject: [PATCH 26/56] update secret. --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 8e3599a1ac22..7da906ca130a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -59,7 +59,7 @@ jobs: - name: Update benchmarking results to DB env: PGDATABASE: metrics - PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }} # TODO + PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} # TODO PGUSER: transformers_benchmarks PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} run: cd benchmarks && python populate_into_db.py From e45e4ebc832d2155005a59714b62f31489666a9e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 11:52:06 +0530 Subject: [PATCH 27/56] population db update --- .github/workflows/benchmark.yml | 2 + benchmarks/populate_into_db.py | 154 +++++++++++++++----------------- 2 files changed, 74 insertions(+), 82 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 7da906ca130a..121d77a34250 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -36,6 +36,8 @@ jobs: nvidia-smi - name: Install dependencies run: | + apt update + apt install -y libpq-dev postgresql-client python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install -r benchmarks/requirements.txt diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index 62cfbd132717..a64723b33f94 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -1,40 +1,31 @@ +import datetime import os +import uuid import pandas as pd import psycopg2 import psycopg2.extras -FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" -TABLE_NAME = "diffusers_benchmarks" +# FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" +# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27 +TABLE_NAME = "model_measurements" if __name__ == "__main__": - conn = psycopg2.connect( - host=os.getenv("PGHOST"), - database=os.getenv("PGDATABASE"), - user=os.getenv("PGUSER"), - password=os.getenv("PGPASSWORD"), - ) + try: + conn = psycopg2.connect( + host=os.getenv("PGHOST"), + database=os.getenv("PGDATABASE"), + user=os.getenv("PGUSER"), + password=os.getenv("PGPASSWORD"), + ) + print("DB connection established successfully.") + except Exception: + raise cur = conn.cursor() - cur.execute(f""" - CREATE TABLE IF NOT EXISTS {TABLE_NAME} ( - scenario TEXT, - model_cls TEXT, - num_params_M REAL, - flops_M REAL, - time_plain_s REAL, - mem_plain_GB REAL, - time_compile_s REAL, - mem_compile_GB REAL, - fullgraph BOOLEAN, - mode TEXT, - github_sha TEXT - ); - """) - conn.commit() - - df = pd.read_csv(FINAL_CSV_FILENAME) + # df = pd.read_csv(FINAL_CSV_FILENAME) + df = pd.read_csv("collated_results.csv") # Helper to cast values (or None) given a dtype def _cast_value(val, dtype: str): @@ -64,61 +55,60 @@ def _cast_value(val, dtype: str): return val - rows_to_insert = [] - for _, row in df.iterrows(): - scenario = _cast_value(row.get("scenario"), "text") - model_cls = _cast_value(row.get("model_cls"), "text") - num_params_M = _cast_value(row.get("num_params_M"), "float") - flops_M = _cast_value(row.get("flops_M"), "float") - time_plain_s = _cast_value(row.get("time_plain_s"), "float") - mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") - time_compile_s = _cast_value(row.get("time_compile_s"), "float") - mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") - fullgraph = _cast_value(row.get("fullgraph"), "bool") - mode = _cast_value(row.get("mode"), "text") - - # If "github_sha" column exists in the CSV, cast it; else default to None - if "github_sha" in df.columns: - github_sha = _cast_value(row.get("github_sha"), "text") - else: - github_sha = None - - rows_to_insert.append( - ( - scenario, - model_cls, - num_params_M, - flops_M, - time_plain_s, - mem_plain_GB, - time_compile_s, - mem_compile_GB, - fullgraph, - mode, - github_sha, - ) + try: + rows_to_insert = [] + id_for_benchmark = str(uuid.uuid4()) + "_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + for _, row in df.iterrows(): + scenario = _cast_value(row.get("scenario"), "text") + model_cls = _cast_value(row.get("model_cls"), "text") + num_params_M = _cast_value(row.get("num_params_M"), "float") + flops_M = _cast_value(row.get("flops_M"), "float") + time_plain_s = _cast_value(row.get("time_plain_s"), "float") + mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") + time_compile_s = _cast_value(row.get("time_compile_s"), "float") + mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") + fullgraph = _cast_value(row.get("fullgraph"), "bool") + mode = _cast_value(row.get("mode"), "text") + + # If "github_sha" column exists in the CSV, cast it; else default to None + if "github_sha" in df.columns: + github_sha = _cast_value(row.get("github_sha"), "text") + else: + github_sha = None + + if github_sha: + benchmark_id = f"{model_cls}-{scenario}-{github_sha}" + else: + benchmark_id = f"{model_cls}-{scenario}-{id_for_benchmark}" + + measurements = { + "scenario": scenario, + "model_cls": model_cls, + "num_params_M": num_params_M, + "flops_M": flops_M, + "time_plain_s": time_plain_s, + "mem_plain_GB": mem_plain_GB, + "time_compile_s": time_compile_s, + "mem_compile_GB": mem_compile_GB, + "fullgraph": fullgraph, + "mode": mode, + "github_sha": github_sha, + } + rows_to_insert.append((benchmark_id, measurements)) + + # Batch-insert all rows + insert_sql = f""" + INSERT INTO {TABLE_NAME} ( + benchmark_id, + measurements ) + VALUES (%s, %s); + """ + + psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) + conn.commit() - # Batch-insert all rows (with NULL for any None) - insert_sql = """ - INSERT INTO benchmarks ( - scenario, - model_cls, - num_params_M, - flops_M, - time_plain_s, - mem_plain_GB, - time_compile_s, - mem_compile_GB, - fullgraph, - mode, - github_sha - ) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); - """ - - psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) - conn.commit() - - cur.close() - conn.close() + cur.close() + conn.close() + except Exception as e: + print(f"Exception: {e}") From 4a60155dc420cabec31f946a20d0b0a1de024144 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 11:52:50 +0530 Subject: [PATCH 28/56] disable db population for now. --- .github/workflows/benchmark.yml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 121d77a34250..e7adb23668a2 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -57,14 +57,15 @@ jobs: with: name: benchmark_test_reports path: benchmarks/benchmark_outputs - - - name: Update benchmarking results to DB - env: - PGDATABASE: metrics - PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} # TODO - PGUSER: transformers_benchmarks - PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} - run: cd benchmarks && python populate_into_db.py + + # TODO: enable this once the connection problem has been resolved. + # - name: Update benchmarking results to DB + # env: + # PGDATABASE: metrics + # PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} # TODO + # PGUSER: transformers_benchmarks + # PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} + # run: cd benchmarks && python populate_into_db.py - name: Report success status if: ${{ success() }} From e0ccb602537e331227914d4a6f84b6cd112f7355 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 13:53:52 +0530 Subject: [PATCH 29/56] change to every monday --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index e7adb23668a2..b157cebdf4f2 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -3,7 +3,7 @@ name: Benchmarking tests on: workflow_dispatch: schedule: - - cron: "0 0 * * *" # every day at midnight + - cron: "0 17 * * 1" # every monday at 5 PM. env: DIFFUSERS_IS_CI: yes From 61dd029b0026eface3c7ab7da0e600674fe4c77a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 6 Jun 2025 13:54:10 +0530 Subject: [PATCH 30/56] Update .github/workflows/benchmark.yml Co-authored-by: Dhruv Nair --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b157cebdf4f2..60c5dfc33452 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -49,7 +49,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} run: | cd benchmarks && python run_all.py && python push_results.py - mkdir ${BASE_PATH} && mv *.csv ${BASE_PATH} + mkdir $BASE_PATH && mv *.csv $BASE_PATH - name: Test suite reports artifacts if: ${{ always() }} From ee0fcd496caccc856b5828b0d1e940c2f256771b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 14:18:32 +0530 Subject: [PATCH 31/56] quality improvements. --- benchmarks/benchmarking_utils.py | 28 ++++++++++++++++++---------- benchmarks/run_all.py | 16 ++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index c4a3a976309e..9c0b8644f897 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -9,9 +9,13 @@ import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging from diffusers.utils.testing_utils import require_torch_gpu, torch_device +logger = logging.get_logger(__name__) + + def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", @@ -101,12 +105,16 @@ def post_benchmark(self, model): @torch.no_grad() def run_benchmark(self, scenario: BenchmarkScenario): # 0) Basic stats - print(f"Running scenario: {scenario.name}.") - model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) - num_params = round(calculate_params(model) / 1e6, 2) - flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) - model.cpu() - del model + logger.info(f"Running scenario: {scenario.name}.") + try: + model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) + num_params = round(calculate_params(model) / 1e6, 2) + flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) + model.cpu() + del model + except Exception as e: + logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}") + return {} self.pre_benchmark() # 1) plain stats @@ -121,7 +129,7 @@ def run_benchmark(self, scenario: BenchmarkScenario): compile_kwargs=None, ) except Exception as e: - print(f"Benchmark could not be run with the following error\n: {e}") + logger.info(f"Benchmark could not be run with the following error:\n{e}") return results # 2) compiled stats (if any) @@ -136,7 +144,7 @@ def run_benchmark(self, scenario: BenchmarkScenario): compile_kwargs=scenario.compile_kwargs, ) except Exception as e: - print(f"Compilation benchmark could not be run with the following error\n: {e}") + logger.info(f"Compilation benchmark could not be run with the following error\n: {e}") if plain is None: return results @@ -166,10 +174,10 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben try: records.append(self.run_benchmark(s)) except Exception as e: - print(f"Running scenario ({s.name}) led to error:\n{e}") + logger.info(f"Running scenario ({s.name}) led to error:\n{e}") df = pd.DataFrame.from_records([r for r in records if r]) df.to_csv(filename, index=False) - print(f"Results serialized to {filename=}.") + logger.info(f"Results serialized to {filename=}.") def _run_phase( self, diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 278683bdc254..4994dfedc4ef 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -34,16 +34,16 @@ def run_command(command: list[str], return_stdout=False): def run_scripts(): python_files = sorted(glob.glob(PATTERN)) + python_files = [f for f in python_files if f != "benchmarking_utils.py"] for file in python_files: - if file != "benchmarking_utils.py": - print(f"****** Running file: {file} ******") - command = f"python {file}" - try: - run_command(command.split()) - except SubprocessCallException as e: - print(f"Error running {file}: {e}") - continue + print(f"****** Running file: {file} ******") + command = f"python {file}" + try: + run_command(command.split()) + except SubprocessCallException as e: + print(f"Error running {file}:\n{e}") + continue def merge_csvs(): From e35ffe83a01a176d233dc0a670ce22a1a136ef6b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 14:57:53 +0530 Subject: [PATCH 32/56] reparate hub upload step. --- .github/workflows/benchmark.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 60c5dfc33452..043a623dfcc9 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -46,11 +46,18 @@ jobs: python utils/print_env.py - name: Diffusers Benchmarking env: - HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | - cd benchmarks && python run_all.py && python push_results.py + cd benchmarks && python run_all.py mkdir $BASE_PATH && mv *.csv $BASE_PATH + - name: Push results to the Hub + env: + HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} + run: + cd benchmarks && cp $BASE_PATH/collated_results.csv . + python push_results.py + - name: Test suite reports artifacts if: ${{ always() }} uses: actions/upload-artifact@v4 From d3c494a73c483081b0cc96175e40833ea4115c11 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 15:21:00 +0530 Subject: [PATCH 33/56] repository --- benchmarks/collated_results.csv | 4 ++++ benchmarks/populate_into_db.py | 1 + 2 files changed, 5 insertions(+) create mode 100644 benchmarks/collated_results.csv diff --git a/benchmarks/collated_results.csv b/benchmarks/collated_results.csv new file mode 100644 index 000000000000..c179ecefed91 --- /dev/null +++ b/benchmarks/collated_results.csv @@ -0,0 +1,4 @@ +Unnamed: 0,scenario,model_cls,num_params_M,flops_M,time_plain_s,mem_plain_GB,time_compile_s,mem_compile_GB,fullgraph,mode,github_sha +0.0,stabilityai/stable-diffusion-xl-base-1.0-bf16,UNet2DConditionModel,2567.46,5979098.32,0.074 (0.079),5.05,0.054 (0.055),5.24,True,default,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 +1.0,stabilityai/stable-diffusion-xl-base-1.0-layerwise-upcasting,UNet2DConditionModel,2567.46,5979098.32,0.152 (0.164),4.89,nan,nan,,,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 +2.0,stabilityai/stable-diffusion-xl-base-1.0-group-offload-leaf,UNet2DConditionModel,2567.46,5979098.32,0.56 (0.516),0.2,nan,nan,,,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index a64723b33f94..dd81cf2c6b26 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -82,6 +82,7 @@ def _cast_value(val, dtype: str): benchmark_id = f"{model_cls}-{scenario}-{id_for_benchmark}" measurements = { + "repository": "huggingface/diffusers", "scenario": scenario, "model_cls": model_cls, "num_params_M": num_params_M, From ce8d1ec649c948b77a83e0acb314683c8941bbc0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 15:30:28 +0530 Subject: [PATCH 34/56] remove csv --- benchmarks/collated_results.csv | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 benchmarks/collated_results.csv diff --git a/benchmarks/collated_results.csv b/benchmarks/collated_results.csv deleted file mode 100644 index c179ecefed91..000000000000 --- a/benchmarks/collated_results.csv +++ /dev/null @@ -1,4 +0,0 @@ -Unnamed: 0,scenario,model_cls,num_params_M,flops_M,time_plain_s,mem_plain_GB,time_compile_s,mem_compile_GB,fullgraph,mode,github_sha -0.0,stabilityai/stable-diffusion-xl-base-1.0-bf16,UNet2DConditionModel,2567.46,5979098.32,0.074 (0.079),5.05,0.054 (0.055),5.24,True,default,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 -1.0,stabilityai/stable-diffusion-xl-base-1.0-layerwise-upcasting,UNet2DConditionModel,2567.46,5979098.32,0.152 (0.164),4.89,nan,nan,,,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 -2.0,stabilityai/stable-diffusion-xl-base-1.0-group-offload-leaf,UNet2DConditionModel,2567.46,5979098.32,0.56 (0.516),0.2,nan,nan,,,8dd326fdba7c2063029b502e3b2ebd7a20a1bb95 From fc69eb86b66fd22715bd4d4921c78ff85b9dd9f2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 15:36:21 +0530 Subject: [PATCH 35/56] check --- benchmarks/benchmarking_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 9c0b8644f897..097644ca6f97 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -172,7 +172,11 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben records = [] for s in scenarios: try: - records.append(self.run_benchmark(s)) + record = self.run_benchmark(s) + if record: + records.append(record) + else: + logger.info(f"Record empty from scenario: {s.name}.") except Exception as e: logger.info(f"Running scenario ({s.name}) led to error:\n{e}") df = pd.DataFrame.from_records([r for r in records if r]) From a43e8ef784a628c68f7b31f55f5e0f7f5c1d32b3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:21:51 +0530 Subject: [PATCH 36/56] update --- .github/workflows/benchmark.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 043a623dfcc9..c482be97c736 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -55,8 +55,7 @@ jobs: env: HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} run: - cd benchmarks && cp $BASE_PATH/collated_results.csv . - python push_results.py + cd benchmarks && cp $BASE_PATH/collated_results.csv . && python push_results.py - name: Test suite reports artifacts if: ${{ always() }} From 2f5c8d07ffde0606258fc2a79a8ee4476e008be3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:29:08 +0530 Subject: [PATCH 37/56] update --- benchmarks/benchmarking_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 097644ca6f97..5c79e21bee85 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,5 +1,6 @@ import gc import inspect +import logging as std_logging from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union @@ -13,6 +14,7 @@ from diffusers.utils.testing_utils import require_torch_gpu, torch_device +std_logging.basicConfig(level=std_logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logger = logging.get_logger(__name__) From 1f7587e2c0578f348ec99f5423f52d682f57839c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:33:18 +0530 Subject: [PATCH 38/56] threading. --- benchmarks/benchmarking_utils.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 5c79e21bee85..45332fe22a13 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,6 +1,9 @@ import gc import inspect import logging as std_logging +import os +import queue +import threading from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union @@ -171,18 +174,34 @@ def run_benchmark(self, scenario: BenchmarkScenario): def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str): if not isinstance(scenarios, list): scenarios = [scenarios] - records = [] + record_queue = queue.Queue() + stop_signal = object() + + def _writer_thread(): + while True: + item = record_queue.get() + if item is stop_signal: + break + df_row = pd.DataFrame([item]) + write_header = not os.path.exists(filename) + df_row.to_csv(filename, mode="a", header=write_header, index=False) + record_queue.task_done() + + record_queue.task_done() + + writer = threading.Thread(target=_writer_thread, daemon=True) + writer.start() + for s in scenarios: try: record = self.run_benchmark(s) if record: - records.append(record) + record_queue.put(record) else: logger.info(f"Record empty from scenario: {s.name}.") except Exception as e: logger.info(f"Running scenario ({s.name}) led to error:\n{e}") - df = pd.DataFrame.from_records([r for r in records if r]) - df.to_csv(filename, index=False) + record_queue.put(stop_signal) logger.info(f"Results serialized to {filename=}.") def _run_phase( From 7a935a49ffb399f9b313f7217ffa646c2526de59 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:49:51 +0530 Subject: [PATCH 39/56] update --- benchmarks/run_all.py | 73 ++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 4994dfedc4ef..02b160efa6a4 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -1,12 +1,10 @@ import glob import os import subprocess - import pandas as pd - PATTERN = "benchmarking_*.py" -FINAL_CSV_FILENAME = "collated_results.csv" +FINAL_CSV_FILENAME = "collated_results.py" GITHUB_SHA = os.getenv("GITHUB_SHA", None) @@ -14,46 +12,63 @@ class SubprocessCallException(Exception): pass -# Taken from `test_examples_utils.py` def run_command(command: list[str], return_stdout=False): - """ - Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture - if an error occurred while running `command` - """ try: output = subprocess.check_output(command, stderr=subprocess.STDOUT) - if return_stdout: - if hasattr(output, "decode"): - output = output.decode("utf-8") - return output + if return_stdout and hasattr(output, "decode"): + return output.decode("utf-8") except subprocess.CalledProcessError as e: raise SubprocessCallException( - f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}" ) from e +def merge_csvs(final_csv: str = "collated_results.csv"): + all_csvs = glob.glob("*.csv") + if not all_csvs: + print("No result CSVs found to merge.") + return + + df_list = [] + for f in all_csvs: + try: + d = pd.read_csv(f) + except pd.errors.EmptyDataError: + # If a file existed but was zero‐bytes or corrupted, skip it + continue + df_list.append(d) + + if not df_list: + print("All result CSVs were empty or invalid; nothing to merge.") + return + + final_df = pd.concat(df_list, ignore_index=True) + if GITHUB_SHA is not None: + final_df["github_sha"] = GITHUB_SHA + final_df.to_csv(final_csv, index=False) + print(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.") + + def run_scripts(): python_files = sorted(glob.glob(PATTERN)) python_files = [f for f in python_files if f != "benchmarking_utils.py"] for file in python_files: - print(f"****** Running file: {file} ******") - command = f"python {file}" + script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo + print(f"\n****** Running file: {file} ******") + + partial_csv = f"{script_name}.csv" + if os.path.exists(partial_csv): + os.remove(partial_csv) + + command = ["python", file] try: - run_command(command.split()) + run_command(command) + print(f"→ {file} finished normally.") except SubprocessCallException as e: print(f"Error running {file}:\n{e}") - continue - - -def merge_csvs(): - all_csvs = glob.glob("*.csv") - final_df = pd.concat([pd.read_csv(f) for f in all_csvs]).reset_index(drop=True) - if GITHUB_SHA: - final_df["github_sha"] = GITHUB_SHA - final_df.to_csv(FINAL_CSV_FILENAME) - + finally: + print(f"→ Merging partial CSVs after {file} …") + merge_csvs(final_csv=FINAL_CSV_FILENAME) -if __name__ == "__main__": - run_scripts() - merge_csvs() + print(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}") From a6c7359f2d6ca22668871ec7e709b5fe747ea294 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:50:36 +0530 Subject: [PATCH 40/56] update --- benchmarks/run_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 02b160efa6a4..1f535f575891 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -59,6 +59,7 @@ def run_scripts(): partial_csv = f"{script_name}.csv" if os.path.exists(partial_csv): + print(f"Found {partial_csv}. Removing for safer numbers and duplication.") os.remove(partial_csv) command = ["python", file] From 1150cb09e0b81c576c6b067be862ed53af0a59e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 16:53:21 +0530 Subject: [PATCH 41/56] updaye --- benchmarks/run_all.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 1f535f575891..ca87ea105876 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -1,8 +1,10 @@ import glob import os import subprocess + import pandas as pd + PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.py" GITHUB_SHA = os.getenv("GITHUB_SHA", None) @@ -18,9 +20,7 @@ def run_command(command: list[str], return_stdout=False): if return_stdout and hasattr(output, "decode"): return output.decode("utf-8") except subprocess.CalledProcessError as e: - raise SubprocessCallException( - f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}" - ) from e + raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e def merge_csvs(final_csv: str = "collated_results.csv"): @@ -54,7 +54,7 @@ def run_scripts(): python_files = [f for f in python_files if f != "benchmarking_utils.py"] for file in python_files: - script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo + script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo print(f"\n****** Running file: {file} ******") partial_csv = f"{script_name}.csv" @@ -73,3 +73,7 @@ def run_scripts(): merge_csvs(final_csv=FINAL_CSV_FILENAME) print(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}") + + +if __name__ == "__main__": + run_scripts() From 6cc4707338ee3b2789b57aae8f9ac1232366ac3a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 17:07:35 +0530 Subject: [PATCH 42/56] update --- benchmarks/run_all.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index ca87ea105876..3aa75b641f8b 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -1,9 +1,15 @@ import glob +import logging as std_logging import os import subprocess import pandas as pd +from diffusers.utils import logging + + +std_logging.basicConfig(level=std_logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.get_logger(__name__) PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.py" @@ -26,7 +32,7 @@ def run_command(command: list[str], return_stdout=False): def merge_csvs(final_csv: str = "collated_results.csv"): all_csvs = glob.glob("*.csv") if not all_csvs: - print("No result CSVs found to merge.") + logger.info("No result CSVs found to merge.") return df_list = [] @@ -39,14 +45,14 @@ def merge_csvs(final_csv: str = "collated_results.csv"): df_list.append(d) if not df_list: - print("All result CSVs were empty or invalid; nothing to merge.") + logger.info("All result CSVs were empty or invalid; nothing to merge.") return final_df = pd.concat(df_list, ignore_index=True) if GITHUB_SHA is not None: final_df["github_sha"] = GITHUB_SHA final_df.to_csv(final_csv, index=False) - print(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.") + logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.") def run_scripts(): @@ -55,24 +61,24 @@ def run_scripts(): for file in python_files: script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo - print(f"\n****** Running file: {file} ******") + logger.info(f"\n****** Running file: {file} ******") partial_csv = f"{script_name}.csv" if os.path.exists(partial_csv): - print(f"Found {partial_csv}. Removing for safer numbers and duplication.") + logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.") os.remove(partial_csv) command = ["python", file] try: run_command(command) - print(f"→ {file} finished normally.") + logger.info(f"→ {file} finished normally.") except SubprocessCallException as e: - print(f"Error running {file}:\n{e}") + logger.info(f"Error running {file}:\n{e}") finally: - print(f"→ Merging partial CSVs after {file} …") + logger.info(f"→ Merging partial CSVs after {file} …") merge_csvs(final_csv=FINAL_CSV_FILENAME) - print(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}") + logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}") if __name__ == "__main__": From f1ee6315e5d2bd660d0575ab9f60fd0d5c4091c4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 18:00:27 +0530 Subject: [PATCH 43/56] update --- benchmarks/requirements.txt | 1 + utils/print_env.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt index 29e681a2d61d..41c6604f2cf5 100644 --- a/benchmarks/requirements.txt +++ b/benchmarks/requirements.txt @@ -1,5 +1,6 @@ pandas peft +psutil torchprofile bitsandbytes psycopg2==2.9.9 \ No newline at end of file diff --git a/utils/print_env.py b/utils/print_env.py index 2d2acb59d5cc..2fe0777daf7d 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -28,6 +28,16 @@ print("OS platform:", platform.platform()) print("OS architecture:", platform.machine()) +try: + import psutil + + vm = psutil.virtual_memory() + total_gb = vm.total / (1024**3) + available_gb = vm.available / (1024**3) + print(f"Total RAM: {total_gb:.2f} GB") + print(f"Available RAM: {available_gb:.2f} GB") +except ImportError: + pass try: import torch From 73e07baa1925a6e0fa5c8eaeae236368366c612e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 20:03:20 +0530 Subject: [PATCH 44/56] update --- .github/workflows/benchmark.yml | 10 +++++----- benchmarks/benchmarking_utils.py | 7 +++---- benchmarks/populate_into_db.py | 5 ++--- benchmarks/run_all.py | 8 +++----- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c482be97c736..0617f6353c0a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -49,26 +49,26 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | cd benchmarks && python run_all.py - mkdir $BASE_PATH && mv *.csv $BASE_PATH - name: Push results to the Hub env: HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} - run: - cd benchmarks && cp $BASE_PATH/collated_results.csv . && python push_results.py + run: | + cd benchmarks && python push_results.py + mkdir $BASE_PATH && cp *.csv $BASE_PATH - name: Test suite reports artifacts if: ${{ always() }} uses: actions/upload-artifact@v4 with: name: benchmark_test_reports - path: benchmarks/benchmark_outputs + path: benchmarks/$BASE_PATH # TODO: enable this once the connection problem has been resolved. # - name: Update benchmarking results to DB # env: # PGDATABASE: metrics - # PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} # TODO + # PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} # PGUSER: transformers_benchmarks # PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} # run: cd benchmarks && python populate_into_db.py diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index 45332fe22a13..c9da19be06dc 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -1,6 +1,6 @@ import gc import inspect -import logging as std_logging +import logging import os import queue import threading @@ -13,12 +13,11 @@ import torch.utils.benchmark as benchmark from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils import logging from diffusers.utils.testing_utils import require_torch_gpu, torch_device -std_logging.basicConfig(level=std_logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.get_logger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) def benchmark_fn(f, *args, **kwargs): diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index dd81cf2c6b26..7dc24ed8dd21 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -7,7 +7,7 @@ import psycopg2.extras -# FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" +FINAL_CSV_FILENAME = "collated_results.csv" # https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27 TABLE_NAME = "model_measurements" @@ -24,8 +24,7 @@ raise cur = conn.cursor() - # df = pd.read_csv(FINAL_CSV_FILENAME) - df = pd.read_csv("collated_results.csv") + df = pd.read_csv(FINAL_CSV_FILENAME) # Helper to cast values (or None) given a dtype def _cast_value(val, dtype: str): diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 3aa75b641f8b..23e18997c6a6 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -1,15 +1,13 @@ import glob -import logging as std_logging +import logging import os import subprocess import pandas as pd -from diffusers.utils import logging - -std_logging.basicConfig(level=std_logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.get_logger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) PATTERN = "benchmarking_*.py" FINAL_CSV_FILENAME = "collated_results.py" From 2a65a891270ce978897ee2f87275f27bdb3efaf2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 20:49:47 +0530 Subject: [PATCH 45/56] remove peft dep --- benchmarks/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt index 41c6604f2cf5..4f69b192f3f9 100644 --- a/benchmarks/requirements.txt +++ b/benchmarks/requirements.txt @@ -1,5 +1,4 @@ pandas -peft psutil torchprofile bitsandbytes From dc778b008a0bbca1117c5b536b72e19860c73d73 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 20:54:53 +0530 Subject: [PATCH 46/56] upgrade runner. --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0617f6353c0a..02b4aebd6d09 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -22,7 +22,7 @@ jobs: fail-fast: false max-parallel: 1 runs-on: - group: aws-g6e-xlarge-plus + group: aws-g6e-4xlarge container: image: diffusers/diffusers-pytorch-cuda options: --shm-size "16gb" --ipc host --gpus 0 From 8ddf57cfdd6c3d0fa3fcbb4ea16c370d684f32e0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 21:32:11 +0530 Subject: [PATCH 47/56] fix --- .github/workflows/benchmark.yml | 2 +- benchmarks/run_all.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 02b4aebd6d09..df596ce30ac9 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -62,7 +62,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: benchmark_test_reports - path: benchmarks/$BASE_PATH + path: benchmarks/${{ env.BASE_PATH }} # TODO: enable this once the connection problem has been resolved. # - name: Update benchmarking results to DB diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 23e18997c6a6..9b8f6bacc7f4 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) PATTERN = "benchmarking_*.py" -FINAL_CSV_FILENAME = "collated_results.py" +FINAL_CSV_FILENAME = "collated_results.csv" GITHUB_SHA = os.getenv("GITHUB_SHA", None) From 8161e36d610a2789b6b7040b0d3c7ccf742c3706 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 22:45:39 +0530 Subject: [PATCH 48/56] fixes --- benchmarks/benchmarking_utils.py | 12 ++++++++---- benchmarks/populate_into_db.py | 8 ++++---- benchmarks/push_results.py | 4 +++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index c9da19be06dc..d3bec9373f01 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -112,8 +112,12 @@ def run_benchmark(self, scenario: BenchmarkScenario): logger.info(f"Running scenario: {scenario.name}.") try: model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) - num_params = round(calculate_params(model) / 1e6, 2) - flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2) + num_params = round(calculate_params(model) / 1e9, 2) + try: + flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2) + except Exception as e: + logger.info(f"Problem in calculating FLOPs:\n{e}") + flops = None model.cpu() del model except Exception as e: @@ -156,8 +160,8 @@ def run_benchmark(self, scenario: BenchmarkScenario): result = { "scenario": scenario.name, "model_cls": scenario.model_cls.__name__, - "num_params_M": num_params, - "flops_M": flops, + "num_params_B": num_params, + "flops_G": flops, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"], "time_compile_s": compiled["time"], diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index 7dc24ed8dd21..82d11328f1c4 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -60,8 +60,8 @@ def _cast_value(val, dtype: str): for _, row in df.iterrows(): scenario = _cast_value(row.get("scenario"), "text") model_cls = _cast_value(row.get("model_cls"), "text") - num_params_M = _cast_value(row.get("num_params_M"), "float") - flops_M = _cast_value(row.get("flops_M"), "float") + num_params_B = _cast_value(row.get("num_params_B"), "float") + flops_G = _cast_value(row.get("flops_G"), "float") time_plain_s = _cast_value(row.get("time_plain_s"), "float") mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") time_compile_s = _cast_value(row.get("time_compile_s"), "float") @@ -84,8 +84,8 @@ def _cast_value(val, dtype: str): "repository": "huggingface/diffusers", "scenario": scenario, "model_cls": model_cls, - "num_params_M": num_params_M, - "flops_M": flops_M, + "num_params_B": num_params_B, + "flops_G": flops_G, "time_plain_s": time_plain_s, "mem_plain_GB": mem_plain_GB, "time_compile_s": time_compile_s, diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py index 30da0c053863..96c303f0d796 100644 --- a/benchmarks/push_results.py +++ b/benchmarks/push_results.py @@ -1,3 +1,5 @@ +import os + import pandas as pd from huggingface_hub import hf_hub_download, upload_file from huggingface_hub.utils import EntryNotFoundError @@ -50,7 +52,7 @@ def push_to_hf_dataset(): # combine current_results[column] = curr_str + append_str - + os.remove(FINAL_CSV_FILENAME) current_results.to_csv(FINAL_CSV_FILENAME, index=False) commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" From 807f5113860869a15c6639377745dabce5d0d1ba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 09:05:50 +0530 Subject: [PATCH 49/56] fix merging csvs. --- benchmarks/run_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index 9b8f6bacc7f4..9cf053f5480c 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -29,6 +29,7 @@ def run_command(command: list[str], return_stdout=False): def merge_csvs(final_csv: str = "collated_results.csv"): all_csvs = glob.glob("*.csv") + all_csvs = [f for f in all_csvs if f != final_csv] if not all_csvs: logger.info("No result CSVs found to merge.") return From a09768f2bd006dd36ce5e705e65cc87565b42702 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 14:40:27 +0530 Subject: [PATCH 50/56] push dataset to the Space repo for analysis. --- benchmarks/push_results.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py index 96c303f0d796..8be3b393683b 100644 --- a/benchmarks/push_results.py +++ b/benchmarks/push_results.py @@ -63,6 +63,13 @@ def push_to_hf_dataset(): repo_type="dataset", commit_message=commit_message, ) + upload_file( + repo_id="diffusers/benchmark-analyzer", + path_in_repo=FINAL_CSV_FILENAME, + path_or_fileobj=FINAL_CSV_FILENAME, + repo_type="space", + commit_message=commit_message, + ) if __name__ == "__main__": From 1683c47b4061b0c88fc4986bef8eb8683084f765 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 8 Jun 2025 20:03:45 +0530 Subject: [PATCH 51/56] warm up. --- benchmarks/benchmarking_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py index d3bec9373f01..c8c1a10ef899 100644 --- a/benchmarks/benchmarking_utils.py +++ b/benchmarks/benchmarking_utils.py @@ -19,6 +19,8 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logger = logging.getLogger(__name__) +NUM_WARMUP_ROUNDS = 5 + def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( @@ -230,6 +232,8 @@ def _run_phase( # measure run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext() with run_ctx: + for _ in range(NUM_WARMUP_ROUNDS): + _ = model(**inp) time_s = benchmark_fn(lambda m, d: m(**d), model, inp) mem_gb = torch.cuda.max_memory_allocated() / (1024**3) mem_gb = round(mem_gb, 2) From 858dc09a60ce7ed80836827d64747ab2c4bcbee9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Jun 2025 07:41:48 +0530 Subject: [PATCH 52/56] add a readme --- benchmarks/README.md | 69 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 benchmarks/README.md diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000000..cf9d090bcc88 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,69 @@ +# Diffusers Benchmarks + +Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as: + +* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`. +* Base + `torch.compile()` +* NF4 quantization +* Layerwise upcasting + +Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`). + +The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, i.e., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run. + +The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml). + +## Running the benchmarks manually + +First set up `torch` and install `diffusers` from the root of the directory: + +```py +pip install -e ".[quality,test]" +``` + +Then make sure the other dependencies are installed: + +```sh +cd benchmarks/ +pip install -r requirements.txt +``` + +We need to be authenticated to access some of the checkpoints used during benchmarking: + +```sh +huggingface-cli login +``` + +We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly). + +Then you can either launch the entire benchmarking suite by running: + +```sh +python run_all.py +``` + +Or, you can run the individual benchmarks. + +## Customizing the benchmarks + +We define "scenarios" to cover the most common ways in which these models are used. You can +define a new scenario, modifying an existing benchmark file: + +```py +BenchmarkScenario( + name=f"{CKPT_ID}-bnb-8bit", + model_cls=FluxTransformer2DModel, + model_init_kwargs={ + "pretrained_model_name_or_path": CKPT_ID, + "torch_dtype": torch.bfloat16, + "subfolder": "transformer", + "quantization_config": BitsAndBytesConfig(load_in_8bit=True), + }, + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), + model_init_fn=model_init_fn, +) +``` + +You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough. + +Happy benchmarking 🧨 \ No newline at end of file From 6bfdae67a4732381812c980e1a1b17b0c004e063 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 10 Jun 2025 15:15:32 +0530 Subject: [PATCH 53/56] Apply suggestions from code review Co-authored-by: Luc Georges --- benchmarks/populate_into_db.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index 82d11328f1c4..42eded5045d7 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -9,7 +9,8 @@ FINAL_CSV_FILENAME = "collated_results.csv" # https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27 -TABLE_NAME = "model_measurements" +BENCHMARKS_TABLE_NAME = "benchmarks" +MEASUREMENTS_TABLE_NAME = "model_measurements" if __name__ == "__main__": try: @@ -98,7 +99,7 @@ def _cast_value(val, dtype: str): # Batch-insert all rows insert_sql = f""" - INSERT INTO {TABLE_NAME} ( + INSERT INTO {MEASUREMENTS_TABLE_NAME} ( benchmark_id, measurements ) From 6b11973c0ca88988981a126f3bc97dbfa8451f21 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Jun 2025 15:45:33 +0530 Subject: [PATCH 54/56] address feedback --- .github/workflows/benchmark.yml | 18 ++++++---- benchmarks/populate_into_db.py | 62 +++++++++++++++++++++++++++------ 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index df596ce30ac9..e9dd1dec4e6d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -65,13 +65,17 @@ jobs: path: benchmarks/${{ env.BASE_PATH }} # TODO: enable this once the connection problem has been resolved. - # - name: Update benchmarking results to DB - # env: - # PGDATABASE: metrics - # PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} - # PGUSER: transformers_benchmarks - # PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} - # run: cd benchmarks && python populate_into_db.py + - name: Update benchmarking results to DB + env: + PGDATABASE: metrics + PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} + PGUSER: transformers_benchmarks + PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + run: | + commit_id=$GITHUB_SHA + commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70) + cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg" - name: Report success status if: ${{ success() }} diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py index 42eded5045d7..5c918e403733 100644 --- a/benchmarks/populate_into_db.py +++ b/benchmarks/populate_into_db.py @@ -1,6 +1,6 @@ -import datetime +import argparse import os -import uuid +import sys import pandas as pd import psycopg2 @@ -12,7 +12,45 @@ BENCHMARKS_TABLE_NAME = "benchmarks" MEASUREMENTS_TABLE_NAME = "model_measurements" + +def _init_benchmark(conn, branch, commit_id, commit_msg): + metadata = {} + repository = "huggingface/diffusers" + with conn.cursor() as cur: + cur.execute( + f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id", + (repository, branch, commit_id, commit_msg, metadata), + ) + benchmark_id = cur.fetchone()[0] + print(f"Initialised benchmark #{benchmark_id}") + return benchmark_id + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "branch", + type=str, + help="The branch name on which the benchmarking is performed.", + ) + + parser.add_argument( + "commit_id", + type=str, + help="The commit hash on which the benchmarking is performed.", + ) + + parser.add_argument( + "commit_msg", + type=str, + help="The commit message associated with the commit, truncated to 70 characters.", + ) + args = parser.parse_args() + return args + + if __name__ == "__main__": + args = parse_args() try: conn = psycopg2.connect( host=os.getenv("PGHOST"), @@ -21,8 +59,17 @@ password=os.getenv("PGPASSWORD"), ) print("DB connection established successfully.") - except Exception: - raise + except Exception as e: + print(f"Problem during DB init: {e}") + sys.exit(1) + + benchmark_id = _init_benchmark( + conn=conn, + branch=args.branch, + commit_id=args.commit_id, + commit_msg=args.commit_msg, + ) + cur = conn.cursor() df = pd.read_csv(FINAL_CSV_FILENAME) @@ -57,7 +104,6 @@ def _cast_value(val, dtype: str): try: rows_to_insert = [] - id_for_benchmark = str(uuid.uuid4()) + "_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") for _, row in df.iterrows(): scenario = _cast_value(row.get("scenario"), "text") model_cls = _cast_value(row.get("model_cls"), "text") @@ -76,13 +122,7 @@ def _cast_value(val, dtype: str): else: github_sha = None - if github_sha: - benchmark_id = f"{model_cls}-{scenario}-{github_sha}" - else: - benchmark_id = f"{model_cls}-{scenario}-{id_for_benchmark}" - measurements = { - "repository": "huggingface/diffusers", "scenario": scenario, "model_cls": model_cls, "num_params_B": num_params_B, From ba7a89c662b0030be7cbb8b385e451559fd8b25d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 10 Jun 2025 17:00:46 +0530 Subject: [PATCH 55/56] Apply suggestions from code review --- benchmarks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index cf9d090bcc88..574779bb5059 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -9,7 +9,7 @@ Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`). -The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, i.e., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run. +The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run. The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml). From f9285fdf8c42d5f4093ed4fb24da594f5d8821bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Jun 2025 17:03:39 +0530 Subject: [PATCH 56/56] disable db workflow. --- .github/workflows/benchmark.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index e9dd1dec4e6d..9b8489ebef03 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -65,17 +65,17 @@ jobs: path: benchmarks/${{ env.BASE_PATH }} # TODO: enable this once the connection problem has been resolved. - - name: Update benchmarking results to DB - env: - PGDATABASE: metrics - PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} - PGUSER: transformers_benchmarks - PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - run: | - commit_id=$GITHUB_SHA - commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70) - cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg" + # - name: Update benchmarking results to DB + # env: + # PGDATABASE: metrics + # PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} + # PGUSER: transformers_benchmarks + # PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} + # BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + # run: | + # commit_id=$GITHUB_SHA + # commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70) + # cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg" - name: Report success status if: ${{ success() }}