Skip to content

Commit 908f721

Browse files
authored
Merge pull request #12 from aidos-lab/ect-channels
ECT Channels
2 parents 5fd3afc + f206d72 commit 908f721

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

dect/ect.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
"""@private"""
2626

2727

28-
def normalize(ect):
28+
def normalize_ect(ect):
2929
"""Returns the normalized ect, scaled to lie in the interval 0,1"""
30-
return ect / torch.amax(ect, dim=(2, 3)).unsqueeze(2).unsqueeze(2)
30+
breakpoint()
31+
return ect / torch.amax(ect, dim=(-2, -3))
3132

3233

3334
def compute_ect(
@@ -180,7 +181,7 @@ def compute_ect_point_cloud(
180181
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
181182
ect = torch.sum(ecc, dim=2)
182183
if normalize:
183-
ect = ect / torch.amax(ect, dim=(-1, -2), keepdim=True)
184+
ect = normalize_ect(ect)
184185

185186
return ect
186187

@@ -431,3 +432,66 @@ def compute_ect_mesh(
431432

432433
# Returns the ect as [batch_len, num_thetas, resolution]
433434
return output.movedim(0, 1).movedim(-1, -2)
435+
436+
437+
def compute_ect_channels(
438+
x: Tensor,
439+
v: Tensor,
440+
radius: float,
441+
resolution: int,
442+
scale: float,
443+
channels: Tensor,
444+
index: Tensor | None = None,
445+
max_channels: int | None = None,
446+
normalize: bool = False,
447+
):
448+
"""
449+
Allows for channels within the point cloud to separated in different
450+
ECT's.
451+
452+
Input is a point cloud of size (B*num_point_per_pc,num_features) with an addtional feature vector with the
453+
channel number for each point and the output is ECT for shape [B,num_channels,num_thetas,resolution]
454+
"""
455+
456+
# Ensure that the scale is in the right device
457+
scale = torch.tensor([scale], device=x.device)
458+
459+
# Compute maximum channels.
460+
if max_channels is None:
461+
max_channels = int(channels.max())
462+
463+
if index is not None:
464+
batch_len = int(index.max() + 1)
465+
else:
466+
batch_len = 1
467+
index = torch.zeros(
468+
size=(len(x),),
469+
dtype=torch.int32,
470+
device=x.device,
471+
)
472+
473+
# Fix the index to interleave with the channel info.
474+
index = max_channels * index + channels
475+
476+
# v is of shape [ambient_dimension, num_thetas]
477+
num_thetas = v.shape[1]
478+
479+
out_shape = (resolution, batch_len * (max_channels + 1), num_thetas)
480+
481+
# Node heights have shape [num_points, num_directions]
482+
nh = x @ v
483+
lin = torch.linspace(-radius, radius, resolution, device=x.device).view(-1, 1, 1)
484+
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
485+
output = torch.zeros(
486+
size=out_shape,
487+
device=nh.device,
488+
)
489+
490+
output.index_add_(1, index, ecc)
491+
ect = output.movedim(0, 1)
492+
493+
if normalize:
494+
ect = ect / torch.amax(ect, dim=(-2, -3))
495+
496+
# Returns the ect as [batch_len, num_thetas, resolution]
497+
return ect.reshape(-1, max_channels + 1, resolution, num_thetas)

0 commit comments

Comments
 (0)