Skip to content

Commit 47ba3bf

Browse files
committed
Update tensor_product.py
1 parent 89c18d9 commit 47ba3bf

File tree

1 file changed

+92
-157
lines changed

1 file changed

+92
-157
lines changed

dptb/nn/tensor_product.py

Lines changed: 92 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,18 @@
33
import torch
44
import torch.nn as nn
55
from torch.nn import Linear
6-
from typing import List, Optional, Tuple
6+
from typing import List, Optional
77
from e3nn.o3 import xyz_to_angles, Irreps
88
from e3nn.util.jit import compile_mode
99

10-
1110
_Jd_file = os.path.join(os.path.dirname(__file__), "Jd.pt")
1211
if os.path.exists(_Jd_file):
1312
_Jd = torch.load(_Jd_file)
1413
else:
15-
print(f"Warning: Jd.pt not found at {_Jd_file}. Wigner D functions will fail.")
16-
_Jd = []
14+
raise RuntimeError(f"Jd.pt not found at {_Jd_file}. Wigner D functions will fail.")
1715

1816

1917
def wigner_D(l: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
20-
if not _Jd:
21-
raise RuntimeError("Jd.pt was not loaded. Cannot compute Wigner D matrices.")
2218
if not l < len(_Jd):
2319
raise NotImplementedError(
2420
f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more"
@@ -59,111 +55,97 @@ def __init__(
5955
extra_m0_outsize: int = 0,
6056
):
6157
super().__init__()
62-
if not _Jd:
63-
raise RuntimeError("Jd.pt was not loaded. SO2_Linear cannot be initialized.")
6458
self.Jd: List[torch.Tensor] = _Jd
6559

6660
irreps_in_s = irreps_in.simplify()
6761
irreps_out_s = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify()
6862

69-
self.irreps_out: Irreps = irreps_out_s
63+
self.irreps_out = irreps_out_s
7064
self.in_dim = irreps_in_s.dim
7165
self.out_dim = irreps_out_s.dim
7266
self.in_num_irreps = irreps_in_s.num_irreps
7367
self.out_num_irreps = irreps_out_s.num_irreps
7468
self.has_radial = radial_emb
7569

76-
if radial_channels is None:
77-
radial_channels = []
78-
79-
in_offsets_list: List[int] = []
80-
in_mul_list: List[int] = []
81-
in_l_list: List[int] = []
82-
current_offset = 0
83-
for mul, (l, p_val) in irreps_in_s:
84-
in_offsets_list.append(current_offset)
85-
in_mul_list.append(mul)
86-
in_l_list.append(l)
87-
current_offset += mul * (2 * l + 1)
88-
in_offsets_list.append(current_offset)
89-
self.register_buffer('in_offsets', torch.tensor(in_offsets_list, dtype=torch.long))
90-
self.register_buffer('in_mul', torch.tensor(in_mul_list, dtype=torch.long))
91-
self.register_buffer('in_l', torch.tensor(in_l_list, dtype=torch.long))
92-
93-
out_offsets_list: List[int] = []
94-
out_mul_list: List[int] = []
95-
out_l_list: List[int] = []
96-
current_offset = 0
97-
for mul, (l, p_val) in irreps_out_s:
98-
out_offsets_list.append(current_offset)
99-
out_mul_list.append(mul)
100-
out_l_list.append(l)
101-
current_offset += mul * (2 * l + 1)
102-
out_offsets_list.append(current_offset)
103-
self.register_buffer('out_offsets', torch.tensor(out_offsets_list, dtype=torch.long))
104-
self.register_buffer('out_mul', torch.tensor(out_mul_list, dtype=torch.long))
105-
self.register_buffer('out_l', torch.tensor(out_l_list, dtype=torch.long))
106-
70+
# Buffers for irreps layout
71+
in_offsets, in_mul, in_l = [], [], []
72+
offset = 0
73+
for mul, (l, _) in irreps_in_s:
74+
in_offsets.append(offset)
75+
in_mul.append(mul)
76+
in_l.append(l)
77+
offset += mul * (2 * l + 1)
78+
in_offsets.append(offset)
79+
self.register_buffer('in_offsets', torch.tensor(in_offsets, dtype=torch.long))
80+
self.register_buffer('in_mul', torch.tensor(in_mul, dtype=torch.long))
81+
self.register_buffer('in_l', torch.tensor(in_l, dtype=torch.long))
82+
83+
out_offsets, out_mul, out_l = [], [], []
84+
offset = 0
85+
for mul, (l, _) in irreps_out_s:
86+
out_offsets.append(offset)
87+
out_mul.append(mul)
88+
out_l.append(l)
89+
offset += mul * (2 * l + 1)
90+
out_offsets.append(offset)
91+
self.register_buffer('out_offsets', torch.tensor(out_offsets, dtype=torch.long))
92+
self.register_buffer('out_mul', torch.tensor(out_mul, dtype=torch.long))
93+
self.register_buffer('out_l', torch.tensor(out_l, dtype=torch.long))
94+
95+
# m-in mask and count
10796
m_in_mask = torch.zeros(irreps_in_s.lmax + 1, self.in_dim, dtype=torch.bool)
10897
cnt_list = [0] * (irreps_in_s.lmax + 1)
109-
110-
current_offset_for_mask = 0
111-
for i in range(len(irreps_in_s)):
112-
mul, (l, p_val) = irreps_in_s[i]
113-
for k_mul in range(mul):
114-
base_idx = current_offset_for_mask + k_mul * (2 * l + 1)
98+
cur = 0
99+
for mul, (l, _) in irreps_in_s:
100+
for k in range(mul):
101+
base = cur + k * (2 * l + 1)
115102
for m_val in range(l + 1):
116103
if m_val == 0:
117-
m_in_mask[m_val, base_idx + l] = True
104+
m_in_mask[m_val, base + l] = True
118105
cnt_list[m_val] += 1
119106
else:
120-
m_in_mask[m_val, base_idx + l + m_val] = True
121-
m_in_mask[m_val, base_idx + l - m_val] = True
107+
m_in_mask[m_val, base + l + m_val] = True
108+
m_in_mask[m_val, base + l - m_val] = True
122109
cnt_list[m_val] += 1
123-
current_offset_for_mask += mul * (2 * l + 1)
110+
cur += mul * (2 * l + 1)
124111
self.register_buffer('m_in_mask', m_in_mask)
125112
self.register_buffer('cnt', torch.tensor(cnt_list, dtype=torch.long))
113+
self.register_buffer('m_idx', torch.cat([torch.tensor([0], dtype=torch.long), torch.cumsum(torch.tensor(cnt_list, dtype=torch.long), dim=0)]))
126114

127-
m_idx = torch.cat([torch.tensor([0], dtype=torch.long), torch.cumsum(self.cnt, dim=0)])
128-
self.register_buffer('m_idx', m_idx)
129-
115+
# m-out mask
130116
m_out_mask = torch.zeros(irreps_out_s.lmax + 1, self.out_dim, dtype=torch.bool)
131-
current_offset_for_mask = 0
132-
for i in range(len(irreps_out_s)):
133-
mul, (l, p_val) = irreps_out_s[i]
134-
for k_mul in range(mul):
135-
base_idx = current_offset_for_mask + k_mul * (2 * l + 1)
117+
cur = 0
118+
for mul, (l, _) in irreps_out_s:
119+
for k in range(mul):
120+
base = cur + k * (2 * l + 1)
136121
for m_val in range(l + 1):
137122
if m_val <= irreps_in_s.lmax:
138123
if m_val == 0:
139-
m_out_mask[m_val, base_idx + l] = True
124+
m_out_mask[m_val, base + l] = True
140125
else:
141-
m_out_mask[m_val, base_idx + l + m_val] = True
142-
m_out_mask[m_val, base_idx + l - m_val] = True
143-
current_offset_for_mask += mul * (2 * l + 1)
126+
m_out_mask[m_val, base + l + m_val] = True
127+
m_out_mask[m_val, base + l - m_val] = True
128+
cur += mul * (2 * l + 1)
144129
self.register_buffer('m_out_mask', m_out_mask)
145130

131+
# fc0 and m_linears
146132
self.fc0 = Linear(self.in_num_irreps, self.out_num_irreps, bias=True)
133+
self.m_linears = nn.ModuleList([SO2_m_Linear(mv, irreps_in_s, irreps_out_s) for mv in range(1, irreps_out_s.lmax + 1)])
147134

148-
self.m_linears = nn.ModuleList([
149-
SO2_m_Linear(m, irreps_in_s, irreps_out_s) for m in range(1, irreps_out_s.lmax + 1)
150-
])
151-
135+
# radial embedding
152136
if self.has_radial:
153-
if latent_dim <= 0:
154-
raise ValueError("latent_dim must be > 0 if radial_emb is True")
155137
layers_list: List[nn.Module] = []
156-
current_ch_radial = latent_dim
157-
all_radial_net_channels = radial_channels + [int(m_idx[-1].item())]
158-
for i, next_ch_radial in enumerate(all_radial_net_channels):
159-
layers_list.append(Linear(current_ch_radial, next_ch_radial, bias=True))
160-
current_ch_radial = next_ch_radial
161-
if i < len(all_radial_net_channels) - 1:
162-
layers_list.append(nn.LayerNorm(next_ch_radial))
138+
current_dim = latent_dim
139+
all_radial_layer_dims = (radial_channels if radial_channels is not None else []) + [int(self.m_idx[-1].item())]
140+
for i, out_ch in enumerate(all_radial_layer_dims):
141+
layers_list.append(Linear(current_dim, out_ch, bias=True))
142+
current_dim = out_ch
143+
if i < len(all_radial_layer_dims) - 1: # Not the last layer
144+
layers_list.append(nn.LayerNorm(out_ch))
163145
layers_list.append(nn.SiLU())
164-
self.radial: nn.Module = nn.Sequential(*layers_list)
146+
self.radial = nn.Sequential(*layers_list)
165147
else:
166-
self.radial: nn.Module = nn.Identity() # Explicitly type self.radial here for clarity
148+
self.radial = nn.Identity()
167149

168150
def _wigner(self, l: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
169151
J = self.Jd[l].to(dtype=alpha.dtype, device=alpha.device)
@@ -179,114 +161,67 @@ def forward(
179161
alpha, beta = xyz_to_angles(R[:, [1, 2, 0]])
180162
gamma = torch.zeros_like(alpha)
181163

182-
# MODIFIED PART FOR w CALCULATION
183-
w: Optional[torch.Tensor] = None
164+
# initialize radial weights tensor to empty or computed
165+
w = torch.ones(n, int(self.m_idx[-1].item()), dtype=x.dtype, device=x.device)
184166
if self.has_radial:
185167
if latents is None:
186-
raise RuntimeError("`latents` must be provided and be a Tensor when `radial_emb=True`")
168+
raise RuntimeError("`latents` must be provided when `radial_emb=True`")
187169
w = self.radial(latents)
188-
# END OF MODIFIED PART
189170

190-
x_rot = x.clone()
171+
# initialize x_rot to zero
172+
x_rot = torch.zeros_like(x)
191173
for i in range(len(self.in_mul)):
192174
start = int(self.in_offsets[i].item())
193175
end = int(self.in_offsets[i + 1].item())
194176
mul = int(self.in_mul[i].item())
195177
l_val = int(self.in_l[i].item())
196-
197178
if l_val > 0:
198-
rot_mat = self._wigner(l_val, alpha, beta, gamma)
199-
vals = x_rot[:, start:end].reshape(n, mul, 2 * l_val + 1)
200-
rotated_vals = torch.einsum('nji,nmj->nmi', rot_mat, vals)
201-
x_rot[:, start:end] = rotated_vals.reshape(n, -1)
179+
rot = self._wigner(l_val, alpha, beta, gamma)
180+
vals = x[:, start:end].reshape(n, mul, 2 * l_val + 1)
181+
x_rot[:, start:end] = torch.einsum('nji,nmj->nmi', rot, vals).reshape(n, -1)
202182

203183
out = x.new_zeros(n, self.out_dim)
204-
205-
seg0_raw = x_rot[:, self.m_in_mask[0]]
206-
seg0_for_fc0 = seg0_raw.clone()
207-
208-
current_col_in_seg0 = 0
209-
for i_irrep in range(len(self.in_l)):
210-
l_val_of_input_irrep = int(self.in_l[i_irrep].item())
211-
mul_of_input_irrep = int(self.in_mul[i_irrep].item())
212-
213-
if l_val_of_input_irrep == 0:
214-
seg0_for_fc0[:, current_col_in_seg0: current_col_in_seg0 + mul_of_input_irrep] = 0.0
215-
current_col_in_seg0 += mul_of_input_irrep
216-
217-
if w is not None:
218-
start_w = int(self.m_idx[0].item())
219-
end_w = int(self.m_idx[1].item())
220-
w_m0 = w[:, start_w:end_w]
221-
if seg0_for_fc0.size(1) == w_m0.size(1): # Ensure dimensions match for broadcasting/element-wise mul
222-
seg0_for_fc0 = seg0_for_fc0 * w_m0
223-
elif seg0_for_fc0.size(1) != 0 and w_m0.size(1) != 0: # Both non-zero but mismatch
224-
raise RuntimeError(
225-
f"Dimension mismatch for radial weights at m=0: seg0 has {seg0_for_fc0.size(1)}, w_m0 has {w_m0.size(1)}")
226-
# If one is zero dim, multiplication might be okay or do nothing, depends on exact case.
227-
# For safety, only multiply if dims match and are non-zero. If seg0 is empty, w_m0 should also be.
228-
229-
out[:, self.m_out_mask[0]] += self.fc0(seg0_for_fc0)
230-
231-
for idx, m_linear_layer in enumerate(self.m_linears):
184+
# m=0
185+
seg0 = x_rot[:, self.m_in_mask[0]]
186+
if w is not None and seg0.numel() > 0:
187+
seg0 = seg0 * w[:, self.m_idx[0]:self.m_idx[1]]
188+
out[:, self.m_out_mask[0]] += self.fc0(seg0)
189+
# m>0
190+
for idx, layer in enumerate(self.m_linears):
232191
m_val = idx + 1
233-
if self.m_in_mask[m_val].any():
234-
seg_m = x_rot[:, self.m_in_mask[m_val]].reshape(n, 2, -1)
235-
236-
if w is not None:
237-
start_w = int(self.m_idx[m_val].item())
238-
end_w = int(self.m_idx[m_val + 1].item())
239-
w_slice = w[:, start_w:end_w]
240-
if seg_m.size(2) == w_slice.size(1) and seg_m.size(2) > 0:
241-
seg_m = seg_m * w_slice.unsqueeze(1)
242-
elif seg_m.size(2) != 0 and w_slice.size(1) != 0:
243-
raise RuntimeError(
244-
f"Dimension mismatch for radial weights at m={m_val}: seg_m has {seg_m.size(2)}, w_slice has {w_slice.size(1)}")
245-
246-
processed_seg_m = m_linear_layer(seg_m).reshape(n, -1)
247-
out[:, self.m_out_mask[m_val]] += processed_seg_m
248-
192+
mask = self.m_in_mask[m_val]
193+
if mask.any():
194+
seg = x_rot[:, mask].reshape(n, 2, -1)
195+
if w is not None and seg.numel() > 0:
196+
seg = seg * w[:, self.m_idx[m_val]:self.m_idx[m_val+1]].unsqueeze(1)
197+
out[:, self.m_out_mask[m_val]] += layer(seg).reshape(n, -1)
198+
# final rotation
249199
for i in range(len(self.out_mul)):
250200
start = int(self.out_offsets[i].item())
251201
end = int(self.out_offsets[i + 1].item())
252-
mul = int(self.out_mul[i].item())
253202
l_val = int(self.out_l[i].item())
254-
203+
mul = int(self.out_mul[i].item())
255204
if l_val > 0:
256-
rot_mat = self._wigner(l_val, alpha, beta, gamma)
205+
rot = self._wigner(l_val, alpha, beta, gamma)
257206
vals = out[:, start:end].reshape(n, mul, 2 * l_val + 1)
258-
out[:, start:end] = torch.einsum('nji,nmj->nmi', rot_mat, vals).reshape(n, -1)
259-
207+
out[:, start:end] = torch.einsum('nji,nmj->nmi', rot, vals).reshape(n, -1)
260208
return out
261209

262210

263211
@compile_mode("script")
264212
class SO2_m_Linear(nn.Module):
265-
def __init__(self, m: int, irreps_in_s: Irreps, irreps_out_s: Irreps):
213+
def __init__(self, m_val: int, irreps_in_s: Irreps, irreps_out_s: Irreps):
266214
super().__init__()
267-
num_in = sum(mul for mul, (l, p_val) in irreps_in_s if l >= m)
268-
num_out = sum(mul for mul, (l, p_val) in irreps_out_s if l >= m)
269-
215+
# count input/output channels for order m_val
216+
num_in = sum(mul for mul, (l, _) in irreps_in_s if l >= m_val)
217+
num_out = sum(mul for mul, (l, _) in irreps_out_s if l >= m_val)
270218
self.fc = Linear(num_in, 2 * num_out, bias=False)
271219
if num_in > 0 and num_out > 0:
272220
self.fc.weight.data.mul_(1.0 / math.sqrt(2.0))
273221

274222
def forward(self, x: torch.Tensor) -> torch.Tensor:
275-
if x.size(2) == 0:
276-
if self.fc.out_features == 0:
277-
return torch.empty((x.size(0), 2, 0), dtype=x.dtype, device=x.device)
278-
# If num_in is 0, but num_out > 0, fc(x) will still produce output of shape [N, 2, 2*num_out]
279-
# where the input to fc was effectively zeros.
280-
# So, proceed with fc(x) even if x.size(2) == 0, as fc handles it.
281-
282223
y = self.fc(x)
283-
284-
num_out_channels = y.size(2) // 2
285-
if num_out_channels == 0:
286-
return torch.empty((x.size(0), 2, 0), dtype=x.dtype, device=x.device)
287-
288-
out_re = y[:, 0, :num_out_channels] - y[:, 1, num_out_channels:]
289-
out_im = y[:, 0, num_out_channels:] + y[:, 1, :num_out_channels]
290-
291-
return torch.stack((out_re, out_im), dim=1)
292-
224+
num_out = y.size(2) // 2
225+
re = y[:, 0, :num_out] - y[:, 1, num_out:]
226+
im = y[:, 0, num_out:] + y[:, 1, :num_out]
227+
return torch.stack((re, im), dim=1)

0 commit comments

Comments
 (0)