Skip to content

Commit 5910d81

Browse files
EmGarrfacebook-github-bot
authored andcommitted
Add blurpool following MIPNerf paper.
Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
1 parent ccf860f commit 5910d81

File tree

4 files changed

+97
-5
lines changed

4 files changed

+97
-5
lines changed

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ model_factory_ImplicitronModelFactory_args:
249249
append_coarse_samples_to_fine: true
250250
density_noise_std_train: 0.0
251251
return_weights: false
252+
blurpool_weights: false
253+
sample_pdf_eps: 1.0e-05
252254
raymarcher_CumsumRaymarcher_args:
253255
surface_thickness: 1
254256
bg_color:
@@ -679,6 +681,8 @@ model_factory_ImplicitronModelFactory_args:
679681
append_coarse_samples_to_fine: true
680682
density_noise_std_train: 0.0
681683
return_weights: false
684+
blurpool_weights: false
685+
sample_pdf_eps: 1.0e-05
682686
raymarcher_CumsumRaymarcher_args:
683687
surface_thickness: 1
684688
bg_color:

pytorch3d/implicitron/models/renderer/multipass_ea.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
6565
opacity field.
6666
return_weights: Enables returning the rendering weights of the EA raymarcher.
6767
Setting to `True` can lead to a prohibitivelly large memory consumption.
68+
blurpool_weights: Use blurpool defined in [3], on the input weights of
69+
each implicit_function except the first (implicit_functions[0]).
70+
sample_pdf_eps: Padding applied to the weights (alpha in equation 18 of [3]).
6871
raymarcher_class_type: The type of self.raymarcher corresponding to
6972
a child of `RaymarcherBase` in the registry.
7073
raymarcher: The raymarcher object used to convert per-point features
@@ -75,6 +78,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
7578
Fields for View Synthesis." ECCV 2020.
7679
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
7780
Volumes from Images." SIGGRAPH 2019.
81+
[3] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
82+
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
7883
7984
"""
8085

@@ -88,18 +93,24 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
8893
append_coarse_samples_to_fine: bool = True
8994
density_noise_std_train: float = 0.0
9095
return_weights: bool = False
96+
blurpool_weights: bool = False
97+
sample_pdf_eps: float = 1e-5
9198

9299
def __post_init__(self):
93100
self._refiners = {
94101
EvaluationMode.TRAINING: RayPointRefiner(
95102
n_pts_per_ray=self.n_pts_per_ray_fine_training,
96103
random_sampling=self.stratified_sampling_coarse_training,
97104
add_input_samples=self.append_coarse_samples_to_fine,
105+
blurpool_weights=self.blurpool_weights,
106+
sample_pdf_eps=self.sample_pdf_eps,
98107
),
99108
EvaluationMode.EVALUATION: RayPointRefiner(
100109
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
101110
random_sampling=self.stratified_sampling_coarse_evaluation,
102111
add_input_samples=self.append_coarse_samples_to_fine,
112+
blurpool_weights=self.blurpool_weights,
113+
sample_pdf_eps=self.sample_pdf_eps,
103114
),
104115
}
105116
run_auto_creation(self)

pytorch3d/implicitron/models/renderer/ray_point_refiner.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,45 +32,66 @@ class RayPointRefiner(Configurable, torch.nn.Module):
3232
sampling from that distribution.
3333
add_input_samples: Concatenates and returns the sampled values
3434
together with the input samples.
35+
blurpool_weights: Use blurpool defined in [1], on the input weights.
36+
sample_pdf_eps: A constant preventing division by zero in case empty bins
37+
are present.
38+
39+
References:
40+
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
41+
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
3542
"""
3643

3744
n_pts_per_ray: int
3845
random_sampling: bool
3946
add_input_samples: bool = True
47+
blurpool_weights: bool = False
48+
sample_pdf_eps: float = 1e-5
4049

4150
def forward(
4251
self,
4352
input_ray_bundle: ImplicitronRayBundle,
4453
ray_weights: torch.Tensor,
54+
blurpool_weights: bool = False,
55+
sample_pdf_padding: float = 1e-5,
4556
**kwargs,
4657
) -> ImplicitronRayBundle:
4758
"""
4859
Args:
4960
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
5061
source rays for sampling of the probability distribution.
5162
ray_weights: A tensor of shape
52-
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
63+
`(..., input_ray_bundle.lengths.shape[-1])` with non-negative
5364
elements defining the probability distribution to sample
5465
ray points from.
66+
blurpool_weights: Use blurpool defined in [1], on the input weights.
67+
sample_pdf_padding: A constant preventing division by zero in case empty bins
68+
are present.
5569
5670
Returns:
5771
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
5872
points together with `n_pts_per_ray` additionally sampled
5973
points per ray. For each ray, the lengths are sorted.
74+
75+
References:
76+
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
77+
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
78+
6079
"""
6180

6281
z_vals = input_ray_bundle.lengths
6382
with torch.no_grad():
83+
if self.blurpool_weights:
84+
ray_weights = apply_blurpool_on_weights(ray_weights)
85+
6486
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
6587
z_samples = sample_pdf(
6688
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
6789
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
6890
self.n_pts_per_ray,
6991
det=not self.random_sampling,
92+
eps=self.sample_pdf_eps,
7093
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
71-
7294
if self.add_input_samples:
73-
# Add the new samples to the input ones.
7495
z_vals = torch.cat((z_vals, z_samples), dim=-1)
7596
else:
7697
z_vals = z_samples
@@ -80,3 +101,31 @@ def forward(
80101
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
81102
new_bundle.lengths = z_vals
82103
return new_bundle
104+
105+
106+
def apply_blurpool_on_weights(weights) -> torch.Tensor:
107+
"""
108+
Filter weights with a 2-tap max filters followed by a 2-tap blur filter,
109+
which produces a wide and smooth upper envelope on the weights.
110+
111+
Args:
112+
weights: Tensor of shape `(..., dim)`
113+
114+
Returns:
115+
blured_weights: Tensor of shape `(..., dim)`
116+
"""
117+
weights_pad = torch.concatenate(
118+
[
119+
weights[..., :1],
120+
weights,
121+
weights[..., -1:],
122+
],
123+
dim=-1,
124+
)
125+
126+
weights_max = torch.nn.functional.max_pool1d(
127+
weights_pad.flatten(end_dim=-2), 2, stride=1
128+
)
129+
return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as(
130+
weights
131+
)

tests/implicitron/test_ray_point_refiner.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8+
from itertools import product
89

910
import torch
10-
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
11+
12+
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
13+
apply_blurpool_on_weights,
14+
RayPointRefiner,
15+
)
1116
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
1217
from tests.common_testing import TestCaseMixin
1318

@@ -17,11 +22,12 @@ def test_simple(self):
1722
length = 15
1823
n_pts_per_ray = 10
1924

20-
for add_input_samples in [False, True]:
25+
for add_input_samples, use_blurpool in product([False, True], [False, True]):
2126
ray_point_refiner = RayPointRefiner(
2227
n_pts_per_ray=n_pts_per_ray,
2328
random_sampling=False,
2429
add_input_samples=add_input_samples,
30+
blurpool_weights=use_blurpool,
2531
)
2632
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
2733
bundle = ImplicitronRayBundle(
@@ -50,6 +56,7 @@ def test_simple(self):
5056
n_pts_per_ray=n_pts_per_ray,
5157
random_sampling=True,
5258
add_input_samples=add_input_samples,
59+
blurpool_weights=use_blurpool,
5360
)
5461
refined_random = ray_point_refiner_random(bundle, weights)
5562
lengths_random = refined_random.lengths
@@ -62,3 +69,24 @@ def test_simple(self):
6269
self.assertTrue(
6370
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
6471
)
72+
73+
def test_apply_blurpool_on_weights(self):
74+
weights = torch.tensor(
75+
[
76+
[0.5, 0.6, 0.7],
77+
[0.5, 0.3, 0.9],
78+
]
79+
)
80+
expected_weights = 0.5 * torch.tensor(
81+
[
82+
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
83+
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
84+
]
85+
)
86+
out_weights = apply_blurpool_on_weights(weights)
87+
self.assertTrue(torch.allclose(out_weights, expected_weights))
88+
89+
def test_shapes_apply_blurpool_on_weights(self):
90+
weights = torch.randn((5, 4, 3, 2, 1))
91+
out_weights = apply_blurpool_on_weights(weights)
92+
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)

0 commit comments

Comments
 (0)