Skip to content

[modular diffusers] more refactor #11235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 50 commits into
base: modular-diffusers
Choose a base branch
from
Open

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Apr 8, 2025

continue #9672 here

Getting started with Modular diffusers

Guides on the basics

Modular Repo

see an example of modular repo here https://huggingface.co/YiYiXu/modular-diffdiff/tree/main
it includes

remote pipeline blocks

you can load a pipeline block from hub with one line of code and mix-and-match it into your workflow

diffdiff_blocks = ModularPipelineMixin.from_pretrained(repo_id, trust_remote_code=True)

You can also load it directly into a ModularNode - diffusers abstraction to create UI nodes

node = ModularNode.from_pretrained("YiYiXu/modular-diffdiff", trust_remote_code=True)

it comes with the node definition too

node.save_mellon_config("diffdiff_mellon_config.json")

you can use it to make a mellon node like this

class DiffDiff(NodeBase):
    def __init__(self, node_id=None):
        super().__init__(node_id)
        self._diffdiff_block = ModularNode.from_pretrained("YiYiXu/modular-diffdiff", trust_remote_code=True)
        self._diffdiff_block.setup(components=components)

    def execute(self, **kwargs):
        return self._diffdiff_block.execute(**kwargs)

ModularNode is platform-agnostic, it will work with both comfy and mellon or whatever UI platform we support in the future. It will save a default node config, allow you to make changes and automatically convert it to usable module map in different platform (this is TO-DO but very doable!)

see an example of what we curreently have here

modular_model_index.json

You can pack all the relevant checkpoint for your workflow in modular_model_index.json, you can reference subfolders in other repos; it also work beyound diffusers and transformers, as long as it has a from_pretrained method that accept the standard loading arguments, subfolder, variance etc, for example, we used image_gen_aux here https://huggingface.co/YiYiXu/modular-depth-block/blob/main/block.py

TODO

  • move methods to pipeline block
  • fix the awkward __init__ of pipeline block, no longer do the self.component["vae"] = None thing
  • guider refactor by @a-r-r-o-w (more details on this PR Modular Diffusers Guiders #11311)
  • introducing ModularLoader (more details on this PR [modular diffusers] introducing ModularLoader #11462)
  • removed ModularPipeline, replaced it with ModularPielineMixin:
    - it is the base class for PipelineBlock, SequentialPipelines and AutoPipelines
    - it provides a set_loader method that creates a ModularLoader based on its expected_componeents
    - it provides run() method to invoke the pipeline call (you can use it like images = node.run(prompt=...))
  • make LoopSequentialBlocks: explore ways to make the denoise loop itself modular, i.e you can add/remove blocks inside the loop ([Modular diffusers] more refactors  #11484)
  • [ModularLoader] support partial loading/save for model components
  • [ModularLoader] should be able work with model_index.json too
  • kwargs_type: eithere remove or make sure it works well
  • make guider a config mixin (maybe image processor too)
  • components manager: offloading strategy based on maximum model memory (Instead of what's available on the device)

Documentation

  • simpified the _repr_ of various pipeline blocks and will further simplify, the goal is to make _repr_ concise and simple while move more detailed information onto @doc property
  • add doc
  • simplify the auto-documentation process from the user/developer perspective: we will keep a schema table (type/default/description) for all the standard inputs for a given model type, e.g. SDXL, Flux, so user only need to add info for for additional variables they add with their new blocks

Testing

  • Add a basic testing suits

Custom code sharing

testing under this PR https://huggingface.co/YiYiXu/modular-diffdiff

UI Nodes

faster test script

(Click to expand)
# ModularLoader PR examples

from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentsManager
from diffusers.modular_pipelines.stable_diffusion_xl.modular_pipeline_block_mappings import TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS

from diffusers.utils import load_image
import torch
dtype = torch.float16
device = torch.device("cuda:2")

# create pipeline blocks (here we use diffusers official block presets and just assmeble them, but you can create your own)
t2i_blocks = TEXT2IMAGE_BLOCKS.copy()
i2i_blocks = IMAGE2IMAGE_BLOCKS.copy()
i2i_blocks.pop("image_encoder") # don't need this for refiner because we use latent from t2i pipeline

class Text2ImageBlocks(SequentialPipelineBlocks):
    block_classes = list(t2i_blocks.values())
    block_names = list(t2i_blocks.keys())


class RefinerBlocks(SequentialPipelineBlocks):
    block_classes = list(i2i_blocks.values())
    block_names = list(i2i_blocks.keys())

# this is your text2image pipeline
t2i = Text2ImageBlocks()

# this is your refiner pipeline
refiner = RefinerBlocks()

# create components manager
components = ComponentsManager()

# setup loader with component manager
t2i.setup_loader(modular_repo="YiYiXu/modular-loader-t2i", component_manager=components, collection="t2i")
t2i.loader.load(torch_dtype=dtype)

# set up offloading strategy on component manager, this way they will only be loaded when used and offloaded when not used
components.enable_auto_cpu_offload(device=device)



prompt = "A crystal orb resting on a wooden table with a yellow rubber duck, surrounded by aged scrolls and alchemy tools, illuminated by candlelight, detailed texture, high resolution image"

# generate image, use `run` method here so it will:
# 1. run the pipeline blocks in the order/logic defined in SequentialPipelineBlocks
# 2. prepare the inputs for each block.__call__() method: pipeline_state and `pipeline`
#    - the `pipeline` input passed to each blocks that contains all the models the block needs is actually just the ModularLoader we just setup here! 
image = t2i.run(prompt=prompt, num_inference_steps=25, output="images")[0]
image.save("yiyi_test_7_t2i.png")



# ok now I want to setup refiner, but reuse the same components because I know only the unet is different 
# here I already have repo made for refiner-specific configs,i will just use it 
refiner.setup_loader(modular_repo="YiYiXu/modular_refiner", component_manager=components, collection="refiner")

# if you run refiner.load() here it would just work, but you get complaints from component manager about duplicated components
# it is easy to remove the duplicates, but let's not do that for now 
# feel free to uncomment and try it out
# refiner.loader.load(torch_dtype=dtype)

# let's only load unet for now
refiner.loader.load(component_names="unet", torch_dtype=dtype)
# uncomment to check the loader: you can verify only unet is loaded 
# print(refiner.loader)
# uncomment this line below to check out component manager:
# you should see that only unet is registered in component manager under the "refiner" collection
# print(components)

# let's reuse the text_encoder and tokenizer from t2i pipeline
# this gets you everything under "t2i" collection that is not unet/text_encoder/tokenizer
# we get them as tuples of (name, component)
reuse_components = components.get("!unet|text_encoder|tokenizer", collection="t2i", as_name_component_tuples=True)
for name, component in iter(reuse_components):
    print(f"reuse {name}: {component.__class__.__name__}")


# ok now let's update the refiner loader with the reuse components and the new unet
# since these are exact same objects (same id() and everything), component manager won't re-register them either
refiner.loader.update(**dict(reuse_components))
# uncomment this line below to check out component manager:
# you should see that only refiner unet is registered in component manager under the "refiner" collection
# print(refiner.loader)
# print(components)




# running a refiner example 

latents = t2i.run(prompt=prompt, num_inference_steps=25, output_type="latent", output="images")
print(f"latents: {latents}")
# assert False
print(f" refiner: {refiner}")
print(f" refiner.doc: {refiner.doc}")
image = refiner.run(image_latents=latents, prompt=prompt, num_inference_steps=10, output="images")
image[0].save("yiyi_test_7_example2.png")

# all the loading related methods are availabe on the "loader", not the pipeline itself
# for example, use lora

t2i.loader.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
image = t2i.run(prompt=prompt, num_inference_steps=25, output="images")[0]
image.save("yiyi_test_7_example3_lora.png")

# uncomment this line below to check out component manager:
# you should see info about the lora weights loaded in unet/text_encoder
# print(components)


# ip-adapter
t2i.loader.unload_lora_weights()
t2i.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
t2i.loader.set_ip_adapter_scale(0.6)

ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
image = t2i.run(prompt=prompt, num_inference_steps=25, ip_adapter_image=ip_adapter_image, output="images")[0]
image.save("yiyi_test_7_example4_ip-adapter.png")
# uncomment this line below to check out component manager:
# you should see info about the ip-adapter weights loaded in unet/text_encoder
# print(components)

slower testing script

(Click to expand)
# test modular pipeline (slower test)

import os
import shutil

import torch
import numpy as np
from PIL import Image
import cv2


from diffusers import (
    ControlNetModel,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
    AdaptiveProjectedGuidance,
    ClassifierFreeGuidance,
    SkipLayerGuidance,
    LayerSkipConfig,
)
from diffusers.utils import load_image
from diffusers import StableDiffusionXLAutoPipeline, ComponentsManager, ComponentSpec
from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:3"
dtype = torch.float16
num_images_per_prompt = 1

# test related parameters
test_pag = True
test_lora = False
tests_to_run = [1,2,3,4,5,6,7,8,9,10,11,12,13,14, 15, 16]


# define output folder
out_folder = "modular_test_outputs_0517"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """

    def print_mem(mem_size, name):
        mem_gb = mem_size / 1024**3
        mem_mb = mem_size / 1024**2
        print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")


# make a preprocessor block (mostly for controlnet)
from diffusers.modular_pipelines import PipelineBlock, PipelineState, InputParam, OutputParam
class GetImageStep(PipelineBlock):

    PROCESSOR_IDS = set([
        "canny", "lineart_anime",
    ])

    def __init__(self):
        from controlnet_aux.processor import Processor
        self.processor = Processor
    
    @staticmethod
    def make_canny(image):
        image = np.array(image)
        image = cv2.Canny(image, 100, 200)
        image = image[:, :, None]
        image = np.concatenate([image, image, image], axis=2)
        return Image.fromarray(image)
    
    def make_lineart_anime(self, image):
        return self.processor("lineart_anime")(image)
    
    
    def check_inputs(self, data) -> None:
        """
        Validates that `processor_id` is one of the supported processors.
        Raises:
            ValueError: if `processor_id` is not in PROCESSOR_IDS.
        """

        if data.image_url is None and data.image is None:
            raise ValueError("Either `image_url` or `image` must be provided.")

        if data.image_url is not None and data.image is not None:
            raise ValueError("Only one of `image_url` or `image` must be provided.")
        
        if data.processor_id is not None and data.processor_id not in self.PROCESSOR_IDS:
            raise ValueError(
                f"Processor id '{data.processor_id}' not found. "
                f"Please use one of the following: {self.PROCESSOR_IDS}"
            )
    
    @property
    def inputs(self):
        return [
            InputParam("image", type_hint=Image.Image),
            InputParam("image_url", type_hint=str, description="The url of the image to load"),
            InputParam("size", description="The size of the image"),
            InputParam("processor_id", type_hint=str, description="The id of the processor to use for controlnet")
        ]

    @property
    def intermediates_outputs(self):
        return [
          OutputParam("image", type_hint=Image.Image),
        ]

    def __call__(self, pipeline, state: PipelineState):

        data = self.get_block_state(state)
        self.check_inputs(data)

        if data.image is None:
            data.image = load_image(data.image_url).convert("RGB")
        
        if data.size is not None:
            data.image = data.image.resize(data.size)

        if data.processor_id is not None:
            if data.processor_id == "canny":
                data.image = self.make_canny(data.image)
            elif data.processor_id == "lineart_anime":
                data.image = self.make_lineart_anime(data.image)
        
        self.add_block_state(state, data)
        
        return pipeline, state
    

# (1)Define inputs
# prompts
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"
# image urls
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
inpaint_img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
inpaint_mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
ip_adapter_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
# strength/scale etc
strength = 0.9 #img2img strength
inpaint_strength = 0.99 #inpainting strength
controlnet_conditioning_scale = 0.5  # recommended for good generalization)

# get all the images
get_image_step = GetImageStep()
get_image_step.loader = None # should be able to skip this step if no checkpoints needed (TODO)

init_image = get_image_step.run(image_url=url,output="image")
control_image = get_image_step.run(image_url=url, processor_id="canny",output="image")
controlnet_union_image = get_image_step.run(image_url=url, processor_id="lineart_anime",output="image")
inpaint_image = get_image_step.run(image_url=inpaint_img_url, size=(1024, 1024),output="image")
inpaint_mask = get_image_step.run(image_url=inpaint_mask_url, size=(1024, 1024),output="image")
ip_adapter_image = get_image_step.run(image_url=ip_adapter_image_url,output="image")



# (2) create pipelines  
auto_pipeline = StableDiffusionXLAutoPipeline()
refiner_pipeline = StableDiffusionXLAutoPipeline()


# (3) define model components needed for the tests

# specs
refiner_spec = ComponentSpec(name="refiner", type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-refiner-1.0", subfolder="unet")
inpaint_spec = ComponentSpec(name="inpaint", type_hint=UNet2DConditionModel, repo="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet")
controlnet_union_spec = ComponentSpec(name="controlnet_union", type_hint=ControlNetUnionModel, repo="brad-twinkl/controlnet-union-sdxl-1.0-promax")
# repos
ip_adapter_repo = "h94/IP-Adapter"
modular_repo = "YiYiXu/modular_demo"
# create guiders: pag/cfg/apg
pag_guider_spec_config = {
    "guidance_scale": 5.0,
    "skip_layer_guidance_scale": 3.0,
    "skip_layer_config": LayerSkipConfig(
        indices=[2, 3, 7, 8],
        fqn="mid_block.attentions.0.transformer_blocks",
        skip_attention=False,
        skip_ff=False,
        skip_attention_scores=True,
    ),
    "start": 0.0,
    "stop": 1.0,
}
pag_guider_spec = ComponentSpec(name="guider", type_hint=SkipLayerGuidance, config=pag_guider_spec_config, default_creation_method="from_config")
cfg_guider_spec = ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config")
apg_guider_spec = ComponentSpec(name="guider", type_hint=AdaptiveProjectedGuidance, config={"guidance_scale": 15.0, "adaptive_projected_guidance_momentum": -0.3, "adaptive_projected_guidance_rescale": 12.0, "start": 0.01}, default_creation_method="from_config")


# (4) create components manager and load the pipeline
components = ComponentsManager()
auto_pipeline.setup_loader(modular_repo=modular_repo, component_manager=components, collection="sdxl_auto")
auto_pipeline.loader.load(torch_dtype=dtype)
print(f" auto_pipeline.loader: {auto_pipeline.loader}")


print(f" ")
print(f"auto_pipeline.loader:")
print(auto_pipeline.loader)
print(f" loader components:")
for key, value in auto_pipeline.loader.components.items():
    if isinstance(value, torch.nn.Module):
        print(f" {key}: {value.__class__.__name__}, dtype: {value.dtype}, device: {value.device}")


# enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")



# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

# assert False

if 1 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator, 
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test1_out_text2img_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()



# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
auto_pipeline.loader.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
if 2 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator, 
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test2_out_text2img_lora.png")

# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.loader.unload_lora_weights()
auto_pipeline.loader.update(guider=pag_guider_spec)

if 3 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator,
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.loader.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.loader.set_ip_adapter_scale(0.6)

if 4 in tests_to_run: 
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator, 
        ip_adapter_image=ip_adapter_image,
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test4_out_text2img_ip_adapter_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.loader.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

if 5 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        control_image=control_image, 
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_images_per_prompt=num_images_per_prompt,
        generator=generator,
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()


# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

if 6 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=init_image, 
        strength=strength, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator, 
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test6_out_img2img_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")
if 7 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=init_image, 
        strength=strength, 
        num_images_per_prompt=num_images_per_prompt,
        control_image=control_image, 
        controlnet_conditioning_scale=controlnet_conditioning_scale, 
        generator=generator, 
        output="images"
    )

    for i, image in enumerate(images_output):
        print(f"image: {image.size}")
        image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

# test refiner pipeline but not using a repo
refiner_pipeline.setup_loader(component_manager=components, collection="refiner")

print(f" ")
print(f" after setup refiner loader (initial setup, should be empty)")
print(refiner_pipeline.loader)
print(f" ")

refiner_components = components.get("!unet|text_encoder|tokenizer|guider", collection="sdxl_auto", as_name_component_tuples=True)
print(f" reuse these components for refiner pipeline:")
for name, component in refiner_components:
    print(f" {name}: {component.__class__.__name__}")
print(f" ")


refiner_pipeline.loader.update(**dict(refiner_components), unet=refiner_spec.load(torch_dtype=dtype), force_zeros_for_empty_prompt=False, requires_aesthetics_score=True)
print(f" ")
print(f" refiner loader after update")
print(refiner_pipeline.loader)
print(f" ")

print(f" ")
print(f" ")
print(f" components info")
print(components)
print(f" ")


print(f" running test8: img2img with refiner (reuse components from components manager)")

if 8 in tests_to_run:
    print(f" ")
    print(f" step1 run auto pipeline to get latents")
    generator = torch.Generator(device="cuda").manual_seed(0)
    latents = auto_pipeline.run(
        prompt=prompt, 
        num_images_per_prompt=num_images_per_prompt,
        generator=generator, 
        denoising_end=0.8,
        output="images",
        output_type="latent",
    )
    print(f" ")
    print(f" step2 run refiner pipeline to get images")
    images_output = refiner_pipeline.run(
        image_latents=latents,  
        prompt=prompt, 
        denoising_start=0.8, 
        generator=generator, 
        num_images_per_prompt=num_images_per_prompt,
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

if 9 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=inpaint_image, 
        mask_image=inpaint_mask, 
        height=1024, 
        width=1024, 
        generator=generator, 
        num_images_per_prompt=num_images_per_prompt,
        strength=inpaint_strength,  # make sure to use `strength` below 1.0
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

if 10 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        control_image=control_image, 
        image=inpaint_image,
        height=1024,
        width=1024,
        mask_image=inpaint_mask,
        num_images_per_prompt=num_images_per_prompt,
        controlnet_conditioning_scale=controlnet_conditioning_scale, 
        strength=inpaint_strength,  # make sure to use `strength` below 1.0
        generator=generator,
        output="images"
    )
    for i, image in enumerate(images_output):
      image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

inpaint_unet = inpaint_spec.load(torch_dtype=dtype)
# make a backup to swtich back later
sdxl_unet_spec = ComponentSpec.from_component("unet", auto_pipeline.loader.unet)
auto_pipeline.loader.update(unet=inpaint_unet)
if 11 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=inpaint_image, 
        mask_image=inpaint_mask, 
        height=1024, 
        width=1024, 
        generator=generator, 
        num_images_per_prompt=num_images_per_prompt,
        output="images"
    )
    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()
print(f" after update with inpaint_unet")
print(components)


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

if 12 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=inpaint_image, 
        mask_image=inpaint_mask, 
        height=1024, 
        width=1024, 
        generator=generator, 
        padding_mask_crop=33, 
        num_images_per_prompt=num_images_per_prompt,
        strength=inpaint_strength,  # make sure to use `strength` below 1.0
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test13: apg

print(f" ")
print(f" running test13: apg")

auto_pipeline.loader.update(guider=apg_guider_spec, unet=sdxl_unet_spec.load(torch_dtype=dtype))
print(f" autopipeline loader after update with apg guider and unet")
print(auto_pipeline.loader)
print(f" ")

print(f" ")
print(f" components info")
print(components)
print(f" ")

if 13 in tests_to_run:
    generator = torch.Generator().manual_seed(0)
    images_output = auto_pipeline.run(
      prompt=prompt, 
      generator=generator,
      num_inference_steps=20,
      num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
      height=896,
      width=768,
      output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test13_out_apg_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.loader.update(
    controlnet=controlnet_union_spec.load(torch_dtype=dtype), 
    guider=pag_guider_spec
)

print(f" autopipeline loader after update with controlnet (controlnet_union), unet (sdxl_auto), and guider (pag_guider)")
print(auto_pipeline.loader)
print(f" ")

print(f" ")
print(f" components info")
print(components)
print(f" ")

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")
print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

if 14 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)

    images_output = auto_pipeline.run(
        prompt=prompt, 
        control_mode=[3],
        control_image=[controlnet_union_image], 
        num_images_per_prompt=num_images_per_prompt,
        height=1024,
        width=1024,
        generator=generator,
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test15: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

if 15 in tests_to_run:
    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=init_image, 
        generator=generator, 
        control_mode=[3], 
        control_image=[controlnet_union_image], 
        num_images_per_prompt=num_images_per_prompt, 
        height=1024, 
        width=1024, 
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

if 16 in tests_to_run:

    generator = torch.Generator(device="cuda").manual_seed(0)
    images_output = auto_pipeline.run(
        prompt=prompt, 
        image=init_image, 
        mask_image=inpaint_mask, 
        control_image=controlnet_union_image,
        control_mode=[3],
        height=1024, 
        width=1024, 
        generator=generator, 
        output="images"
    )

    for i, image in enumerate(images_output):
        image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
    print(f" save modular output ({len(images_output)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)
</details>










@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

a-r-r-o-w and others added 2 commits April 26, 2025 03:42
* cfg; slg; pag; sdxl without controlnet

* support sdxl controlnet

* support controlnet union

* update

* update

* cfg zero*

* use unwrap_module for torch compiled modules

* remove guider kwargs

* remove commented code

* remove old guider

* fix slg bug

* remove debug print

* autoguidance

* smoothed energy guidance

* add note about seg

* tangential cfg

* cfg plus plus

* support cfgpp in ddim

* apply review suggestions

* refactor

* rename enable/disable

* remove cfg++ for now

* rename do_classifier_free_guidance->prepare_unconditional_embeds

* remove unused
* cfg; slg; pag; sdxl without controlnet

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants