Skip to content

Commit a3cc641

Browse files
Refac training utils.py (#9815)
* Refac training utils.py * quality --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
1 parent 13e8fde commit a3cc641

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/diffusers/training_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def set_seed(seed: int):
4343
4444
Args:
4545
seed (`int`): The seed to set.
46+
47+
Returns:
48+
`None`
4649
"""
4750
random.seed(seed)
4851
np.random.seed(seed)
@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
5861
"""
5962
Computes SNR as per
6063
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
64+
for the given timesteps using the provided noise scheduler.
65+
66+
Args:
67+
noise_scheduler (`NoiseScheduler`):
68+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
69+
the SNR values.
70+
timesteps (`torch.Tensor`):
71+
A tensor of timesteps for which the SNR is computed.
72+
73+
Returns:
74+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
6175
"""
6276
alphas_cumprod = noise_scheduler.alphas_cumprod
6377
sqrt_alphas_cumprod = alphas_cumprod**0.5

0 commit comments

Comments
 (0)