We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a7f07c1 commit 4d5a96eCopy full SHA for 4d5a96e
tests/quantization/bnb/test_mixed_int8.py
@@ -221,7 +221,7 @@ def test_keep_modules_in_fp32(self):
221
self.assertTrue(module.weight.dtype == torch.int8)
222
223
# test if inference works.
224
- with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ with torch.no_grad() and torch.autocast(model.device.type, dtype=torch.float16):
225
input_dict_for_transformer = self.get_dummy_inputs()
226
model_inputs = {
227
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
0 commit comments