Skip to content

Commit 5dc3471

Browse files
authored
[SVD] support generators that are created on GPU (#6484)
* debug generator * fix? * fix? * fix * remove print. * revert none check
1 parent 9df566e commit 5dc3471

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,15 +429,20 @@ def __call__(
429429
fps = fps - 1
430430

431431
# 4. Encode input image using VAE
432-
image = self.image_processor.preprocess(image, height=height, width=width)
433-
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
432+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
433+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
434434
image = image + noise_aug_strength * noise
435435

436436
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
437437
if needs_upcasting:
438438
self.vae.to(dtype=torch.float32)
439439

440-
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
440+
image_latents = self._encode_vae_image(
441+
image,
442+
device=device,
443+
num_videos_per_prompt=num_videos_per_prompt,
444+
do_classifier_free_guidance=self.do_classifier_free_guidance,
445+
)
441446
image_latents = image_latents.to(image_embeddings.dtype)
442447

443448
# cast back to fp16 if needed

0 commit comments

Comments
 (0)