Skip to content

wire coord grads from kernel #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wisp/csrc/ops/hashgrid_interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ at::Tensor hashgrid_interpolate_cuda(
#endif // WITH_CUDA
}

at::Tensor hashgrid_interpolate_backward_cuda(
std::vector<at::Tensor> hashgrid_interpolate_backward_cuda(
at::Tensor coords,
at::Tensor grad_output,
at::Tensor codebook,
Expand Down Expand Up @@ -93,7 +93,7 @@ at::Tensor hashgrid_interpolate_backward_cuda(
resolution[i], i, num_lods, require_grad_coords,
coords, codebook, codebook_first_idx, grad_output, grad_codebook, grad_coords);
}
return grad_codebook;
return {grad_codebook, grad_coords};
#else
AT_ERROR(__func__);
#endif // WITH_CUDA
Expand Down
2 changes: 1 addition & 1 deletion wisp/csrc/ops/hashgrid_interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ at::Tensor hashgrid_interpolate_cuda(
std::vector<int32_t> resolution,
int32_t codebook_bitwidth);

at::Tensor hashgrid_interpolate_backward_cuda(
std::vector<at::Tensor> hashgrid_interpolate_backward_cuda(
at::Tensor coords,
at::Tensor grad_output,
at::Tensor codebook,
Expand Down
10 changes: 6 additions & 4 deletions wisp/ops/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,22 @@ def forward(ctx, coords, resolutions, codebook_bitwidth, lod_idx, codebook, code
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):

coords = ctx.saved_tensors[0]
codebook = ctx.saved_tensors[1]
codebook_first_idx = ctx.saved_tensors[2]
resolutions = ctx.resolutions
feature_dim = ctx.feature_dim
codebook_bitwidth = ctx.codebook_bitwidth

grad_codebook = wisp_C.ops.hashgrid_interpolate_backward_cuda(
is_needs_grad_by_coords = ctx.needs_input_grad[0]
grad_codebook, grad_coords = wisp_C.ops.hashgrid_interpolate_backward_cuda(
coords.float().contiguous(), grad_output.contiguous(), codebook,
codebook_first_idx,
resolutions,
codebook_bitwidth, feature_dim, ctx.needs_input_grad[0])
return (None, None, None, None, grad_codebook, None, None)
codebook_bitwidth, feature_dim, is_needs_grad_by_coords)
if not is_needs_grad_by_coords:
grad_coords = None
return grad_coords, None, None, None, grad_codebook, None, None

def hashgrid(coords, resolutions, codebook_bitwidth, lod_idx, codebook, codebook_sizes, codebook_first_idx):
"""A hash-grid query + interpolation function, accelerated with CUDA.
Expand Down