diff --git a/tests/fixtures/performance/checker/singletensor.py b/tests/fixtures/performance/checker/singletensor.py new file mode 100644 index 0000000..0e1a6cb --- /dev/null +++ b/tests/fixtures/performance/checker/singletensor.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() + + +# These should raise flags +optimizer_adam = torch.optim.Adam(model.parameters()) +optimizer_sgd = torch.optim.SGD(model.parameters(), lr=0.01) +optimizer_adamw = torch.optim.AdamW(model.parameters()) + +# These should not raise flags +optimizer_adam = torch.optim.Adam(model.parameters(), foreach=True) +optimizer_sgd = torch.optim.SGD(model.parameters(), lr=0.01, foreach=True) +optimizer_adamw = torch.optim.AdamW(model.parameters(), foreach=True) +optimizer_adamw = torch.optim.AdamW(model.parameters(), foreach=False) \ No newline at end of file diff --git a/tests/fixtures/performance/checker/singletensor.txt b/tests/fixtures/performance/checker/singletensor.txt new file mode 100644 index 0000000..8c4a847 --- /dev/null +++ b/tests/fixtures/performance/checker/singletensor.txt @@ -0,0 +1,3 @@ +9:18 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. +10:17 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. +11:19 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. \ No newline at end of file diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py index 8f0d6fc..9188224 100644 --- a/tests/fixtures/performance/checker/zerograd.py +++ b/tests/fixtures/performance/checker/zerograd.py @@ -3,7 +3,7 @@ x = torch.ones((100, 100)) model = nn.Sequential() -optimizer = torch.optim.Adam(model.parameters()) +optimizer = torch.optim.Adam(model.parameters(),foreach=True) # This should raise flags optimizer.zero_grad(set_to_none=False) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5e96e38..11c079c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -22,6 +22,7 @@ TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ) __version__ = "0.7.0" @@ -45,6 +46,7 @@ TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 45f2438..e2a9e3d 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -11,6 +11,7 @@ from .performance import ( TorchSynchronizedDataLoaderVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ) from .security import TorchUnsafeLoadVisitor from .vision import ( @@ -34,4 +35,5 @@ "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", "TorchGradNotSetToNonePatternVisitor", + "TorchOptimizerSingleTensorPatternVisitor", ] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 0558af5..d9bac76 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -65,3 +65,41 @@ def visit_Call(self, node): error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message(), ) + + +class TorchOptimizerSingleTensorPatternVisitor(TorchVisitor): + """ + Reimplementation of OptimizerSingleTensorPattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR403", + ( + "Deteced optimizer running with single tensor implementation. " + "Please enable multi tensor implementation by passing 'foreach=True' " + "into optimizer." + ), + ) + ] + + optimizers_with_foreach = ["Adam", "SGD", "AdamW"] + + def visit_Call(self, node): + + qualified_name = self.get_qualified_name_for_call(node) + + for optimizer in self.optimizers_with_foreach: + + if qualified_name and qualified_name.endswith(f"{optimizer}"): + + foreach_arg = self.get_specific_arg(node, arg_name="foreach", arg_pos=1) + + if foreach_arg is None: + + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + )