-
Notifications
You must be signed in to change notification settings - Fork 6k
Modify the implementation of retrieve_timesteps in CogView4-Control. #11125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 54 commits
a97fca2
c30ca7a
5c25cd2
44bfd4c
a9f448e
2cbdf35
df83bf2
8bba67a
b9d864b
5d2e994
ebeb1e4
95e8504
940c23b
7a68a3e
2a81772
1d91a24
dff4b29
050b97c
b007be0
25f4e4b
7ffecbc
b4e11e7
efa0f41
f55e3cc
9410e46
29b0c81
52d4ebf
65b3719
90830ed
71f9235
19d7d27
fe6287a
2f74c4e
264060e
9a10ceb
b6e10e7
692e5cc
fc3830c
98a2417
c774f45
687faa4
8abca19
cbfeb0b
347dd17
775bb8c
985baa9
c2a1985
88abb39
3e3387e
ddb31d3
64637ef
4174736
8cdd36f
785e230
07ef22e
1e93e98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,7 @@ def calculate_shift( | |
return mu | ||
|
||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | ||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps | ||
def retrieve_timesteps( | ||
scheduler, | ||
num_inference_steps: Optional[int] = None, | ||
|
@@ -100,10 +100,19 @@ def retrieve_timesteps( | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | ||
second element is the number of inference steps. | ||
""" | ||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that how I understand it? |
||
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
|
||
if timesteps is not None and sigmas is not None: | ||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | ||
if timesteps is not None: | ||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
if not accepts_timesteps and not accepts_sigmas: | ||
raise ValueError( | ||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
f" timestep or sigma schedules. Please check whether you are using the correct scheduler." | ||
) | ||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) | ||
timesteps = scheduler.timesteps | ||
num_inference_steps = len(timesteps) | ||
elif timesteps is not None and sigmas is None: | ||
if not accepts_timesteps: | ||
raise ValueError( | ||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
|
@@ -112,9 +121,8 @@ def retrieve_timesteps( | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | ||
timesteps = scheduler.timesteps | ||
num_inference_steps = len(timesteps) | ||
elif sigmas is not None: | ||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
if not accept_sigmas: | ||
elif timesteps is None and sigmas is not None: | ||
if not accepts_sigmas: | ||
raise ValueError( | ||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
f" sigmas schedules. Please check whether you are using the correct scheduler." | ||
|
@@ -515,7 +523,7 @@ def __call__( | |
The output format of the generate image. Choose between | ||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead | ||
Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead | ||
of a plain tuple. | ||
attention_kwargs (`dict`, *optional*): | ||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||
|
@@ -533,8 +541,6 @@ def __call__( | |
max_sequence_length (`int`, defaults to `224`): | ||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. | ||
|
||
Examples: | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zRzRzRzRzRzRzR These should not be removed, otherwise the tests will fail: https://github.com/huggingface/diffusers/actions/runs/14009637640/job/39228136089?pr=11125#step:15:64 Let's add this back and we should be good to merge There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
Returns: | ||
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: | ||
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just the change here is correct @zRzRzRzRzRzRzR