Skip to content

Commit 7a9c8c0

Browse files
authored
[TorchFix] Deprecated codemods to honor aliases (#4680)
By introducing `torchfix.common.get_module_name(node: cst.Call)` and using it in deprecated symbols codemods Fixes pytorch/test-infra#4452
1 parent 3fdc1cc commit 7a9c8c0

File tree

6 files changed

+39
-8
lines changed

6 files changed

+39
-8
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
import torch as foo
3+
import torch as bar
4+
import torch as baz
5+
6+
A = torch.arange(9.0).reshape(3, 3)
7+
A3 = foo.chain_matmul(A, A, A)
8+
rc1 = bar.cholesky(torch.mm(A3.t(), A3))
9+
rc2 = baz.qr(torch.mm(A.t(), A))
10+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
import torch as foo
3+
import torch as bar
4+
import torch as baz
5+
6+
A = torch.arange(9.0).reshape(3, 3)
7+
A3 = foo.linalg.multi_dot([A, A, A])
8+
rc1 = bar.linalg.cholesky(torch.mm(A3.t(), A3))
9+
rc2 = baz.linalg.qr(torch.mm(A.t(), A))
10+

torchfix/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,11 @@ def on_leave(self, original_node, updated_node):
120120
return updated_node
121121

122122
return tree.visit(MultiChildReplacementTransformer(replacement_map))
123+
124+
125+
def get_module_name(node: cst.Call, default: Optional[str] = None) -> Optional[str]:
126+
if not isinstance(node.func, cst.Attribute):
127+
return default
128+
if not isinstance(node.func.value, cst.Name):
129+
return default
130+
return node.func.value.value

torchfix/visitors/deprecated_symbols/chain_matmul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import libcst as cst
2+
from ...common import get_module_name
23

34

45
def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode:
@@ -19,7 +20,8 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode:
1920
replacement_args = [matrices_arg]
2021
else:
2122
replacement_args = [matrices_arg, out_arg]
22-
replacement = cst.parse_expression("torch.linalg.multi_dot(args)")
23+
module_name = get_module_name(node, 'torch')
24+
replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)")
2325
replacement = replacement.with_changes(args=replacement_args)
2426

2527
return replacement
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import libcst as cst
2-
from ...common import TorchVisitor
2+
from ...common import (TorchVisitor, get_module_name)
33

44

55
def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode:
@@ -12,20 +12,20 @@ def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode:
1212
comma=cst.MaybeSentinel.DEFAULT
1313
)
1414
upper_arg = TorchVisitor.get_specific_arg(node, "upper", 1)
15+
module_name = get_module_name(node, "torch")
1516

1617
if (
1718
upper_arg is not None
1819
and cst.ensure_type(upper_arg.value, cst.Name).value == "True"
1920
):
20-
replacement = cst.parse_expression("torch.linalg.cholesky(A).mH")
21+
replacement = cst.parse_expression(f"{module_name}.linalg.cholesky(A).mH")
2122
replacement = replacement.with_deep_changes(
2223
# Ignore type error, see https://github.com/Instagram/LibCST/issues/963
2324
old_node=cst.ensure_type(replacement.value, cst.Call).args, # type: ignore
2425
value=[input_arg],
2526
)
2627
else:
27-
replacement = cst.parse_expression("torch.linalg.cholesky(A)").with_changes(
28-
args=[input_arg]
29-
)
28+
replacement = cst.parse_expression(f"{module_name}.linalg.cholesky(A)")
29+
replacement = replacement.with_changes(args=[input_arg])
3030

3131
return replacement

torchfix/visitors/deprecated_symbols/qr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import libcst as cst
22
from typing import Optional
3-
from ...common import TorchVisitor
3+
from ...common import (TorchVisitor, get_module_name)
44

55

66
def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]:
@@ -27,7 +27,8 @@ def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]:
2727
comma=cst.MaybeSentinel.DEFAULT
2828
)
2929
replacement_args = [input_arg]
30-
replacement = cst.parse_expression("torch.linalg.qr(args)")
30+
module_name = get_module_name(node, "torch")
31+
replacement = cst.parse_expression(f"{module_name}.linalg.qr(args)")
3132
replacement = replacement.with_changes(args=replacement_args)
3233

3334
return replacement

0 commit comments

Comments
 (0)