From 4fdc6d090958aff08851c4e33c3b0cea489e0288 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 8 May 2025 17:12:52 -0700 Subject: [PATCH 1/4] slice scatter support for dynamic cases --- .../dynamo/lowering/_decompositions.py | 36 +++++---- .../py/dynamo/lowering/test_decompositions.py | 73 +++++++++++++++++++ 2 files changed, 93 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 8037858151..acaeeb5ab2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -213,22 +213,26 @@ def slice_scatter_decomposition( return src_tensor # Ensure start, end, and step are all integers - assert isinstance(start, int), "start must be an integer" - assert isinstance(end, int), "end must be an integer" - assert isinstance(step, int), "step must be an integer" - - cat_tensors = [] - index_tensor_shape = [] - for i, src_each_dim in enumerate(list(src_dim)): - if i != dim: - index_tensor_shape.append(src_each_dim) - for index in range(start, end, step): - cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim) - index_tensor = index_tensor.to(device_input_tensor) - index_tensor_64 = index_tensor.to(torch.int64) - output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) - return output_tensor + # Ensure start, end, and step are all integers + assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt" + assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt" + assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt" + + src_dim = src_tensor.shape + # step == 0 is not a valid torch case + # also src_dim should be equal to slice dimension + + if start == 0 and end == dim_size and step == 1: + return src_tensor + + indices = torch.arange( + start, end, step, device=device_input_tensor, dtype=torch.int64 + ) + index_tensor = indices.view( + [-1 if i == dim else 1 for i in range(input_tensor.dim())] + ) + index_tensor = index_tensor.expand_as(src_tensor) + return torch.scatter(input_tensor.clone(), dim, index_tensor, src_tensor) @register_torch_trt_decomposition( diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index b63e0f3bf7..9f0f53a4d8 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -812,6 +812,79 @@ def forward(self, x, src, dim, start, end, step): f"Slice_scatter TRT outputs don't match with the original model.", ) + def test_lowering_slice_scatter_dynamic_module(self): + class sliceScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, start=None, end=None, step=1): + y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = { + torch.ops.aten.scatter.src, + } + unexpected_ops = {torch.ops.aten.select_scatter} + + a = torch.zeros(8, 8).cuda() + b = torch.ones(8, 2).cuda() + + # 0-D tensors for dynamic scalar values + start = torch.tensor(1, dtype=torch.int64).cuda() + end = torch.tensor(6, dtype=torch.int64).cuda() + step = torch.tensor(1, dtype=torch.int64).cuda() + + # Mark scalar tensors as dynamic (note: shape = ()) + torch._dynamo.mark_dynamic(start, (), min=1, max=3) + torch._dynamo.mark_dynamic(end, (), min=4, max=6) + + inputs = (a, b, start, end, None, step) + fx_graph = torch.fx.symbolic_trace(sliceScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Slice_scatter TRT outputs don't match with the original model.", + ) + def test_lowering_select_scatter_dimZero_module(self): class selectScatter(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: From 1b8fd586a52b34641c4d2bc70f5cefe7ce0bc67d Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 15 May 2025 15:46:32 -0700 Subject: [PATCH 2/4] :Using torch.export workflow since compile is showing error in tensor guard --- .../dynamo/lowering/_decompositions.py | 15 +--- .../py/dynamo/lowering/test_decompositions.py | 81 +++++-------------- 2 files changed, 24 insertions(+), 72 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index acaeeb5ab2..c0e44b1826 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -201,30 +201,21 @@ def slice_scatter_decomposition( start = get_positive_dim(start, input_tensor.shape[dim]) if end is None: # Ensure end is int end = dim_size - end = get_positive_dim(end, input_tensor.shape[dim]) + end = ( + get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end + ) if step is None: step = 1 - src_dim = src_tensor.shape # step == 0 is not a valid torch case - # also src_dim should be equal to slice dimension - if start == 0 and end == dim_size and step == 1: return src_tensor - # Ensure start, end, and step are all integers # Ensure start, end, and step are all integers assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt" assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt" assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt" - src_dim = src_tensor.shape - # step == 0 is not a valid torch case - # also src_dim should be equal to slice dimension - - if start == 0 and end == dim_size and step == 1: - return src_tensor - indices = torch.arange( start, end, step, device=device_input_tensor, dtype=torch.int64 ) diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 9f0f53a4d8..e7c7b33672 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -817,72 +817,33 @@ class sliceScatter(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def forward(self, x, src, dim, start=None, end=None, step=1): - y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step) + def forward(self, x, src): + y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = { - torch.ops.aten.scatter.src, - } - unexpected_ops = {torch.ops.aten.select_scatter} - - a = torch.zeros(8, 8).cuda() - b = torch.ones(8, 2).cuda() - - # 0-D tensors for dynamic scalar values - start = torch.tensor(1, dtype=torch.int64).cuda() - end = torch.tensor(6, dtype=torch.int64).cuda() - step = torch.tensor(1, dtype=torch.int64).cuda() - - # Mark scalar tensors as dynamic (note: shape = ()) - torch._dynamo.mark_dynamic(start, (), min=1, max=3) - torch._dynamo.mark_dynamic(end, (), min=4, max=6) - - inputs = (a, b, start, end, None, step) fx_graph = torch.fx.symbolic_trace(sliceScatter()) - unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( - fx_graph, - inputs, - expected_ops=expected_ops, - unexpected_ops=unexpected_ops, - min_block_size=1, - ) - self.assertEqual( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEqual( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", + dim1 = torch.export.Dim("dim1", min=8, max=10) + dynamic_shapes = { + "x": [torch.export.Dim.STATIC, dim1], + "src": [torch.export.Dim.STATIC, None], + } + inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda()) + exported_program = torch.export.export( + sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes ) - + fx_graph = exported_program.module() + inputs = [ + torch_tensorrt.Input( + min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10] + ), + torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]), + ] torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - truncate_double=True, - pass_through_build_failures=True, - ) - optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - f"Slice_scatter TRT outputs don't match with the original model.", + trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs) + inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda()) + torch.testing.assert_close( + trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL ) def test_lowering_select_scatter_dimZero_module(self): From 5fd34b86c5c95697382d72510a5871d3f7c84413 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 16 May 2025 12:52:19 -0700 Subject: [PATCH 3/4] undoing the clone since it is not required --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index c0e44b1826..c5c191bb7b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -223,7 +223,7 @@ def slice_scatter_decomposition( [-1 if i == dim else 1 for i in range(input_tensor.dim())] ) index_tensor = index_tensor.expand_as(src_tensor) - return torch.scatter(input_tensor.clone(), dim, index_tensor, src_tensor) + return torch.scatter(input_tensor, dim, index_tensor, src_tensor) @register_torch_trt_decomposition( From 409d15f11784685b7c857814f1cb99cef547367f Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 13 Jun 2025 12:16:01 -0700 Subject: [PATCH 4/4] Removing the fx_graph symbolic trace --- tests/py/dynamo/lowering/test_decompositions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index e7c7b33672..32bf7f8b98 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -821,8 +821,6 @@ def forward(self, x, src): y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1) return y - fx_graph = torch.fx.symbolic_trace(sliceScatter()) - dim1 = torch.export.Dim("dim1", min=8, max=10) dynamic_shapes = { "x": [torch.export.Dim.STATIC, dim1],