Skip to content

Suggestion for speeding up index_for_timestep by removing sequential nonzero() calls in samplers #9417

Open
@ethanweber

Description

@ethanweber

Is your feature request related to a problem? Please describe.
First off, thanks for the great codebase and providing so many resources! I just wanted to provide some insight into an improvement I made for myself, in case you'd like to include it for all samplers. I'm using the FlowMatchEulerDiscreteScheduler and after profiling, I've noticed that it's unexpectedly slowing down my training speeds. I'll describe the issue and proposed solution here rather than making a PR, since this would touch a lot of code and perhaps someone on the diffusers team would like to implement it.

Describe the solution you'd like.
This line in particular is very slow because it is a for loop step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] and the self.index_for_timestep() is calling a nonzero() function which is slow.

step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]

Describe alternatives you've considered.
I've changed the code as follows:

# huggingface code
def index_for_timestep(self, timestep, schedule_timesteps=None):
    if schedule_timesteps is None:
        schedule_timesteps = self.timesteps

    indices = (schedule_timesteps == timestep).nonzero()

    # The sigma index that is taken for the **very** first `step`
    # is always the second index (or the last index if there is only 1)
    # This way we can ensure we don't accidentally skip a sigma in
    # case we start in the middle of the denoising schedule (e.g. for image-to-image)
    pos = 1 if len(indices) > 1 else 0

    return indices[pos].item()

changed to =>

# my code
def index_for_timestep(self, timestep, schedule_timesteps=None):
    if schedule_timesteps is None:
        schedule_timesteps = self.timesteps

    num_steps = len(schedule_timesteps)
    start = schedule_timesteps[0].item()
    end = schedule_timesteps[-1].item()
    indices = torch.round(((timestep - start) / (end - start)) * (num_steps - 1)).long()

    return indices

and

# huggingface code
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
    step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]

changed to =>

# my code
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
    step_indices = self.index_for_timestep(timestep, schedule_timesteps)

Additional context.
Just wanted to bring this modification to your attention since it could be a training speedup for folks. 🙂 Especially when someone has a large batch size > 1 and this for loop it occurring with nonzero search operations. Some other small changes might be necessary to ensure compatibility of the function changes, but I suspect it could help everyone. Thanks for the consideration!

Metadata

Metadata

Assignees

No one assigned

    Labels

    contributions-welcomehelp wantedExtra attention is neededperformanceAnything related to performance improvements, profiling and benchmarkingwip

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions