3
3
"""
4
4
5
5
import itertools
6
+
6
7
import torch
7
8
8
9
9
10
def generate_uniform_directions (num_thetas : int , d : int , seed : int , device : str ):
10
11
"""
11
12
Generate randomly sampled directions from a sphere in d dimensions.
12
13
13
- A standard normal is sampled and projected onto the unit sphere to
14
- yield a randomly sampled set of points on the unit spere. Please
14
+ A standard normal is sampled and projected onto the unit sphere to
15
+ yield a randomly sampled set of points on the unit spere. Please
15
16
note that the generated tensor has shape [d, num_thetas].
16
17
17
18
Parameters
@@ -58,8 +59,6 @@ def generate_2d_directions(num_thetas: int = 64):
58
59
59
60
def generate_multiview_directions (num_thetas : int , d : int ):
60
61
"""
61
- NOTE: Partially depreciated.
62
-
63
62
Generates multiple sets of structured directions in n dimensions.
64
63
65
64
We generate sets of directions by embedding the 2d unit circle in d
@@ -79,20 +78,34 @@ def generate_multiview_directions(num_thetas: int, d: int):
79
78
d: int
80
79
The dimension of the unit sphere. Default is 3 (hence R^3)
81
80
"""
81
+
82
+ # We obtain n choose 2 channels.
83
+ idx_pairs = list (itertools .combinations (range (d ), r = 2 ))
84
+
85
+ num_directions_per_circle = num_thetas // len (idx_pairs )
86
+ remainder = num_thetas % len (idx_pairs )
87
+
82
88
w = torch .vstack (
83
89
[
84
- torch .sin (torch .linspace (0 , 2 * torch .pi , num_thetas )),
85
- torch .cos (torch .linspace (0 , 2 * torch .pi , num_thetas )),
90
+ torch .sin (
91
+ torch .linspace (0 , 2 * torch .pi , num_directions_per_circle + remainder )
92
+ ),
93
+ torch .cos (
94
+ torch .linspace (0 , 2 * torch .pi , num_directions_per_circle + remainder )
95
+ ),
86
96
]
87
97
)
88
98
89
- # We obtain n choose 2 channels.
90
- idx_pairs = list (itertools .combinations (range (d ), r = 2 ))
99
+ multiview_dirs = []
100
+ for idx , idx_pair in enumerate (idx_pairs ):
101
+ num_t = num_directions_per_circle
102
+ if idx == 0 and remainder != 0 :
103
+ num_t = num_directions_per_circle + remainder
91
104
92
- v = torch .zeros (size = (len (idx_pairs ), d , num_thetas ))
105
+ v = torch .zeros (size = (d , num_t ))
106
+ v [idx_pair [0 ], :] = w [0 , :num_t ]
107
+ v [idx_pair [1 ], :] = w [1 , :num_t ]
93
108
94
- for idx , idx_pair in enumerate (idx_pairs ):
95
- v [idx , idx_pair [0 ], :] = w [0 ]
96
- v [idx , idx_pair [1 ], :] = w [1 ]
109
+ multiview_dirs .append (v )
97
110
98
- return v
111
+ return torch . hstack ( multiview_dirs )
0 commit comments