Skip to content

Commit b8ade2a

Browse files
committed
Added functions to generate structured directions
1 parent ee82a3a commit b8ade2a

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

dect/directions.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
"""
44

55
import itertools
6+
67
import torch
78

89

910
def generate_uniform_directions(num_thetas: int, d: int, seed: int, device: str):
1011
"""
1112
Generate randomly sampled directions from a sphere in d dimensions.
1213
13-
A standard normal is sampled and projected onto the unit sphere to
14-
yield a randomly sampled set of points on the unit spere. Please
14+
A standard normal is sampled and projected onto the unit sphere to
15+
yield a randomly sampled set of points on the unit spere. Please
1516
note that the generated tensor has shape [d, num_thetas].
1617
1718
Parameters
@@ -58,8 +59,6 @@ def generate_2d_directions(num_thetas: int = 64):
5859

5960
def generate_multiview_directions(num_thetas: int, d: int):
6061
"""
61-
NOTE: Partially depreciated.
62-
6362
Generates multiple sets of structured directions in n dimensions.
6463
6564
We generate sets of directions by embedding the 2d unit circle in d
@@ -79,20 +78,34 @@ def generate_multiview_directions(num_thetas: int, d: int):
7978
d: int
8079
The dimension of the unit sphere. Default is 3 (hence R^3)
8180
"""
81+
82+
# We obtain n choose 2 channels.
83+
idx_pairs = list(itertools.combinations(range(d), r=2))
84+
85+
num_directions_per_circle = num_thetas // len(idx_pairs)
86+
remainder = num_thetas % len(idx_pairs)
87+
8288
w = torch.vstack(
8389
[
84-
torch.sin(torch.linspace(0, 2 * torch.pi, num_thetas)),
85-
torch.cos(torch.linspace(0, 2 * torch.pi, num_thetas)),
90+
torch.sin(
91+
torch.linspace(0, 2 * torch.pi, num_directions_per_circle + remainder)
92+
),
93+
torch.cos(
94+
torch.linspace(0, 2 * torch.pi, num_directions_per_circle + remainder)
95+
),
8696
]
8797
)
8898

89-
# We obtain n choose 2 channels.
90-
idx_pairs = list(itertools.combinations(range(d), r=2))
99+
multiview_dirs = []
100+
for idx, idx_pair in enumerate(idx_pairs):
101+
num_t = num_directions_per_circle
102+
if idx == 0 and remainder != 0:
103+
num_t = num_directions_per_circle + remainder
91104

92-
v = torch.zeros(size=(len(idx_pairs), d, num_thetas))
105+
v = torch.zeros(size=(d, num_t))
106+
v[idx_pair[0], :] = w[0, :num_t]
107+
v[idx_pair[1], :] = w[1, :num_t]
93108

94-
for idx, idx_pair in enumerate(idx_pairs):
95-
v[idx, idx_pair[0], :] = w[0]
96-
v[idx, idx_pair[1], :] = w[1]
109+
multiview_dirs.append(v)
97110

98-
return v
111+
return torch.hstack(multiview_dirs)

0 commit comments

Comments
 (0)