Skip to content

Commit 01921ff

Browse files
authored
[TorchFix] Add codemod for unsafe load (#4715)
pytorch/test-infra#4671 added linter-only `TorchUnsafeLoadVisitor`, but it turned out that the issue is so widespread that manual fixes would be tedious. The codemod is somewhat unsafe correctness-wise because full pickling functionality may still be needed even without `pickle_module`, but I think it's OK because it fixes a security-related issue and the codemods need to be verified anyway. Maybe later we should add something like Ruff's recently added `--unsafe-fixes`: https://docs.astral.sh/ruff/linter/#fix-safety I used this for pytorch/vision#8105
1 parent bc8d45b commit 01921ff

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

tests/fixtures/security/checker/load.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import random
2+
import dill
23
import torch
4+
35
# Not OK
46
torch.load('tensors.pt')
7+
torch.load('f.pt', pickle_module=dill, encoding='utf-8')
8+
59
# All these are OK
610
torch.load('tensors.pt', weights_only=True)
711
torch.load('tensors.pt', weights_only=False)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
4:1 TOR102 `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.
1+
6:1 TOR102 `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.
2+
7:1 TOR102 `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.

torchfix/visitors/security/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,31 @@ def visit_Call(self, node):
2525
cst.metadata.WhitespaceInclusivePositionProvider, node
2626
)
2727

28+
# Add `weights_only=True` if there is no `pickle_module`.
29+
# (do not add `weights_only=False` with `pickle_module`, as it
30+
# needs to be an explicit choice).
31+
#
32+
# This codemod is somewhat unsafe correctness-wise
33+
# because full pickling functionality may still be needed
34+
# even without `pickle_module`,
35+
# so the changes need to be verified/tested.
36+
replacement = None
37+
pickle_module_arg = self.get_specific_arg(node, "pickle_module", 2)
38+
if pickle_module_arg is None:
39+
weights_only_arg = cst.ensure_type(
40+
cst.parse_expression("f(weights_only=True)"), cst.Call
41+
).args[0]
42+
replacement = node.with_changes(
43+
args=node.args + (weights_only_arg,)
44+
)
45+
2846
self.violations.append(
2947
LintViolation(
3048
error_code=self.ERROR_CODE,
3149
message=self.MESSAGE,
3250
line=position_metadata.start.line,
3351
column=position_metadata.start.column,
3452
node=node,
35-
replacement=None,
53+
replacement=replacement,
3654
)
3755
)

0 commit comments

Comments
 (0)