Skip to content

Commit 27fb2a8

Browse files
committed
move E3 statistics initialization into dataset, optmize nested tensor support in hr2hk and eigvals compute
1 parent 4a93b57 commit 27fb2a8

File tree

6 files changed

+37
-29
lines changed

6 files changed

+37
-29
lines changed

dptb/data/AtomicData.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,6 @@ def _process_dict(kwargs, ignore_fields=[]):
305305
if num_frames > 1 and v.size(0) != num_frames:
306306
raise ValueError(f"Wrong shape for NESTED property {k}")
307307

308-
309-
310-
311-
312308

313309
class AtomicData(Data):
314310
"""A neighbor graph for points in (periodic triclinic) real space.

dptb/data/dataset/_default_dataset.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def raw_dir(self):
353353
# TODO: this is not implemented.
354354
return self.root
355355

356-
def E3statistics(self, decay=False):
356+
def E3statistics(self, model: torch.nn.Module=None, decay=False):
357357
assert self.transform is not None
358358
idp = self.transform
359359

@@ -369,6 +369,19 @@ def E3statistics(self, decay=False):
369369
stats["node"] = self._E3nodespecies_stat(typed_dataset=typed_dataset)
370370
stats["edge"] = self._E3edgespecies_stat(typed_dataset=typed_dataset, decay=decay)
371371

372+
if model is not None:
373+
# initilize the model param with statistics
374+
scalar_mask = torch.BoolTensor([ir.dim==1 for ir in model.idp.orbpair_irreps])
375+
node_shifts = stats["node"]["scalar_ave"]
376+
node_scales = stats["node"]["norm_ave"]
377+
node_scales[:,scalar_mask] = stats["node"]["scalar_std"]
378+
379+
edge_shifts = stats["edge"]["scalar_ave"]
380+
edge_scales = stats["edge"]["norm_ave"]
381+
edge_scales[:,scalar_mask] = stats["edge"]["scalar_std"]
382+
model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts)
383+
model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts)
384+
372385
return stats
373386

374387
def _E3edgespecies_stat(self, typed_dataset, decay):

dptb/entrypoints/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def train(
183183
# include the init model and from scratch
184184
# build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint.
185185
checkpoint = init_model if init_model else None
186-
model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
186+
model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"])
187+
train_datasets.E3statistics(model=model)
187188
trainer = Trainer(
188189
train_options=jdata["train_options"],
189190
common_options=jdata["common_options"],

dptb/nn/build.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
def build_model(
1212
checkpoint: str=None,
1313
model_options: dict={},
14-
common_options: dict={},
15-
statistics: dict=None
14+
common_options: dict={}
1615
):
1716
"""
1817
The build model method should composed of the following steps:
@@ -141,23 +140,8 @@ def build_model(
141140
if from_scratch:
142141
if init_nnenv:
143142
model = NNENV(**model_options, **common_options)
144-
145-
# do initialization from statistics if NNENV is e3tb and statistics is provided
146-
if model.method == "e3tb" and statistics is not None:
147-
scalar_mask = torch.BoolTensor([ir.dim==1 for ir in model.idp.orbpair_irreps])
148-
node_shifts = statistics["node"]["scalar_ave"]
149-
node_scales = statistics["node"]["norm_ave"]
150-
node_scales[:,scalar_mask] = statistics["node"]["scalar_std"]
151-
152-
edge_shifts = statistics["edge"]["scalar_ave"]
153-
edge_scales = statistics["edge"]["norm_ave"]
154-
edge_scales[:,scalar_mask] = statistics["edge"]["scalar_std"]
155-
model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts)
156-
model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts)
157-
158143
elif init_nnsk:
159144
model = NNSK(**model_options["nnsk"], **common_options)
160-
161145
elif init_mixed:
162146
model = MIX(**model_options, **common_options)
163147
elif init_dftbsk:

dptb/nn/energy.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,19 @@ def __init__(
5656

5757

5858
def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDataDict.Type:
59-
num_k = data[AtomicDataDict.KPOINT_KEY][0].shape[0]
60-
kpoints = data[AtomicDataDict.KPOINT_KEY][0] # slice the first dimension, since it is nested tensor by default
59+
kpoints = data[AtomicDataDict.KPOINT_KEY]
60+
if kpoints.is_nested:
61+
nested = True
62+
assert kpoints.size(0) == 1
63+
kpoints = kpoints[0]
64+
else:
65+
nested = False
66+
num_k = kpoints.shape[0]
6167
eigvals = []
6268
if nk is None:
6369
nk = num_k
6470
for i in range(int(np.ceil(num_k / nk))):
65-
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints[i*nk:(i+1)*nk]])
71+
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
6672
data = self.h2k(data)
6773
if self.overlap:
6874
data = self.s2k(data)
@@ -74,5 +80,9 @@ def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDa
7480

7581
eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
7682
data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
83+
if nested:
84+
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints])
85+
else:
86+
data[AtomicDataDict.KPOINT_KEY] = kpoints
7787

7888
return data

dptb/nn/hr2hk.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
5656
bondwise_hopping.to(self.device)
5757
bondwise_hopping.type(self.dtype)
5858
onsite_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
59+
kpoints = data[AtomicDataDict.KPOINT_KEY]
60+
if kpoints.is_nested:
61+
assert kpoints.size(0) == 1
62+
kpoints = kpoints[0]
5963

6064
soc = data.get(AtomicDataDict.NODE_SOC_SWITCH_KEY, False)
6165
if isinstance(soc, torch.Tensor):
@@ -111,7 +115,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
111115

112116
# R2K procedure can be done for all kpoint at once.
113117
all_norb = self.idp.atom_norb[data[AtomicDataDict.ATOM_TYPE_KEY]].sum()
114-
block = torch.zeros(data[AtomicDataDict.KPOINT_KEY][0].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
118+
block = torch.zeros(kpoints.shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
115119
# block = torch.complex(block, torch.zeros_like(block))
116120
# if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all():
117121
# block_uu = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
@@ -149,13 +153,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
149153
masked_hblock = hblock[imask][:,jmask]
150154

151155
block[:,iatom_indices,jatom_indices] += masked_hblock.squeeze(0).type_as(block) * \
152-
torch.exp(-1j * 2 * torch.pi * (data[AtomicDataDict.KPOINT_KEY][0] @ data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i])).reshape(-1,1,1)
156+
torch.exp(-1j * 2 * torch.pi * (kpoints @ data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i])).reshape(-1,1,1)
153157

154158
block = block + block.transpose(1,2).conj()
155159
block = block.contiguous()
156160

157161
if soc:
158-
HK_SOC = torch.zeros(data[AtomicDataDict.KPOINT_KEY][0].shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
162+
HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
159163
#HK_SOC[:,:all_norb,:all_norb] = block + block_uu
160164
#HK_SOC[:,:all_norb,all_norb:] = block_ud
161165
#HK_SOC[:,all_norb:,:all_norb] = block_ud.conj()

0 commit comments

Comments
 (0)