NeRF acquisition typically requires careful choice of near planes for the different cameras or suffers from background collapse, creating floating artifacts on the edges of the captured scene. The key insight of this work is that background collapse is caused by a higher density of samples in regions near cameras. As a result of this sampling bias, near-camera volumes receive significantly more gradients, leading to incorrect density buildup. We propose a gradient scaling approach to counter-balance this bias, removing the need for near planes, while preventing background collapse. Our method can be implemented in a few lines, does not induce any significant overhead, and is compatible with most NeRF implementations.
##### Define with: #####
class GradientScaler(torch.autograd.Function):
@staticmethod
def forward(ctx, colors, sigmas, ray_dist):
ctx.save_for_backward(ray_dist)
return colors, sigmas, ray_dist
@staticmethod
def backward(ctx, grad_output_colors, grad_output_sigmas, grad_output_ray_dist):
(ray_dist,) = ctx.saved_tensors
scaling = torch.square(ray_dist).clamp(0, 1)
return grad_output_colors * scaling.unsqueeze(-1), grad_output_sigmas * scaling, grad_output_ray_dist
##### Call with: #####
if gradient_scaling:
colors, sigmas, ray_dist = GradientScaler.apply(colors, sigmas, ray_dist)
##### Define with: #####
@custom_jvp
def gradientScaling(origins,gaussians,rgb,density):
return rgb,density
@gradientScaling.defjvp
def lgradientScaling_jvp(primals, tangents):
origins,gaussians,rgb,density=primals
origins_dot,gaussians_dot,rgb_dot,density_dot=tangents
ans=gradientScaling(origins,gaussians,rgb,density)
scaling=jnp.square(jnp.linalg.norm(gaussians[0]-origins[:,:,:,None],axis=-1,keepdims=True)).clip(0,1)
ans_dot=(rgb_dot*scaling,density_dot*scaling[...,0])
return ans, ans_dot
##### Call with: #####
if gradient_scaling:
ray_results['rgb'],ray_results['density']=gradientScaling(rays.origins,gaussians,ray_results['rgb'],ray_results['density'])
Coming Soon, see arXiv for now.