Skip to content

Modular Diffusers Guiders #11311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
Expand Down Expand Up @@ -129,12 +130,22 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]

else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"LayerSkipConfig",
"apply_faster_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
Expand Down Expand Up @@ -711,10 +722,18 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_layer_skip,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
Expand Down
26 changes: 26 additions & 0 deletions src/diffusers/guiders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from ..utils import is_torch_available


if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .skip_layer_guidance import SkipLayerGuidance

GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance]
174 changes: 174 additions & 0 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Union, Tuple, List

import torch

from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs


class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416

Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""

_input_predictions = ["pred_cond", "pred_uncond"]

def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)

self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None

def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return _default_prepare_inputs(denoiser, self.num_conditions, *args)

def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None

if not self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)

if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred

@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0

@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions

def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False

is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step

is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)

return is_within_range and not is_close


class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average


def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred
127 changes: 127 additions & 0 deletions src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Union, Tuple, List

import torch

from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs


class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598

CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]

Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]

The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)

The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.

Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""

_input_predictions = ["pred_cond", "pred_uncond"]

def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
):
super().__init__(start, stop)

self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation

def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
return _default_prepare_inputs(denoiser, self.num_conditions, *args)

def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None

if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift

if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred

@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0

@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions

def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False

is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step

is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)

return is_within_range and not is_close
Loading