Radiance Field Gradient Scaling for Unbiased Near-Camera Training

Quadratically scaling gradients near cameras during NeRF training removes floaters.

Left: Original method (MipNeRF360,DVGO) without scaling.
Right: With our gradient scaling.

Abstract

NeRF acquisition typically requires careful choice of near planes for the different cameras or suffers from background collapse, creating floating artifacts on the edges of the captured scene. The key insight of this work is that background collapse is caused by a higher density of samples in regions near cameras. As a result of this sampling bias, near-camera volumes receive significantly more gradients, leading to incorrect density buildup. We propose a gradient scaling approach to counter-balance this bias, removing the need for near planes, while preventing background collapse. Our method can be implemented in a few lines, does not induce any significant overhead, and is compatible with most NeRF implementations.

Code

PyTorch


##### Define with: #####

class GradientScaler(torch.autograd.Function):
  @staticmethod
  def forward(ctx, colors, sigmas, ray_dist):
    ctx.save_for_backward(ray_dist)
    return colors, sigmas, ray_dist
  @staticmethod
  def backward(ctx, grad_output_colors, grad_output_sigmas, grad_output_ray_dist):
    (ray_dist,) = ctx.saved_tensors
    scaling = torch.square(ray_dist).clamp(0, 1)
    return grad_output_colors * scaling.unsqueeze(-1), grad_output_sigmas * scaling, grad_output_ray_dist

##### Call with: #####

  if gradient_scaling:
    colors, sigmas, ray_dist = GradientScaler.apply(colors, sigmas, ray_dist)
        

JAX (compatible with MultiNeRF)


##### Define with: #####

@custom_jvp
def gradientScaling(origins,gaussians,rgb,density):
  return rgb,density
@gradientScaling.defjvp
def lgradientScaling_jvp(primals, tangents):
  origins,gaussians,rgb,density=primals
  origins_dot,gaussians_dot,rgb_dot,density_dot=tangents
  ans=gradientScaling(origins,gaussians,rgb,density)
  scaling=jnp.square(jnp.linalg.norm(gaussians[0]-origins[:,:,:,None],axis=-1,keepdims=True)).clip(0,1)
  ans_dot=(rgb_dot*scaling,density_dot*scaling[...,0])
  return ans, ans_dot

##### Call with: #####

  if gradient_scaling:
    ray_results['rgb'],ray_results['density']=gradientScaling(rays.origins,gaussians,ray_results['rgb'],ray_results['density']) 
        

BibTeX


  Coming Soon, see arXiv for now.