diff --git a/advanced_source/static_quantization_tutorial.py b/advanced_source/static_quantization_tutorial.py index 72ea3cc703e..5e7c21ab46d 100644 --- a/advanced_source/static_quantization_tutorial.py +++ b/advanced_source/static_quantization_tutorial.py @@ -263,7 +263,7 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res