Skip to content

Commit f34da3d

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
packed_to_padded now accepts all sizes
Summary: We need to make packing/unpacking in 2 places for mixed frame raysampling (metrics and raysampler) but those tensors that need to be unpacked/packed have more than two dimensions. I could have reshaped and stored dimensions but this seems to just complicate code there with something which packed_to_padded should support. I could have made a separate function for implicitron but it would confusing to have two different padded_to_packed functions inside pytorch3d codebase one of which does packing for (b, max) and (b, max, f) and the other for (b, max, …) Reviewed By: bottler Differential Revision: D39729026 fbshipit-source-id: 2bdebf290dcc6c316b7fe1aeee49bbb5255e508c
1 parent c2d876c commit f34da3d

File tree

2 files changed

+88
-64
lines changed

2 files changed

+88
-64
lines changed

pytorch3d/ops/packed_to_padded.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def packed_to_padded(inputs, first_idxs, max_size):
6565
Torch wrapper that handles allowed input shapes. See description below.
6666
6767
Args:
68-
inputs: FloatTensor of shape (F,) or (F, D), representing the packed
68+
inputs: FloatTensor of shape (F,) or (F, ...), representing the packed
6969
batch tensor, e.g. areas for faces in a batch of meshes.
7070
first_idxs: LongTensor of shape (N,) where N is the number of
7171
elements in the batch and `first_idxs[i] = f`
7272
means that the inputs for batch element i begin at `inputs[f]`.
7373
max_size: Max length of an element in the batch.
7474
7575
Returns:
76-
inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, D)
76+
inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...)
7777
where max_size is max of `sizes`. The values for batch element i
7878
which start at `inputs[first_idxs[i]]` will be copied to
7979
`inputs_padded[i, :]`, with zeros padding out the extra inputs.
@@ -83,15 +83,20 @@ def packed_to_padded(inputs, first_idxs, max_size):
8383
(N, max_size, 1).
8484
"""
8585
# if inputs is of shape (F,), reshape into (F, 1)
86-
flat = False
87-
if inputs.dim() == 1:
88-
flat = True
86+
input_shape = inputs.shape
87+
n_dims = inputs.dim()
88+
if n_dims == 1:
8989
inputs = inputs.unsqueeze(1)
90+
else:
91+
inputs = inputs.reshape(input_shape[0], -1)
9092
inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size)
9193
# if flat is True, reshape output to (N, max_size) from (N, max_size, 1)
92-
if flat:
93-
inputs_padded = inputs_padded.squeeze(2)
94-
return inputs_padded
94+
# else reshape output to (N, max_size, ...)
95+
if n_dims == 1:
96+
return inputs_padded.squeeze(2)
97+
if n_dims == 2:
98+
return inputs_padded
99+
return inputs_padded.view(*inputs_padded.shape[:2], *input_shape[1:])
95100

96101

97102
class _PaddedToPacked(Function):
@@ -147,7 +152,7 @@ def padded_to_packed(inputs, first_idxs, num_inputs):
147152
Torch wrapper that handles allowed input shapes. See description below.
148153
149154
Args:
150-
inputs: FloatTensor of shape (N, max_size) or (N, max_size, D),
155+
inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...),
151156
representing the padded tensor, e.g. areas for faces in a batch of
152157
meshes.
153158
first_idxs: LongTensor of shape (N,) where N is the number of
@@ -156,20 +161,25 @@ def padded_to_packed(inputs, first_idxs, num_inputs):
156161
num_inputs: Number of packed entries (= F)
157162
158163
Returns:
159-
inputs_packed: FloatTensor of shape (F,) or (F, D) where
160-
`inputs_packed[first_idx[i]:] = inputs[i, :]`.
164+
inputs_packed: FloatTensor of shape (F,) or (F, ...) where
165+
`inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`.
161166
162167
To handle the allowed input shapes, we convert the inputs tensor of shape
163168
(N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from
164169
(F, 1).
165170
"""
166171
# if inputs is of shape (N, max_size), reshape into (N, max_size, 1))
167-
flat = False
168-
if inputs.dim() == 2:
169-
flat = True
172+
input_shape = inputs.shape
173+
n_dims = inputs.dim()
174+
if n_dims == 2:
170175
inputs = inputs.unsqueeze(2)
176+
else:
177+
inputs = inputs.reshape(*input_shape[:2], -1)
171178
inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs)
172-
# if flat is True, reshape output to (F,) from (F, 1)
173-
if flat:
174-
inputs_packed = inputs_packed.squeeze(1)
175-
return inputs_packed
179+
# if input is flat, reshape output to (F,) from (F, 1)
180+
# else reshape output to (F, ...)
181+
if n_dims == 2:
182+
return inputs_packed.squeeze(1)
183+
if n_dims == 3:
184+
return inputs_packed
185+
return inputs_packed.view(-1, *input_shape[2:])

tests/test_packed_to_padded.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ def packed_to_padded_python(inputs, first_idxs, max_size, device):
4545
PyTorch implementation of packed_to_padded function.
4646
"""
4747
num_meshes = first_idxs.size(0)
48-
D = inputs.shape[1] if inputs.dim() == 2 else 0
49-
if D == 0:
48+
if inputs.dim() == 1:
5049
inputs_padded = torch.zeros((num_meshes, max_size), device=device)
5150
else:
52-
inputs_padded = torch.zeros((num_meshes, max_size, D), device=device)
51+
inputs_padded = torch.zeros(
52+
(num_meshes, max_size, *inputs.shape[1:]), device=device
53+
)
5354
for m in range(num_meshes):
5455
s = first_idxs[m]
5556
if m == num_meshes - 1:
5657
f = inputs.shape[0]
5758
else:
5859
f = first_idxs[m + 1]
59-
inputs_padded[m, :f] = inputs[s:f]
60+
inputs_padded[m, : f - s] = inputs[s:f]
6061

6162
return inputs_padded
6263

@@ -66,22 +67,21 @@ def padded_to_packed_python(inputs, first_idxs, num_inputs, device):
6667
PyTorch implementation of padded_to_packed function.
6768
"""
6869
num_meshes = inputs.size(0)
69-
D = inputs.shape[2] if inputs.dim() == 3 else 0
70-
if D == 0:
70+
if inputs.dim() == 2:
7171
inputs_packed = torch.zeros((num_inputs,), device=device)
7272
else:
73-
inputs_packed = torch.zeros((num_inputs, D), device=device)
73+
inputs_packed = torch.zeros((num_inputs, *inputs.shape[2:]), device=device)
7474
for m in range(num_meshes):
7575
s = first_idxs[m]
7676
if m == num_meshes - 1:
7777
f = num_inputs
7878
else:
7979
f = first_idxs[m + 1]
80-
inputs_packed[s:f] = inputs[m, :f]
80+
inputs_packed[s:f] = inputs[m, : f - s]
8181

8282
return inputs_packed
8383

84-
def _test_packed_to_padded_helper(self, D, device):
84+
def _test_packed_to_padded_helper(self, dims, device):
8585
"""
8686
Check the results from packed_to_padded and PyTorch implementations
8787
are the same.
@@ -91,10 +91,12 @@ def _test_packed_to_padded_helper(self, D, device):
9191
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
9292
max_faces = meshes.num_faces_per_mesh().max().item()
9393

94-
if D == 0:
94+
if len(dims) == 0:
9595
values = torch.rand((faces.shape[0],), device=device, requires_grad=True)
9696
else:
97-
values = torch.rand((faces.shape[0], D), device=device, requires_grad=True)
97+
values = torch.rand(
98+
(faces.shape[0], *dims), device=device, requires_grad=True
99+
)
98100
values_torch = values.detach().clone()
99101
values_torch.requires_grad = True
100102
values_padded = packed_to_padded(
@@ -107,10 +109,10 @@ def _test_packed_to_padded_helper(self, D, device):
107109
self.assertClose(values_padded, values_padded_torch)
108110

109111
# check backward
110-
if D == 0:
112+
if len(dims) == 0:
111113
grad_inputs = torch.rand((len(meshes), max_faces), device=device)
112114
else:
113-
grad_inputs = torch.rand((len(meshes), max_faces, D), device=device)
115+
grad_inputs = torch.rand((len(meshes), max_faces, *dims), device=device)
114116
values_padded.backward(grad_inputs)
115117
grad_outputs = values.grad
116118
values_padded_torch.backward(grad_inputs)
@@ -122,27 +124,41 @@ def _test_packed_to_padded_helper(self, D, device):
122124
self.assertClose(grad_outputs, grad_outputs_torch2)
123125

124126
def test_packed_to_padded_flat_cpu(self):
125-
self._test_packed_to_padded_helper(0, "cpu")
127+
self._test_packed_to_padded_helper([], "cpu")
126128

127129
def test_packed_to_padded_D1_cpu(self):
128-
self._test_packed_to_padded_helper(1, "cpu")
130+
self._test_packed_to_padded_helper([1], "cpu")
129131

130132
def test_packed_to_padded_D16_cpu(self):
131-
self._test_packed_to_padded_helper(16, "cpu")
133+
self._test_packed_to_padded_helper([16], "cpu")
134+
135+
def test_packed_to_padded_D16_9_cpu(self):
136+
self._test_packed_to_padded_helper([16, 9], "cpu")
137+
138+
def test_packed_to_padded_D16_3_2_cpu(self):
139+
self._test_packed_to_padded_helper([16, 3, 2], "cpu")
132140

133141
def test_packed_to_padded_flat_cuda(self):
134142
device = get_random_cuda_device()
135-
self._test_packed_to_padded_helper(0, device)
143+
self._test_packed_to_padded_helper([], device)
136144

137145
def test_packed_to_padded_D1_cuda(self):
138146
device = get_random_cuda_device()
139-
self._test_packed_to_padded_helper(1, device)
147+
self._test_packed_to_padded_helper([1], device)
140148

141149
def test_packed_to_padded_D16_cuda(self):
142150
device = get_random_cuda_device()
143-
self._test_packed_to_padded_helper(16, device)
151+
self._test_packed_to_padded_helper([16], device)
152+
153+
def test_packed_to_padded_D16_9_cuda(self):
154+
device = get_random_cuda_device()
155+
self._test_packed_to_padded_helper([16, 9], device)
156+
157+
def test_packed_to_padded_D16_3_2_cuda(self):
158+
device = get_random_cuda_device()
159+
self._test_packed_to_padded_helper([16, 3, 2], device)
144160

145-
def _test_padded_to_packed_helper(self, D, device):
161+
def _test_padded_to_packed_helper(self, dims, device):
146162
"""
147163
Check the results from packed_to_padded and PyTorch implementations
148164
are the same.
@@ -151,10 +167,10 @@ def _test_padded_to_packed_helper(self, D, device):
151167
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
152168
num_faces_per_mesh = meshes.num_faces_per_mesh()
153169
max_faces = num_faces_per_mesh.max().item()
154-
if D == 0:
170+
if len(dims) == 0:
155171
values = torch.rand((len(meshes), max_faces), device=device)
156172
else:
157-
values = torch.rand((len(meshes), max_faces, D), device=device)
173+
values = torch.rand((len(meshes), max_faces, *dims), device=device)
158174
for i, num in enumerate(num_faces_per_mesh):
159175
values[i, num:] = 0
160176
values.requires_grad = True
@@ -173,11 +189,11 @@ def _test_padded_to_packed_helper(self, D, device):
173189
self.assertClose(values_packed, values_packed_torch)
174190

175191
# check backward
176-
if D == 0:
192+
if len(dims) == 0:
177193
grad_inputs = torch.rand((num_faces_per_mesh.sum().item()), device=device)
178194
else:
179195
grad_inputs = torch.rand(
180-
(num_faces_per_mesh.sum().item(), D), device=device
196+
(num_faces_per_mesh.sum().item(), *dims), device=device
181197
)
182198
values_packed.backward(grad_inputs)
183199
grad_outputs = values.grad
@@ -190,41 +206,39 @@ def _test_padded_to_packed_helper(self, D, device):
190206
self.assertClose(grad_outputs, grad_outputs_torch2)
191207

192208
def test_padded_to_packed_flat_cpu(self):
193-
self._test_padded_to_packed_helper(0, "cpu")
209+
self._test_padded_to_packed_helper([], "cpu")
194210

195211
def test_padded_to_packed_D1_cpu(self):
196-
self._test_padded_to_packed_helper(1, "cpu")
212+
self._test_padded_to_packed_helper([1], "cpu")
197213

198214
def test_padded_to_packed_D16_cpu(self):
199-
self._test_padded_to_packed_helper(16, "cpu")
215+
self._test_padded_to_packed_helper([16], "cpu")
216+
217+
def test_padded_to_packed_D16_9_cpu(self):
218+
self._test_padded_to_packed_helper([16, 9], "cpu")
219+
220+
def test_padded_to_packed_D16_3_2_cpu(self):
221+
self._test_padded_to_packed_helper([16, 3, 2], "cpu")
200222

201223
def test_padded_to_packed_flat_cuda(self):
202224
device = get_random_cuda_device()
203-
self._test_padded_to_packed_helper(0, device)
225+
self._test_padded_to_packed_helper([], device)
204226

205227
def test_padded_to_packed_D1_cuda(self):
206228
device = get_random_cuda_device()
207-
self._test_padded_to_packed_helper(1, device)
229+
self._test_padded_to_packed_helper([1], device)
208230

209231
def test_padded_to_packed_D16_cuda(self):
210232
device = get_random_cuda_device()
211-
self._test_padded_to_packed_helper(16, device)
212-
213-
def test_invalid_inputs_shapes(self, device="cuda:0"):
214-
with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):
215-
values = torch.rand((100, 50, 2), device=device)
216-
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
217-
packed_to_padded(values, first_idxs, 100)
218-
219-
with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."):
220-
values = torch.rand((100,), device=device)
221-
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
222-
padded_to_packed(values, first_idxs, 20)
223-
224-
with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."):
225-
values = torch.rand((100, 50, 2, 2), device=device)
226-
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
227-
padded_to_packed(values, first_idxs, 20)
233+
self._test_padded_to_packed_helper([16], device)
234+
235+
def test_padded_to_packed_D16_9_cuda(self):
236+
device = get_random_cuda_device()
237+
self._test_padded_to_packed_helper([16, 9], device)
238+
239+
def test_padded_to_packed_D16_3_2_cuda(self):
240+
device = get_random_cuda_device()
241+
self._test_padded_to_packed_helper([16, 3, 2], device)
228242

229243
@staticmethod
230244
def packed_to_padded_with_init(

0 commit comments

Comments
 (0)