Skip to content

Commit 28c1f3b

Browse files
committed
test_add_noise_device
1 parent 7b53e97 commit 28c1f3b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tests/schedulers/test_scheduler_lcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_add_noise_device(self, num_inference_steps=10):
9999
scaled_sample = scheduler.scale_model_input(sample, 0.0)
100100
self.assertEqual(sample.shape, scaled_sample.shape)
101101

102-
noise = torch.randn_like(scaled_sample).to(torch_device)
102+
noise = torch.randn(scaled_sample.shape).to(torch_device)
103103
t = scheduler.timesteps[5][None]
104104
noised = scheduler.add_noise(scaled_sample, noise, t)
105105
self.assertEqual(noised.shape, scaled_sample.shape)

tests/schedulers/test_schedulers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def test_add_noise_device(self):
722722
scaled_sample = scheduler.scale_model_input(sample, 0.0)
723723
self.assertEqual(sample.shape, scaled_sample.shape)
724724

725-
noise = torch.randn_like(scaled_sample).to(torch_device)
725+
noise = torch.randn(scaled_sample.shape).to(torch_device)
726726
t = scheduler.timesteps[5][None]
727727
noised = scheduler.add_noise(scaled_sample, noise, t)
728728
self.assertEqual(noised.shape, scaled_sample.shape)

0 commit comments

Comments
 (0)