File tree Expand file tree Collapse file tree 4 files changed +49
-0
lines changed
tests/fixtures/security/checker Expand file tree Collapse file tree 4 files changed +49
-0
lines changed Original file line number Diff line number Diff line change
1
+ import random
2
+ import torch
3
+ # Not OK
4
+ torch .load ('tensors.pt' )
5
+ # All these are OK
6
+ torch .load ('tensors.pt' , weights_only = True )
7
+ torch .load ('tensors.pt' , weights_only = False )
8
+ use_weights_only = random .choice ([False , True ])
9
+ torch .load ('tensors.pt' , weights_only = use_weights_only )
Original file line number Diff line number Diff line change
1
+ 4:1 TOR102 `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.
Original file line number Diff line number Diff line change 16
16
TorchVisionDeprecatedPretrainedVisitor ,
17
17
TorchVisionDeprecatedToTensorVisitor ,
18
18
)
19
+ from .visitors .security import TorchUnsafeLoadVisitor
19
20
20
21
__version__ = "0.1.1"
21
22
@@ -31,6 +32,7 @@ def GET_ALL_VISITORS():
31
32
TorchSynchronizedDataLoaderVisitor (),
32
33
TorchVisionDeprecatedPretrainedVisitor (),
33
34
TorchVisionDeprecatedToTensorVisitor (),
35
+ TorchUnsafeLoadVisitor (),
34
36
]
35
37
36
38
Original file line number Diff line number Diff line change
1
+ import libcst as cst
2
+ from ...common import TorchVisitor , LintViolation
3
+
4
+
5
+ class TorchUnsafeLoadVisitor (TorchVisitor ):
6
+ """
7
+ Warn on `torch.load` not having explicit `weights_only`.
8
+ See https://github.com/pytorch/pytorch/issues/31875.
9
+ """
10
+
11
+ ERROR_CODE = "TOR102"
12
+ MESSAGE = (
13
+ "`torch.load` without `weights_only` parameter is unsafe. "
14
+ "Explicitly set `weights_only` to False only if you trust the data you load "
15
+ "and full pickle functionality is needed, otherwise set "
16
+ "`weights_only=True`."
17
+ )
18
+
19
+ def visit_Call (self , node ):
20
+ qualified_name = self .get_qualified_name_for_call (node )
21
+ if qualified_name == "torch.load" :
22
+ weights_only_arg = self .get_specific_arg (node , "weights_only" , - 1 )
23
+ if weights_only_arg is None :
24
+ position_metadata = self .get_metadata (
25
+ cst .metadata .WhitespaceInclusivePositionProvider , node
26
+ )
27
+
28
+ self .violations .append (
29
+ LintViolation (
30
+ error_code = self .ERROR_CODE ,
31
+ message = self .MESSAGE ,
32
+ line = position_metadata .start .line ,
33
+ column = position_metadata .start .column ,
34
+ node = node ,
35
+ replacement = None ,
36
+ )
37
+ )
You can’t perform that action at this time.
0 commit comments