Skip to content

Commit 3fdc1cc

Browse files
authored
[TorchFix] Add TorchUnsafeLoadVisitor (#4671)
See pytorch/pytorch#31875 and pytorch/pytorch#111806 for discussion.
1 parent 9f67b51 commit 3fdc1cc

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import random
2+
import torch
3+
# Not OK
4+
torch.load('tensors.pt')
5+
# All these are OK
6+
torch.load('tensors.pt', weights_only=True)
7+
torch.load('tensors.pt', weights_only=False)
8+
use_weights_only = random.choice([False, True])
9+
torch.load('tensors.pt', weights_only=use_weights_only)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
4:1 TOR102 `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TorchVisionDeprecatedPretrainedVisitor,
1717
TorchVisionDeprecatedToTensorVisitor,
1818
)
19+
from .visitors.security import TorchUnsafeLoadVisitor
1920

2021
__version__ = "0.1.1"
2122

@@ -31,6 +32,7 @@ def GET_ALL_VISITORS():
3132
TorchSynchronizedDataLoaderVisitor(),
3233
TorchVisionDeprecatedPretrainedVisitor(),
3334
TorchVisionDeprecatedToTensorVisitor(),
35+
TorchUnsafeLoadVisitor(),
3436
]
3537

3638

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import libcst as cst
2+
from ...common import TorchVisitor, LintViolation
3+
4+
5+
class TorchUnsafeLoadVisitor(TorchVisitor):
6+
"""
7+
Warn on `torch.load` not having explicit `weights_only`.
8+
See https://github.com/pytorch/pytorch/issues/31875.
9+
"""
10+
11+
ERROR_CODE = "TOR102"
12+
MESSAGE = (
13+
"`torch.load` without `weights_only` parameter is unsafe. "
14+
"Explicitly set `weights_only` to False only if you trust the data you load "
15+
"and full pickle functionality is needed, otherwise set "
16+
"`weights_only=True`."
17+
)
18+
19+
def visit_Call(self, node):
20+
qualified_name = self.get_qualified_name_for_call(node)
21+
if qualified_name == "torch.load":
22+
weights_only_arg = self.get_specific_arg(node, "weights_only", -1)
23+
if weights_only_arg is None:
24+
position_metadata = self.get_metadata(
25+
cst.metadata.WhitespaceInclusivePositionProvider, node
26+
)
27+
28+
self.violations.append(
29+
LintViolation(
30+
error_code=self.ERROR_CODE,
31+
message=self.MESSAGE,
32+
line=position_metadata.start.line,
33+
column=position_metadata.start.column,
34+
node=node,
35+
replacement=None,
36+
)
37+
)

0 commit comments

Comments
 (0)