Floaters No More: Radiance Field Gradient Scaling for Improved Near-Camera Training

1Adobe Research

Quadratically scaling gradients near cameras during NeRF training removes floaters.

Left: Original method (MipNeRF360,DVGOv2) without scaling.
Right: With our gradient scaling.

Abstract

NeRF acquisition typically requires careful choice of near planes for the different cameras or suffers from background collapse, creating floating artifacts on the edges of the captured scene. The key insight of this work is that background collapse is caused by a higher density of samples in regions near cameras. As a result of this sampling imbalance, near-camera volumes receive significantly more gradients, leading to incorrect density buildup. We propose a gradient scaling approach to counter-balance this sampling imbalance, removing the need for near planes, while preventing background collapse. Our method can be implemented in a few lines, does not induce any significant overhead, and is compatible with most NeRF implementations.

Code

PyTorch


##### Define with: #####

class GradientScaler(torch.autograd.Function):
  @staticmethod
  def forward(ctx, colors, sigmas, ray_dist):
    ctx.save_for_backward(ray_dist)
    return colors, sigmas, ray_dist
  @staticmethod
  def backward(ctx, grad_output_colors, grad_output_sigmas, grad_output_ray_dist):
    (ray_dist,) = ctx.saved_tensors
    scaling = torch.square(ray_dist).clamp(0, 1)
    return grad_output_colors * scaling.unsqueeze(-1), grad_output_sigmas * scaling, grad_output_ray_dist

##### Call with: #####

  if gradient_scaling:
    colors, sigmas, ray_dist = GradientScaler.apply(colors, sigmas, ray_dist)
        

JAX (compatible with MultiNeRF)


##### Define with: #####

@custom_jvp
def gradientScaling(origins,gaussians,rgb,density):
  return rgb,density
@gradientScaling.defjvp
def lgradientScaling_jvp(primals, tangents):
  origins,gaussians,rgb,density=primals
  origins_dot,gaussians_dot,rgb_dot,density_dot=tangents
  ans=gradientScaling(origins,gaussians,rgb,density)
  scaling=jnp.square(jnp.linalg.norm(gaussians[0]-origins[:,:,:,None],axis=-1,keepdims=True)).clip(0,1)
  ans_dot=(rgb_dot*scaling,density_dot*scaling[...,0])
  return ans, ans_dot

##### Call with: #####

  if gradient_scaling:
    ray_results['rgb'],ray_results['density']=gradientScaling(rays.origins,gaussians,ray_results['rgb'],ray_results['density']) 
        

Nerfstudio

Our method has been implemented by the community in NerfStudio. If you want to use it look here!

BibTeX


        @inproceedings {10.2312:sr.20231122,
        booktitle = {Eurographics Symposium on Rendering},
        editor = {Ritschel, Tobias and Weidlich, Andrea},
        title = {{Floaters No More: Radiance Field Gradient Scaling for Improved Near-Camera Training}},
        author = {Philip, Julien and Deschaintre, Valentin},
        year = {2023},
        publisher = {The Eurographics Association},
        ISSN = {1727-3463},
        ISBN = {978-3-03868-229-5},
        DOI = {10.2312/sr.20231122}
        }