Skip to content

Commit eeb6bd3

Browse files
authored
Merge pull request #114 from nikhilaravi/fixup-T64213310-master
Re-sync with internal repository
2 parents 2480723 + 3d3b2fd commit eeb6bd3

File tree

7 files changed

+2805
-0
lines changed

7 files changed

+2805
-0
lines changed

pytorch3d/renderer/points/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from typing import Optional
4+
import torch
5+
6+
from pytorch3d import _C
7+
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
8+
9+
10+
# TODO(jcjohns): Support non-square images
11+
def rasterize_points(
12+
pointclouds,
13+
image_size: int = 256,
14+
radius: float = 0.01,
15+
points_per_pixel: int = 8,
16+
bin_size: Optional[int] = None,
17+
max_points_per_bin: Optional[int] = None,
18+
):
19+
"""
20+
Pointcloud rasterization
21+
22+
Args:
23+
pointclouds: A Pointclouds object representing a batch of point clouds to be
24+
rasterized. This is a batch of N pointclouds, where each point cloud
25+
can have a different number of points; the coordinates of each point
26+
are (x, y, z). The coordinates are expected to
27+
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
28+
(0, 0, 0); the x-axis goes from left-to-right, the y-axis goes from
29+
top-to-bottom, and the z-axis goes from back-to-front.
30+
image_size: Integer giving the resolution of the rasterized image
31+
radius (Optional): Float giving the radius (in NDC units) of the disk to
32+
be rasterized for each point.
33+
points_per_pixel (Optional): We will keep track of this many points per
34+
pixel, returning the nearest points_per_pixel points along the z-axis
35+
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
36+
bin_size=0 uses naive rasterization; setting bin_size=None attempts to
37+
set it heuristically based on the shape of the input. This should not
38+
affect the output, but can affect the speed of the forward pass.
39+
points_per_bin: Only applicable when using coarse-to-fine rasterization
40+
(bin_size > 0); this is the maxiumum number of points allowed within each
41+
bin. If more than this many points actually fall into a bin, an error
42+
will be raised. This should not affect the output values, but can affect
43+
the memory usage in the forward pass.
44+
45+
Returns:
46+
3-element tuple containing
47+
48+
- **idx**: int32 Tensor of shape (N, image_size, image_size, points_per_pixel)
49+
giving the indices of the nearest points at each pixel, in ascending
50+
z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth
51+
closest point (along the z-direction) to pixel (y, x) - note that points
52+
represents the packed points of shape (P, 3).
53+
Pixels that are hit by fewer than points_per_pixel are padded with -1.
54+
- **zbuf**: Tensor of shape (N, image_size, image_size, points_per_pixel)
55+
giving the z-coordinates of the nearest points at each pixel, sorted in
56+
z-order. Concretely, if `idx[n, y, x, k] = p` then
57+
`zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than
58+
points_per_pixel are padded with -1
59+
- **dists2**: Tensor of shape (N, image_size, image_size, points_per_pixel)
60+
giving the squared Euclidean distance (in NDC units) in the x/y plane
61+
for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p`
62+
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
63+
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
64+
than points_per_pixel are padded with -1.
65+
"""
66+
points_packed = pointclouds.points_packed()
67+
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
68+
num_points_per_cloud = pointclouds.num_points_per_cloud()
69+
70+
if bin_size is None:
71+
if not points_packed.is_cuda:
72+
# Binned CPU rasterization not fully implemented
73+
bin_size = 0
74+
else:
75+
# TODO: These heuristics are not well-thought out!
76+
if image_size <= 64:
77+
bin_size = 8
78+
elif image_size <= 256:
79+
bin_size = 16
80+
elif image_size <= 512:
81+
bin_size = 32
82+
elif image_size <= 1024:
83+
bin_size = 64
84+
85+
if max_points_per_bin is None:
86+
max_points_per_bin = int(max(10000, points_packed.shape[0] / 5))
87+
88+
# Function.apply cannot take keyword args, so we handle defaults in this
89+
# wrapper and call apply with positional args only
90+
return _RasterizePoints.apply(
91+
points_packed,
92+
cloud_to_packed_first_idx,
93+
num_points_per_cloud,
94+
image_size,
95+
radius,
96+
points_per_pixel,
97+
bin_size,
98+
max_points_per_bin,
99+
)
100+
101+
102+
class _RasterizePoints(torch.autograd.Function):
103+
@staticmethod
104+
def forward(
105+
ctx,
106+
points, # (P, 3)
107+
cloud_to_packed_first_idx,
108+
num_points_per_cloud,
109+
image_size: int = 256,
110+
radius: float = 0.01,
111+
points_per_pixel: int = 8,
112+
bin_size: int = 0,
113+
max_points_per_bin: int = 0,
114+
):
115+
# TODO: Add better error handling for when there are more than
116+
# max_points_per_bin in any bin.
117+
args = (
118+
points,
119+
cloud_to_packed_first_idx,
120+
num_points_per_cloud,
121+
image_size,
122+
radius,
123+
points_per_pixel,
124+
bin_size,
125+
max_points_per_bin,
126+
)
127+
idx, zbuf, dists = _C.rasterize_points(*args)
128+
ctx.save_for_backward(points, idx)
129+
return idx, zbuf, dists
130+
131+
@staticmethod
132+
def backward(ctx, grad_idx, grad_zbuf, grad_dists):
133+
grad_points = None
134+
grad_cloud_to_packed_first_idx = None
135+
grad_num_points_per_cloud = None
136+
grad_image_size = None
137+
grad_radius = None
138+
grad_points_per_pixel = None
139+
grad_bin_size = None
140+
grad_max_points_per_bin = None
141+
points, idx = ctx.saved_tensors
142+
args = (points, idx, grad_zbuf, grad_dists)
143+
grad_points = _C.rasterize_points_backward(*args)
144+
grads = (
145+
grad_points,
146+
grad_cloud_to_packed_first_idx,
147+
grad_num_points_per_cloud,
148+
grad_image_size,
149+
grad_radius,
150+
grad_points_per_pixel,
151+
grad_bin_size,
152+
grad_max_points_per_bin,
153+
)
154+
return grads
155+
156+
157+
def rasterize_points_python(
158+
pointclouds,
159+
image_size: int = 256,
160+
radius: float = 0.01,
161+
points_per_pixel: int = 8,
162+
):
163+
"""
164+
Naive pure PyTorch implementation of pointcloud rasterization.
165+
166+
Inputs / Outputs: Same as above
167+
"""
168+
N = len(pointclouds)
169+
S, K = image_size, points_per_pixel
170+
device = pointclouds.device
171+
172+
points_packed = pointclouds.points_packed()
173+
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
174+
num_points_per_cloud = pointclouds.num_points_per_cloud()
175+
176+
# Intialize output tensors.
177+
point_idxs = torch.full(
178+
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
179+
)
180+
zbuf = torch.full(
181+
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
182+
)
183+
pix_dists = torch.full(
184+
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
185+
)
186+
187+
# NDC is from [-1, 1]. Get pixel size using specified image size.
188+
radius2 = radius * radius
189+
190+
# Iterate through the batch of point clouds.
191+
for n in range(N):
192+
point_start_idx = cloud_to_packed_first_idx[n]
193+
point_stop_idx = point_start_idx + num_points_per_cloud[n]
194+
195+
# Iterate through the horizontal lines of the image from top to bottom.
196+
for yi in range(S):
197+
# Y coordinate of one end of the image. Reverse the ordering
198+
# of yi so that +Y is pointing up in the image.
199+
yfix = S - 1 - yi
200+
yf = pix_to_ndc(yfix, S)
201+
202+
# Iterate through pixels on this horizontal line, left to right.
203+
for xi in range(S):
204+
# X coordinate of one end of the image. Reverse the ordering
205+
# of xi so that +X is pointing to the left in the image.
206+
xfix = S - 1 - xi
207+
xf = pix_to_ndc(xfix, S)
208+
209+
top_k_points = []
210+
# Check whether each point in the batch affects this pixel.
211+
for p in range(point_start_idx, point_stop_idx):
212+
px, py, pz = points_packed[p, :]
213+
if pz < 0:
214+
continue
215+
dx = px - xf
216+
dy = py - yf
217+
dist2 = dx * dx + dy * dy
218+
if dist2 < radius2:
219+
top_k_points.append((pz, p, dist2))
220+
top_k_points.sort()
221+
if len(top_k_points) > K:
222+
top_k_points = top_k_points[:K]
223+
for k, (pz, p, dist2) in enumerate(top_k_points):
224+
zbuf[n, yi, xi, k] = pz
225+
point_idxs[n, yi, xi, k] = p
226+
pix_dists[n, yi, xi, k] = dist2
227+
return point_idxs, zbuf, pix_dists

0 commit comments

Comments
 (0)