diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 214835062a05..a5e38278cdf2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -207,7 +207,7 @@ class StableDiffusionControlNetPipeline( model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"] def __init__( self, @@ -1323,6 +1323,7 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 73ffeeb5e79c..be2874f48e69 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -185,7 +185,7 @@ class StableDiffusionControlNetImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"] def __init__( self, @@ -1294,6 +1294,7 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 875dbed38c4d..40092e5f47f3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -184,7 +184,7 @@ class StableDiffusionControlNetInpaintPipeline( model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"] def __init__( self, @@ -1476,6 +1476,7 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):