Skip to content

Commit 08e62fe

Browse files
authored
Scheduling fixes on MPS (#10549)
* use np.int32 in scheduling * test_add_noise_device * -np.int32, fixes
1 parent 9e1b8a0 commit 08e62fe

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def set_timesteps(
342342
timesteps = torch.from_numpy(timesteps)
343343
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
344344

345-
self.timesteps = timesteps.to(device=device)
345+
self.timesteps = timesteps.to(device=device, dtype=torch.float32)
346346

347347
# empty dt and derivative
348348
self.prev_derivative = None

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
311311
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
312312

313313
self.sigmas = torch.from_numpy(sigmas).to(device=device)
314-
self.timesteps = torch.from_numpy(timesteps).to(device=device)
314+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
315315
self._step_index = None
316316
self._begin_index = None
317317
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

tests/schedulers/test_scheduler_lcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_add_noise_device(self, num_inference_steps=10):
9999
scaled_sample = scheduler.scale_model_input(sample, 0.0)
100100
self.assertEqual(sample.shape, scaled_sample.shape)
101101

102-
noise = torch.randn_like(scaled_sample).to(torch_device)
102+
noise = torch.randn(scaled_sample.shape).to(torch_device)
103103
t = scheduler.timesteps[5][None]
104104
noised = scheduler.add_noise(scaled_sample, noise, t)
105105
self.assertEqual(noised.shape, scaled_sample.shape)

tests/schedulers/test_schedulers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def model(sample, t, *args):
361361
if isinstance(t, torch.Tensor):
362362
num_dims = len(sample.shape)
363363
# pad t with 1s to match num_dims
364-
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
364+
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device, dtype=sample.dtype)
365365

366366
return sample * t / (t + 1)
367367

@@ -722,7 +722,7 @@ def test_add_noise_device(self):
722722
scaled_sample = scheduler.scale_model_input(sample, 0.0)
723723
self.assertEqual(sample.shape, scaled_sample.shape)
724724

725-
noise = torch.randn_like(scaled_sample).to(torch_device)
725+
noise = torch.randn(scaled_sample.shape).to(torch_device)
726726
t = scheduler.timesteps[5][None]
727727
noised = scheduler.add_noise(scaled_sample, noise, t)
728728
self.assertEqual(noised.shape, scaled_sample.shape)

0 commit comments

Comments
 (0)