33import torch
44import torch .nn as nn
55from torch .nn import Linear
6- from typing import List , Optional , Tuple
6+ from typing import List , Optional
77from e3nn .o3 import xyz_to_angles , Irreps
88from e3nn .util .jit import compile_mode
99
10-
1110_Jd_file = os .path .join (os .path .dirname (__file__ ), "Jd.pt" )
1211if os .path .exists (_Jd_file ):
1312 _Jd = torch .load (_Jd_file )
1413else :
15- print (f"Warning: Jd.pt not found at { _Jd_file } . Wigner D functions will fail." )
16- _Jd = []
14+ raise RuntimeError (f"Jd.pt not found at { _Jd_file } . Wigner D functions will fail." )
1715
1816
1917def wigner_D (l : int , alpha : torch .Tensor , beta : torch .Tensor , gamma : torch .Tensor ) -> torch .Tensor :
20- if not _Jd :
21- raise RuntimeError ("Jd.pt was not loaded. Cannot compute Wigner D matrices." )
2218 if not l < len (_Jd ):
2319 raise NotImplementedError (
2420 f"wigner D maximum l implemented is { len (_Jd ) - 1 } , send us an email to ask for more"
@@ -59,111 +55,97 @@ def __init__(
5955 extra_m0_outsize : int = 0 ,
6056 ):
6157 super ().__init__ ()
62- if not _Jd :
63- raise RuntimeError ("Jd.pt was not loaded. SO2_Linear cannot be initialized." )
6458 self .Jd : List [torch .Tensor ] = _Jd
6559
6660 irreps_in_s = irreps_in .simplify ()
6761 irreps_out_s = (Irreps (f"{ extra_m0_outsize } x0e" ) + irreps_out ).simplify ()
6862
69- self .irreps_out : Irreps = irreps_out_s
63+ self .irreps_out = irreps_out_s
7064 self .in_dim = irreps_in_s .dim
7165 self .out_dim = irreps_out_s .dim
7266 self .in_num_irreps = irreps_in_s .num_irreps
7367 self .out_num_irreps = irreps_out_s .num_irreps
7468 self .has_radial = radial_emb
7569
76- if radial_channels is None :
77- radial_channels = []
78-
79- in_offsets_list : List [int ] = []
80- in_mul_list : List [int ] = []
81- in_l_list : List [int ] = []
82- current_offset = 0
83- for mul , (l , p_val ) in irreps_in_s :
84- in_offsets_list .append (current_offset )
85- in_mul_list .append (mul )
86- in_l_list .append (l )
87- current_offset += mul * (2 * l + 1 )
88- in_offsets_list .append (current_offset )
89- self .register_buffer ('in_offsets' , torch .tensor (in_offsets_list , dtype = torch .long ))
90- self .register_buffer ('in_mul' , torch .tensor (in_mul_list , dtype = torch .long ))
91- self .register_buffer ('in_l' , torch .tensor (in_l_list , dtype = torch .long ))
92-
93- out_offsets_list : List [int ] = []
94- out_mul_list : List [int ] = []
95- out_l_list : List [int ] = []
96- current_offset = 0
97- for mul , (l , p_val ) in irreps_out_s :
98- out_offsets_list .append (current_offset )
99- out_mul_list .append (mul )
100- out_l_list .append (l )
101- current_offset += mul * (2 * l + 1 )
102- out_offsets_list .append (current_offset )
103- self .register_buffer ('out_offsets' , torch .tensor (out_offsets_list , dtype = torch .long ))
104- self .register_buffer ('out_mul' , torch .tensor (out_mul_list , dtype = torch .long ))
105- self .register_buffer ('out_l' , torch .tensor (out_l_list , dtype = torch .long ))
106-
70+ # Buffers for irreps layout
71+ in_offsets , in_mul , in_l = [], [], []
72+ offset = 0
73+ for mul , (l , _ ) in irreps_in_s :
74+ in_offsets .append (offset )
75+ in_mul .append (mul )
76+ in_l .append (l )
77+ offset += mul * (2 * l + 1 )
78+ in_offsets .append (offset )
79+ self .register_buffer ('in_offsets' , torch .tensor (in_offsets , dtype = torch .long ))
80+ self .register_buffer ('in_mul' , torch .tensor (in_mul , dtype = torch .long ))
81+ self .register_buffer ('in_l' , torch .tensor (in_l , dtype = torch .long ))
82+
83+ out_offsets , out_mul , out_l = [], [], []
84+ offset = 0
85+ for mul , (l , _ ) in irreps_out_s :
86+ out_offsets .append (offset )
87+ out_mul .append (mul )
88+ out_l .append (l )
89+ offset += mul * (2 * l + 1 )
90+ out_offsets .append (offset )
91+ self .register_buffer ('out_offsets' , torch .tensor (out_offsets , dtype = torch .long ))
92+ self .register_buffer ('out_mul' , torch .tensor (out_mul , dtype = torch .long ))
93+ self .register_buffer ('out_l' , torch .tensor (out_l , dtype = torch .long ))
94+
95+ # m-in mask and count
10796 m_in_mask = torch .zeros (irreps_in_s .lmax + 1 , self .in_dim , dtype = torch .bool )
10897 cnt_list = [0 ] * (irreps_in_s .lmax + 1 )
109-
110- current_offset_for_mask = 0
111- for i in range (len (irreps_in_s )):
112- mul , (l , p_val ) = irreps_in_s [i ]
113- for k_mul in range (mul ):
114- base_idx = current_offset_for_mask + k_mul * (2 * l + 1 )
98+ cur = 0
99+ for mul , (l , _ ) in irreps_in_s :
100+ for k in range (mul ):
101+ base = cur + k * (2 * l + 1 )
115102 for m_val in range (l + 1 ):
116103 if m_val == 0 :
117- m_in_mask [m_val , base_idx + l ] = True
104+ m_in_mask [m_val , base + l ] = True
118105 cnt_list [m_val ] += 1
119106 else :
120- m_in_mask [m_val , base_idx + l + m_val ] = True
121- m_in_mask [m_val , base_idx + l - m_val ] = True
107+ m_in_mask [m_val , base + l + m_val ] = True
108+ m_in_mask [m_val , base + l - m_val ] = True
122109 cnt_list [m_val ] += 1
123- current_offset_for_mask += mul * (2 * l + 1 )
110+ cur += mul * (2 * l + 1 )
124111 self .register_buffer ('m_in_mask' , m_in_mask )
125112 self .register_buffer ('cnt' , torch .tensor (cnt_list , dtype = torch .long ))
113+ self .register_buffer ('m_idx' , torch .cat ([torch .tensor ([0 ], dtype = torch .long ), torch .cumsum (torch .tensor (cnt_list , dtype = torch .long ), dim = 0 )]))
126114
127- m_idx = torch .cat ([torch .tensor ([0 ], dtype = torch .long ), torch .cumsum (self .cnt , dim = 0 )])
128- self .register_buffer ('m_idx' , m_idx )
129-
115+ # m-out mask
130116 m_out_mask = torch .zeros (irreps_out_s .lmax + 1 , self .out_dim , dtype = torch .bool )
131- current_offset_for_mask = 0
132- for i in range (len (irreps_out_s )):
133- mul , (l , p_val ) = irreps_out_s [i ]
134- for k_mul in range (mul ):
135- base_idx = current_offset_for_mask + k_mul * (2 * l + 1 )
117+ cur = 0
118+ for mul , (l , _ ) in irreps_out_s :
119+ for k in range (mul ):
120+ base = cur + k * (2 * l + 1 )
136121 for m_val in range (l + 1 ):
137122 if m_val <= irreps_in_s .lmax :
138123 if m_val == 0 :
139- m_out_mask [m_val , base_idx + l ] = True
124+ m_out_mask [m_val , base + l ] = True
140125 else :
141- m_out_mask [m_val , base_idx + l + m_val ] = True
142- m_out_mask [m_val , base_idx + l - m_val ] = True
143- current_offset_for_mask += mul * (2 * l + 1 )
126+ m_out_mask [m_val , base + l + m_val ] = True
127+ m_out_mask [m_val , base + l - m_val ] = True
128+ cur += mul * (2 * l + 1 )
144129 self .register_buffer ('m_out_mask' , m_out_mask )
145130
131+ # fc0 and m_linears
146132 self .fc0 = Linear (self .in_num_irreps , self .out_num_irreps , bias = True )
133+ self .m_linears = nn .ModuleList ([SO2_m_Linear (mv , irreps_in_s , irreps_out_s ) for mv in range (1 , irreps_out_s .lmax + 1 )])
147134
148- self .m_linears = nn .ModuleList ([
149- SO2_m_Linear (m , irreps_in_s , irreps_out_s ) for m in range (1 , irreps_out_s .lmax + 1 )
150- ])
151-
135+ # radial embedding
152136 if self .has_radial :
153- if latent_dim <= 0 :
154- raise ValueError ("latent_dim must be > 0 if radial_emb is True" )
155137 layers_list : List [nn .Module ] = []
156- current_ch_radial = latent_dim
157- all_radial_net_channels = radial_channels + [int (m_idx [- 1 ].item ())]
158- for i , next_ch_radial in enumerate (all_radial_net_channels ):
159- layers_list .append (Linear (current_ch_radial , next_ch_radial , bias = True ))
160- current_ch_radial = next_ch_radial
161- if i < len (all_radial_net_channels ) - 1 :
162- layers_list .append (nn .LayerNorm (next_ch_radial ))
138+ current_dim = latent_dim
139+ all_radial_layer_dims = ( radial_channels if radial_channels is not None else []) + [int (self . m_idx [- 1 ].item ())]
140+ for i , out_ch in enumerate (all_radial_layer_dims ):
141+ layers_list .append (Linear (current_dim , out_ch , bias = True ))
142+ current_dim = out_ch
143+ if i < len (all_radial_layer_dims ) - 1 : # Not the last layer
144+ layers_list .append (nn .LayerNorm (out_ch ))
163145 layers_list .append (nn .SiLU ())
164- self .radial : nn . Module = nn .Sequential (* layers_list )
146+ self .radial = nn .Sequential (* layers_list )
165147 else :
166- self .radial : nn . Module = nn .Identity () # Explicitly type self.radial here for clarity
148+ self .radial = nn .Identity ()
167149
168150 def _wigner (self , l : int , alpha : torch .Tensor , beta : torch .Tensor , gamma : torch .Tensor ) -> torch .Tensor :
169151 J = self .Jd [l ].to (dtype = alpha .dtype , device = alpha .device )
@@ -179,114 +161,67 @@ def forward(
179161 alpha , beta = xyz_to_angles (R [:, [1 , 2 , 0 ]])
180162 gamma = torch .zeros_like (alpha )
181163
182- # MODIFIED PART FOR w CALCULATION
183- w : Optional [ torch .Tensor ] = None
164+ # initialize radial weights tensor to empty or computed
165+ w = torch .ones ( n , int ( self . m_idx [ - 1 ]. item ()), dtype = x . dtype , device = x . device )
184166 if self .has_radial :
185167 if latents is None :
186- raise RuntimeError ("`latents` must be provided and be a Tensor when `radial_emb=True`" )
168+ raise RuntimeError ("`latents` must be provided when `radial_emb=True`" )
187169 w = self .radial (latents )
188- # END OF MODIFIED PART
189170
190- x_rot = x .clone ()
171+ # initialize x_rot to zero
172+ x_rot = torch .zeros_like (x )
191173 for i in range (len (self .in_mul )):
192174 start = int (self .in_offsets [i ].item ())
193175 end = int (self .in_offsets [i + 1 ].item ())
194176 mul = int (self .in_mul [i ].item ())
195177 l_val = int (self .in_l [i ].item ())
196-
197178 if l_val > 0 :
198- rot_mat = self ._wigner (l_val , alpha , beta , gamma )
199- vals = x_rot [:, start :end ].reshape (n , mul , 2 * l_val + 1 )
200- rotated_vals = torch .einsum ('nji,nmj->nmi' , rot_mat , vals )
201- x_rot [:, start :end ] = rotated_vals .reshape (n , - 1 )
179+ rot = self ._wigner (l_val , alpha , beta , gamma )
180+ vals = x [:, start :end ].reshape (n , mul , 2 * l_val + 1 )
181+ x_rot [:, start :end ] = torch .einsum ('nji,nmj->nmi' , rot , vals ).reshape (n , - 1 )
202182
203183 out = x .new_zeros (n , self .out_dim )
204-
205- seg0_raw = x_rot [:, self .m_in_mask [0 ]]
206- seg0_for_fc0 = seg0_raw .clone ()
207-
208- current_col_in_seg0 = 0
209- for i_irrep in range (len (self .in_l )):
210- l_val_of_input_irrep = int (self .in_l [i_irrep ].item ())
211- mul_of_input_irrep = int (self .in_mul [i_irrep ].item ())
212-
213- if l_val_of_input_irrep == 0 :
214- seg0_for_fc0 [:, current_col_in_seg0 : current_col_in_seg0 + mul_of_input_irrep ] = 0.0
215- current_col_in_seg0 += mul_of_input_irrep
216-
217- if w is not None :
218- start_w = int (self .m_idx [0 ].item ())
219- end_w = int (self .m_idx [1 ].item ())
220- w_m0 = w [:, start_w :end_w ]
221- if seg0_for_fc0 .size (1 ) == w_m0 .size (1 ): # Ensure dimensions match for broadcasting/element-wise mul
222- seg0_for_fc0 = seg0_for_fc0 * w_m0
223- elif seg0_for_fc0 .size (1 ) != 0 and w_m0 .size (1 ) != 0 : # Both non-zero but mismatch
224- raise RuntimeError (
225- f"Dimension mismatch for radial weights at m=0: seg0 has { seg0_for_fc0 .size (1 )} , w_m0 has { w_m0 .size (1 )} " )
226- # If one is zero dim, multiplication might be okay or do nothing, depends on exact case.
227- # For safety, only multiply if dims match and are non-zero. If seg0 is empty, w_m0 should also be.
228-
229- out [:, self .m_out_mask [0 ]] += self .fc0 (seg0_for_fc0 )
230-
231- for idx , m_linear_layer in enumerate (self .m_linears ):
184+ # m=0
185+ seg0 = x_rot [:, self .m_in_mask [0 ]]
186+ if w is not None and seg0 .numel () > 0 :
187+ seg0 = seg0 * w [:, self .m_idx [0 ]:self .m_idx [1 ]]
188+ out [:, self .m_out_mask [0 ]] += self .fc0 (seg0 )
189+ # m>0
190+ for idx , layer in enumerate (self .m_linears ):
232191 m_val = idx + 1
233- if self .m_in_mask [m_val ].any ():
234- seg_m = x_rot [:, self .m_in_mask [m_val ]].reshape (n , 2 , - 1 )
235-
236- if w is not None :
237- start_w = int (self .m_idx [m_val ].item ())
238- end_w = int (self .m_idx [m_val + 1 ].item ())
239- w_slice = w [:, start_w :end_w ]
240- if seg_m .size (2 ) == w_slice .size (1 ) and seg_m .size (2 ) > 0 :
241- seg_m = seg_m * w_slice .unsqueeze (1 )
242- elif seg_m .size (2 ) != 0 and w_slice .size (1 ) != 0 :
243- raise RuntimeError (
244- f"Dimension mismatch for radial weights at m={ m_val } : seg_m has { seg_m .size (2 )} , w_slice has { w_slice .size (1 )} " )
245-
246- processed_seg_m = m_linear_layer (seg_m ).reshape (n , - 1 )
247- out [:, self .m_out_mask [m_val ]] += processed_seg_m
248-
192+ mask = self .m_in_mask [m_val ]
193+ if mask .any ():
194+ seg = x_rot [:, mask ].reshape (n , 2 , - 1 )
195+ if w is not None and seg .numel () > 0 :
196+ seg = seg * w [:, self .m_idx [m_val ]:self .m_idx [m_val + 1 ]].unsqueeze (1 )
197+ out [:, self .m_out_mask [m_val ]] += layer (seg ).reshape (n , - 1 )
198+ # final rotation
249199 for i in range (len (self .out_mul )):
250200 start = int (self .out_offsets [i ].item ())
251201 end = int (self .out_offsets [i + 1 ].item ())
252- mul = int (self .out_mul [i ].item ())
253202 l_val = int (self .out_l [i ].item ())
254-
203+ mul = int ( self . out_mul [ i ]. item ())
255204 if l_val > 0 :
256- rot_mat = self ._wigner (l_val , alpha , beta , gamma )
205+ rot = self ._wigner (l_val , alpha , beta , gamma )
257206 vals = out [:, start :end ].reshape (n , mul , 2 * l_val + 1 )
258- out [:, start :end ] = torch .einsum ('nji,nmj->nmi' , rot_mat , vals ).reshape (n , - 1 )
259-
207+ out [:, start :end ] = torch .einsum ('nji,nmj->nmi' , rot , vals ).reshape (n , - 1 )
260208 return out
261209
262210
263211@compile_mode ("script" )
264212class SO2_m_Linear (nn .Module ):
265- def __init__ (self , m : int , irreps_in_s : Irreps , irreps_out_s : Irreps ):
213+ def __init__ (self , m_val : int , irreps_in_s : Irreps , irreps_out_s : Irreps ):
266214 super ().__init__ ()
267- num_in = sum ( mul for mul , ( l , p_val ) in irreps_in_s if l >= m )
268- num_out = sum (mul for mul , (l , p_val ) in irreps_out_s if l >= m )
269-
215+ # count input/output channels for order m_val
216+ num_in = sum (mul for mul , (l , _ ) in irreps_in_s if l >= m_val )
217+ num_out = sum ( mul for mul , ( l , _ ) in irreps_out_s if l >= m_val )
270218 self .fc = Linear (num_in , 2 * num_out , bias = False )
271219 if num_in > 0 and num_out > 0 :
272220 self .fc .weight .data .mul_ (1.0 / math .sqrt (2.0 ))
273221
274222 def forward (self , x : torch .Tensor ) -> torch .Tensor :
275- if x .size (2 ) == 0 :
276- if self .fc .out_features == 0 :
277- return torch .empty ((x .size (0 ), 2 , 0 ), dtype = x .dtype , device = x .device )
278- # If num_in is 0, but num_out > 0, fc(x) will still produce output of shape [N, 2, 2*num_out]
279- # where the input to fc was effectively zeros.
280- # So, proceed with fc(x) even if x.size(2) == 0, as fc handles it.
281-
282223 y = self .fc (x )
283-
284- num_out_channels = y .size (2 ) // 2
285- if num_out_channels == 0 :
286- return torch .empty ((x .size (0 ), 2 , 0 ), dtype = x .dtype , device = x .device )
287-
288- out_re = y [:, 0 , :num_out_channels ] - y [:, 1 , num_out_channels :]
289- out_im = y [:, 0 , num_out_channels :] + y [:, 1 , :num_out_channels ]
290-
291- return torch .stack ((out_re , out_im ), dim = 1 )
292-
224+ num_out = y .size (2 ) // 2
225+ re = y [:, 0 , :num_out ] - y [:, 1 , num_out :]
226+ im = y [:, 0 , num_out :] + y [:, 1 , :num_out ]
227+ return torch .stack ((re , im ), dim = 1 )
0 commit comments