Skip to content

Commit 9d6f47f

Browse files
authored
More assertive require_grad check (#4709)
Fixes pytorch/test-infra#4687
1 parent 7a9c8c0 commit 9d6f47f

File tree

5 files changed

+20
-8
lines changed

5 files changed

+20
-8
lines changed

tests/fixtures/misc/checker/require_grad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import random
12
import torch
23
x = torch.zeros(1)
34
x.require_grad = False
45
x.require_grad = True
6+
grad = random.choice([False, True])
7+
x.require_grad = grad
58

69
# Don't trigger
710
x.requires_grad = False
811
require_grad = False
9-
x.require_grad = 10
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
3:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
21
4:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
2+
5:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
3+
7:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?

tests/fixtures/misc/codemod/require_grad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22
x = torch.zeros(1)
33
x.require_grad = False
44
x.require_grad = True
5+
6+
# from https://github.com/pytorch/test-infra/issues/4687
7+
import torch.nn as nn
8+
model = nn.Module()
9+
for name, param in model.named_parameters():
10+
param.require_grad = 'specific_layer' in name

tests/fixtures/misc/codemod/require_grad.py.out

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@ import torch
22
x = torch.zeros(1)
33
x.requires_grad = False
44
x.requires_grad = True
5+
6+
# from https://github.com/pytorch/test-infra/issues/4687
7+
import torch.nn as nn
8+
model = nn.Module()
9+
for name, param in model.named_parameters():
10+
param.requires_grad = 'specific_layer' in name

torchfix/visitors/misc/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ class TorchRequireGradVisitor(TorchVisitor):
1414
MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?"
1515

1616
def visit_Assign(self, node):
17-
# Look for any assignment with `require_grad` attribute on the left
18-
# and `False` or `True` on the right.
17+
# Look for any assignment with `require_grad` attribute on the left.
1918
#
20-
# If this causes false-positives on real code (unlikely),
21-
# we can do type inference (not sure if feasible here) or
22-
# at least check that `torch` is imported in the file.
19+
# This is unlikely to cause false-positives on real code, especially
20+
# because TorchFix only looks at files that have a `torch` string.
2321
if m.matches(
2422
node,
2523
m.Assign(
@@ -28,7 +26,6 @@ def visit_Assign(self, node):
2826
target=m.Attribute(attr=m.Name(value="require_grad"))
2927
)
3028
],
31-
value=(m.Name("True") | m.Name("False")),
3229
),
3330
):
3431
replacement = node.with_deep_changes(

0 commit comments

Comments
 (0)