|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | +from typing import Optional |
| 4 | +import torch |
| 5 | + |
| 6 | +from pytorch3d import _C |
| 7 | +from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc |
| 8 | + |
| 9 | + |
| 10 | +# TODO(jcjohns): Support non-square images |
| 11 | +def rasterize_points( |
| 12 | + pointclouds, |
| 13 | + image_size: int = 256, |
| 14 | + radius: float = 0.01, |
| 15 | + points_per_pixel: int = 8, |
| 16 | + bin_size: Optional[int] = None, |
| 17 | + max_points_per_bin: Optional[int] = None, |
| 18 | +): |
| 19 | + """ |
| 20 | + Pointcloud rasterization |
| 21 | +
|
| 22 | + Args: |
| 23 | + pointclouds: A Pointclouds object representing a batch of point clouds to be |
| 24 | + rasterized. This is a batch of N pointclouds, where each point cloud |
| 25 | + can have a different number of points; the coordinates of each point |
| 26 | + are (x, y, z). The coordinates are expected to |
| 27 | + be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at |
| 28 | + (0, 0, 0); the x-axis goes from left-to-right, the y-axis goes from |
| 29 | + top-to-bottom, and the z-axis goes from back-to-front. |
| 30 | + image_size: Integer giving the resolution of the rasterized image |
| 31 | + radius (Optional): Float giving the radius (in NDC units) of the disk to |
| 32 | + be rasterized for each point. |
| 33 | + points_per_pixel (Optional): We will keep track of this many points per |
| 34 | + pixel, returning the nearest points_per_pixel points along the z-axis |
| 35 | + bin_size: Size of bins to use for coarse-to-fine rasterization. Setting |
| 36 | + bin_size=0 uses naive rasterization; setting bin_size=None attempts to |
| 37 | + set it heuristically based on the shape of the input. This should not |
| 38 | + affect the output, but can affect the speed of the forward pass. |
| 39 | + points_per_bin: Only applicable when using coarse-to-fine rasterization |
| 40 | + (bin_size > 0); this is the maxiumum number of points allowed within each |
| 41 | + bin. If more than this many points actually fall into a bin, an error |
| 42 | + will be raised. This should not affect the output values, but can affect |
| 43 | + the memory usage in the forward pass. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + 3-element tuple containing |
| 47 | +
|
| 48 | + - **idx**: int32 Tensor of shape (N, image_size, image_size, points_per_pixel) |
| 49 | + giving the indices of the nearest points at each pixel, in ascending |
| 50 | + z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth |
| 51 | + closest point (along the z-direction) to pixel (y, x) - note that points |
| 52 | + represents the packed points of shape (P, 3). |
| 53 | + Pixels that are hit by fewer than points_per_pixel are padded with -1. |
| 54 | + - **zbuf**: Tensor of shape (N, image_size, image_size, points_per_pixel) |
| 55 | + giving the z-coordinates of the nearest points at each pixel, sorted in |
| 56 | + z-order. Concretely, if `idx[n, y, x, k] = p` then |
| 57 | + `zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than |
| 58 | + points_per_pixel are padded with -1 |
| 59 | + - **dists2**: Tensor of shape (N, image_size, image_size, points_per_pixel) |
| 60 | + giving the squared Euclidean distance (in NDC units) in the x/y plane |
| 61 | + for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p` |
| 62 | + then `dists[n, y, x, k]` is the squared distance between the pixel (y, x) |
| 63 | + and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer |
| 64 | + than points_per_pixel are padded with -1. |
| 65 | + """ |
| 66 | + points_packed = pointclouds.points_packed() |
| 67 | + cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() |
| 68 | + num_points_per_cloud = pointclouds.num_points_per_cloud() |
| 69 | + |
| 70 | + if bin_size is None: |
| 71 | + if not points_packed.is_cuda: |
| 72 | + # Binned CPU rasterization not fully implemented |
| 73 | + bin_size = 0 |
| 74 | + else: |
| 75 | + # TODO: These heuristics are not well-thought out! |
| 76 | + if image_size <= 64: |
| 77 | + bin_size = 8 |
| 78 | + elif image_size <= 256: |
| 79 | + bin_size = 16 |
| 80 | + elif image_size <= 512: |
| 81 | + bin_size = 32 |
| 82 | + elif image_size <= 1024: |
| 83 | + bin_size = 64 |
| 84 | + |
| 85 | + if max_points_per_bin is None: |
| 86 | + max_points_per_bin = int(max(10000, points_packed.shape[0] / 5)) |
| 87 | + |
| 88 | + # Function.apply cannot take keyword args, so we handle defaults in this |
| 89 | + # wrapper and call apply with positional args only |
| 90 | + return _RasterizePoints.apply( |
| 91 | + points_packed, |
| 92 | + cloud_to_packed_first_idx, |
| 93 | + num_points_per_cloud, |
| 94 | + image_size, |
| 95 | + radius, |
| 96 | + points_per_pixel, |
| 97 | + bin_size, |
| 98 | + max_points_per_bin, |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +class _RasterizePoints(torch.autograd.Function): |
| 103 | + @staticmethod |
| 104 | + def forward( |
| 105 | + ctx, |
| 106 | + points, # (P, 3) |
| 107 | + cloud_to_packed_first_idx, |
| 108 | + num_points_per_cloud, |
| 109 | + image_size: int = 256, |
| 110 | + radius: float = 0.01, |
| 111 | + points_per_pixel: int = 8, |
| 112 | + bin_size: int = 0, |
| 113 | + max_points_per_bin: int = 0, |
| 114 | + ): |
| 115 | + # TODO: Add better error handling for when there are more than |
| 116 | + # max_points_per_bin in any bin. |
| 117 | + args = ( |
| 118 | + points, |
| 119 | + cloud_to_packed_first_idx, |
| 120 | + num_points_per_cloud, |
| 121 | + image_size, |
| 122 | + radius, |
| 123 | + points_per_pixel, |
| 124 | + bin_size, |
| 125 | + max_points_per_bin, |
| 126 | + ) |
| 127 | + idx, zbuf, dists = _C.rasterize_points(*args) |
| 128 | + ctx.save_for_backward(points, idx) |
| 129 | + return idx, zbuf, dists |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def backward(ctx, grad_idx, grad_zbuf, grad_dists): |
| 133 | + grad_points = None |
| 134 | + grad_cloud_to_packed_first_idx = None |
| 135 | + grad_num_points_per_cloud = None |
| 136 | + grad_image_size = None |
| 137 | + grad_radius = None |
| 138 | + grad_points_per_pixel = None |
| 139 | + grad_bin_size = None |
| 140 | + grad_max_points_per_bin = None |
| 141 | + points, idx = ctx.saved_tensors |
| 142 | + args = (points, idx, grad_zbuf, grad_dists) |
| 143 | + grad_points = _C.rasterize_points_backward(*args) |
| 144 | + grads = ( |
| 145 | + grad_points, |
| 146 | + grad_cloud_to_packed_first_idx, |
| 147 | + grad_num_points_per_cloud, |
| 148 | + grad_image_size, |
| 149 | + grad_radius, |
| 150 | + grad_points_per_pixel, |
| 151 | + grad_bin_size, |
| 152 | + grad_max_points_per_bin, |
| 153 | + ) |
| 154 | + return grads |
| 155 | + |
| 156 | + |
| 157 | +def rasterize_points_python( |
| 158 | + pointclouds, |
| 159 | + image_size: int = 256, |
| 160 | + radius: float = 0.01, |
| 161 | + points_per_pixel: int = 8, |
| 162 | +): |
| 163 | + """ |
| 164 | + Naive pure PyTorch implementation of pointcloud rasterization. |
| 165 | +
|
| 166 | + Inputs / Outputs: Same as above |
| 167 | + """ |
| 168 | + N = len(pointclouds) |
| 169 | + S, K = image_size, points_per_pixel |
| 170 | + device = pointclouds.device |
| 171 | + |
| 172 | + points_packed = pointclouds.points_packed() |
| 173 | + cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() |
| 174 | + num_points_per_cloud = pointclouds.num_points_per_cloud() |
| 175 | + |
| 176 | + # Intialize output tensors. |
| 177 | + point_idxs = torch.full( |
| 178 | + (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device |
| 179 | + ) |
| 180 | + zbuf = torch.full( |
| 181 | + (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device |
| 182 | + ) |
| 183 | + pix_dists = torch.full( |
| 184 | + (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device |
| 185 | + ) |
| 186 | + |
| 187 | + # NDC is from [-1, 1]. Get pixel size using specified image size. |
| 188 | + radius2 = radius * radius |
| 189 | + |
| 190 | + # Iterate through the batch of point clouds. |
| 191 | + for n in range(N): |
| 192 | + point_start_idx = cloud_to_packed_first_idx[n] |
| 193 | + point_stop_idx = point_start_idx + num_points_per_cloud[n] |
| 194 | + |
| 195 | + # Iterate through the horizontal lines of the image from top to bottom. |
| 196 | + for yi in range(S): |
| 197 | + # Y coordinate of one end of the image. Reverse the ordering |
| 198 | + # of yi so that +Y is pointing up in the image. |
| 199 | + yfix = S - 1 - yi |
| 200 | + yf = pix_to_ndc(yfix, S) |
| 201 | + |
| 202 | + # Iterate through pixels on this horizontal line, left to right. |
| 203 | + for xi in range(S): |
| 204 | + # X coordinate of one end of the image. Reverse the ordering |
| 205 | + # of xi so that +X is pointing to the left in the image. |
| 206 | + xfix = S - 1 - xi |
| 207 | + xf = pix_to_ndc(xfix, S) |
| 208 | + |
| 209 | + top_k_points = [] |
| 210 | + # Check whether each point in the batch affects this pixel. |
| 211 | + for p in range(point_start_idx, point_stop_idx): |
| 212 | + px, py, pz = points_packed[p, :] |
| 213 | + if pz < 0: |
| 214 | + continue |
| 215 | + dx = px - xf |
| 216 | + dy = py - yf |
| 217 | + dist2 = dx * dx + dy * dy |
| 218 | + if dist2 < radius2: |
| 219 | + top_k_points.append((pz, p, dist2)) |
| 220 | + top_k_points.sort() |
| 221 | + if len(top_k_points) > K: |
| 222 | + top_k_points = top_k_points[:K] |
| 223 | + for k, (pz, p, dist2) in enumerate(top_k_points): |
| 224 | + zbuf[n, yi, xi, k] = pz |
| 225 | + point_idxs[n, yi, xi, k] = p |
| 226 | + pix_dists[n, yi, xi, k] = dist2 |
| 227 | + return point_idxs, zbuf, pix_dists |
0 commit comments