Skip to content

Commit 514f6cc

Browse files
committed
Fixed radius for geotorch
1 parent e8edde7 commit 514f6cc

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

dect/nn.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,17 @@
22
NOTE: Under construction.
33
TODO: Needs implementation and refactoring.
44
5-
65
Implementation of the ECT with learnable parameters.
76
"""
87

9-
from typing import TypeAlias, Literal
108
from dataclasses import dataclass
9+
from typing import Literal, TypeAlias
1110

11+
import geotorch
1212
import torch
1313
from torch import nn
14-
import geotorch
15-
from dect.ect import (
16-
compute_ect_points,
17-
compute_ect_edges,
18-
compute_ect_mesh,
19-
)
14+
15+
from dect.ect import compute_ect_edges, compute_ect_mesh, compute_ect_points
2016
from dect.ect_fn import scaled_sigmoid
2117

2218
Tensor: TypeAlias = torch.Tensor
@@ -145,7 +141,7 @@ def __init__(self, config: ECTConfig, v=None):
145141
else:
146142
# Movedim to make geotorch happy, me not happy.
147143
self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2)))
148-
geotorch.constraints.sphere(self, "v", radius=config.radius)
144+
geotorch.constraints.sphere(self, "v", radius=1.0)
149145

150146
# Since geotorch randomizes the vector during initialization, we
151147
# assign the values after registering it with spherical constraints.

0 commit comments

Comments
 (0)