Skip to content

Commit 8811615

Browse files
add Trinity model and debug the lem (deepmodeling#256)
* update input and e3tb hands on doc * debug lem and add trinity * change init2b to only2b, move scalarmlpfunction to base.py * compute onsite shift mu from whole onsite block * drop E3statistics when model does is valid and does not have a scale shift * add edge component in onsite shift computation * feat(transforms): add full_mask_to_diag initialization for diagonal elements Initialize full_mask_to_diag tensor to track diagonal orbital pairs in the reduced matrix. This helps in identifying diagonal elements during further processing. * refactor(loss): extract shift_mu function and simplify onsite shift calculation The shift_mu function was extracted to avoid code duplication across multiple loss classes. The onsite shift calculation was simplified by using a more accurate formula that accounts for both node and edge features. * refactor(loss): simplify shift_mu calculation and usage remove redundant diagonal terms in shift_mu and update all dependent loss classes to use simplified return values --------- Co-authored-by: QG-phy <guqq_phy@qq.com> Co-authored-by: Qiangqiang Gu <98570179+QG-phy@users.noreply.github.com>
1 parent f1f2d2e commit 8811615

File tree

11 files changed

+1306
-306
lines changed

11 files changed

+1306
-306
lines changed

dptb/data/dataset/_default_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ def E3statistics(self, model: torch.nn.Module=None, decay=False):
401401

402402
if self.data[AtomicDataDict.EDGE_FEATURES_KEY].abs().sum() < 1e-7:
403403
return None
404+
405+
if model is not None:
406+
if not isinstance(model.node_prediction_h, torch.nn.Module):
407+
return None
404408

405409
typed_dataset = idp(self.data.clone().to_dict())
406410
e3h = E3Hamiltonian(basis=idp.basis, decompose=True)

dptb/data/dataset/lmdb_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def E3statistics(self, model: torch.nn.Module=None):
198198
if not self.get_Hamiltonian and not self.get_DM:
199199
return None
200200

201+
if model is not None:
202+
if not isinstance(model.node_prediction_h, torch.nn.Module):
203+
return None
204+
201205
assert self.transform is not None
202206
idp = self.transform
203207

dptb/data/transforms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,12 @@ def __init__(
644644
indices += sli.start
645645
assert indices.max() < sli.stop
646646
self.mask_to_ndiag[self.chemical_symbol_to_type[ib]][indices] = True
647-
647+
648+
self.full_mask_to_diag = torch.zeros(self.reduced_matrix_element, dtype=torch.bool, device=self.device)
649+
for orbs, islice in self.orbpair_maps.items():
650+
fio, fjo = orbs.split('-')
651+
if fio == fjo:
652+
self.full_mask_to_diag[islice] = True
648653

649654
def get_orbpairtype_maps(self):
650655
"""

dptb/nn/base.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from torch import Tensor
66
from dptb.utils.constants import dtype_dict
77
from dptb.utils.tools import _get_activation_fn
8+
from e3nn.util.codegen import CodeGenMixin
9+
from e3nn.math import normalize2mom
810
import torch.nn.functional as F
11+
import math
12+
from torch import fx
913
import torch.nn as nn
1014

1115
class AtomicLinear(torch.nn.Module):
@@ -466,4 +470,115 @@ def forward(self, x):
466470
x = layer(x)
467471
x = self.activation(x)
468472

469-
return self.out_layer(x)
473+
return self.out_layer(x)
474+
475+
class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
476+
"""Module implementing an MLP according to provided options."""
477+
478+
in_features: int
479+
out_features: int
480+
481+
def __init__(
482+
self,
483+
mlp_input_dimension: Optional[int],
484+
mlp_latent_dimensions: List[int],
485+
mlp_output_dimension: Optional[int],
486+
mlp_nonlinearity: Optional[str] = "silu",
487+
mlp_initialization: str = "normal",
488+
mlp_dropout_p: float = 0.0,
489+
mlp_batchnorm: bool = False,
490+
):
491+
super().__init__()
492+
nonlinearity = {
493+
None: None,
494+
"silu": torch.nn.functional.silu,
495+
"ssp": ShiftedSoftPlus,
496+
}[mlp_nonlinearity]
497+
if nonlinearity is not None:
498+
nonlin_const = normalize2mom(nonlinearity).cst
499+
else:
500+
nonlin_const = 1.0
501+
502+
dimensions = (
503+
([mlp_input_dimension] if mlp_input_dimension is not None else [])
504+
+ mlp_latent_dimensions
505+
+ ([mlp_output_dimension] if mlp_output_dimension is not None else [])
506+
)
507+
assert len(dimensions) >= 2 # Must have input and output dim
508+
num_layers = len(dimensions) - 1
509+
510+
self.in_features = dimensions[0]
511+
self.out_features = dimensions[-1]
512+
513+
# Code
514+
params = {}
515+
graph = fx.Graph()
516+
tracer = fx.proxy.GraphAppendingTracer(graph)
517+
518+
def Proxy(n):
519+
return fx.Proxy(n, tracer=tracer)
520+
521+
features = Proxy(graph.placeholder("x"))
522+
norm_from_last: float = 1.0
523+
524+
base = torch.nn.Module()
525+
526+
for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
527+
# do dropout
528+
if mlp_dropout_p > 0:
529+
# only dropout if it will do something
530+
# dropout before linear projection- https://stats.stackexchange.com/a/245137
531+
features = Proxy(graph.call_module("_dropout", (features.node,)))
532+
533+
# make weights
534+
w = torch.empty(h_in, h_out)
535+
536+
if mlp_initialization == "normal":
537+
w.normal_()
538+
elif mlp_initialization == "uniform":
539+
# these values give < x^2 > = 1
540+
w.uniform_(-math.sqrt(3), math.sqrt(3))
541+
elif mlp_initialization == "orthogonal":
542+
# this rescaling gives < x^2 > = 1
543+
torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
544+
else:
545+
raise NotImplementedError(
546+
f"Invalid mlp_initialization {mlp_initialization}"
547+
)
548+
549+
# generate code
550+
params[f"_weight_{layer}"] = w
551+
w = Proxy(graph.get_attr(f"_weight_{layer}"))
552+
w = w * (
553+
norm_from_last / math.sqrt(float(h_in))
554+
) # include any nonlinearity normalization from previous layers
555+
features = torch.matmul(features, w)
556+
557+
if mlp_batchnorm:
558+
# if we call batchnorm, do it after the nonlinearity
559+
features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
560+
setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
561+
562+
# generate nonlinearity code
563+
if nonlinearity is not None and layer < num_layers - 1:
564+
features = nonlinearity(features)
565+
# add the normalization const in next layer
566+
norm_from_last = nonlin_const
567+
568+
graph.output(features.node)
569+
570+
for pname, p in params.items():
571+
setattr(base, pname, torch.nn.Parameter(p))
572+
573+
if mlp_dropout_p > 0:
574+
# with normal dropout everything blows up
575+
base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
576+
577+
self._codegen_register({"_forward": fx.GraphModule(base, graph)})
578+
579+
def forward(self, x):
580+
return self._forward(x)
581+
582+
@torch.jit.script
583+
def ShiftedSoftPlus(x: torch.Tensor):
584+
return torch.nn.functional.softplus(x) - math.log(2.0)

dptb/nn/deeptb.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -169,33 +169,54 @@ def __init__(
169169
self.overlaponsite_param = overlaponsite_param
170170

171171
elif prediction_copy.get("method") == "e3tb":
172-
self.node_prediction_h = E3PerSpeciesScaleShift(
173-
field=AtomicDataDict.NODE_FEATURES_KEY,
174-
num_types=n_species,
175-
irreps_in=self.embedding.out_node_irreps,
176-
out_field = AtomicDataDict.NODE_FEATURES_KEY,
177-
shifts=0.,
178-
scales=1.,
179-
dtype=self.dtype,
180-
device=self.device,
181-
**prediction_copy,
182-
)
183-
184-
self.edge_prediction_h = E3PerEdgeSpeciesScaleShift(
185-
field=AtomicDataDict.EDGE_FEATURES_KEY,
186-
num_types=n_species,
187-
irreps_in=self.embedding.out_edge_irreps,
188-
out_field = AtomicDataDict.EDGE_FEATURES_KEY,
189-
shifts=0.,
190-
scales=1.,
191-
dtype=self.dtype,
192-
device=self.device,
193-
**prediction_copy,
194-
)
172+
if embedding.get("method") == "trinity":
173+
# hack to pass the dataset operation
174+
self.node_prediction_h = lambda x: x
175+
self.edge_prediction_h = lambda x: x
176+
self.node_prediction_h.set_scale_shift = lambda scales, shifts: 0
177+
self.edge_prediction_h.set_scale_shift = lambda scales, shifts: 0
178+
else:
179+
self.node_prediction_h = E3PerSpeciesScaleShift(
180+
field=AtomicDataDict.NODE_FEATURES_KEY,
181+
num_types=n_species,
182+
irreps_in=self.embedding.out_node_irreps,
183+
out_field = AtomicDataDict.NODE_FEATURES_KEY,
184+
shifts=0.,
185+
scales=1.,
186+
dtype=self.dtype,
187+
device=self.device,
188+
**prediction_copy,
189+
)
190+
191+
self.edge_prediction_h = E3PerEdgeSpeciesScaleShift(
192+
field=AtomicDataDict.EDGE_FEATURES_KEY,
193+
num_types=n_species,
194+
irreps_in=self.embedding.out_edge_irreps,
195+
out_field = AtomicDataDict.EDGE_FEATURES_KEY,
196+
shifts=0.,
197+
scales=1.,
198+
dtype=self.dtype,
199+
device=self.device,
200+
**prediction_copy,
201+
)
202+
203+
if embedding.get("method") == "trinity":
204+
self.idp_sk = OrbitalMapper(self.idp.basis, method="sktb", device=self.device)
205+
prediction_copy = prediction_copy.copy()
206+
prediction_copy["neurons"] = [self.embedding.latent_dim] + prediction_copy["neurons"] + [self.idp_sk.reduced_matrix_element]
207+
prediction_copy["config"] = get_neuron_config(prediction_copy["neurons"])
208+
self.edge_prediction_h2 = AtomicResNet(
209+
**prediction_copy,
210+
in_field=AtomicDataDict.EDGE_ATTRS_KEY,
211+
out_field=AtomicDataDict.EDGE_ATTRS_KEY,
212+
device=device,
213+
dtype=dtype
214+
)
195215

196216
if overlap:
197217
self.idp_sk = OrbitalMapper(self.idp.basis, method="sktb", device=self.device)
198218
self.idp_sk.get_skonsite_maps()
219+
prediction_copy = prediction.copy()
199220
prediction_copy["neurons"] = [self.embedding.latent_dim] + prediction_copy["neurons"] + [self.idp_sk.reduced_matrix_element]
200221
prediction_copy["config"] = get_neuron_config(prediction_copy["neurons"])
201222
self.edge_prediction_s = AtomicResNet(
@@ -257,6 +278,17 @@ def __init__(
257278
dtype=self.dtype,
258279
device=self.device,
259280
)
281+
if hasattr(self, "edge_prediction_h2"):
282+
self.h2miltonian = SKHamiltonian(
283+
idp_sk=self.idp_sk,
284+
edge_field=AtomicDataDict.EDGE_ATTRS_KEY,
285+
node_field=AtomicDataDict.NODE_ATTRS_KEY,
286+
onsite=True,
287+
strain=False,
288+
soc=False,
289+
dtype=self.dtype,
290+
device=self.device,
291+
)
260292

261293

262294
def forward(self, data: AtomicDataDict.Type):
@@ -274,10 +306,18 @@ def forward(self, data: AtomicDataDict.Type):
274306
data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
275307
data[AtomicDataDict.NODE_OVERLAP_KEY][:,self.idp_sk.mask_diag] = 1.
276308

309+
# prediction for two-body part of e3tb
310+
if hasattr(self, "edge_prediction_h2"):
311+
data = self.edge_prediction_h2(data)
312+
277313
if self.transform:
278314
data = self.hamiltonian(data)
279315
if hasattr(self, "overlap"):
280316
data = self.overlap(data)
317+
if hasattr(self, "edge_prediction_h2"):
318+
data = self.h2miltonian(data)
319+
data[AtomicDataDict.NODE_FEATURES_KEY] += data[AtomicDataDict.NODE_ATTRS_KEY]
320+
data[AtomicDataDict.EDGE_FEATURES_KEY] += data[AtomicDataDict.EDGE_ATTRS_KEY]
281321

282322
return data
283323

dptb/nn/embedding/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .e3baseline_local6 import E3BaseLineModel6
77
from .slem import Slem
88
from .lem import Lem
9+
from .trinity import Trinity
910
from .e3baseline_nonlocal import E3BaseLineModelNonLocal
1011

1112
__all__ = [
@@ -15,6 +16,7 @@
1516
"E3DeePH",
1617
"Lem",
1718
"Slem",
19+
"Trinity",
1820
"E3BaseLineModel6",
1921
"E3BaseLineModelNonLocal",
2022
]

0 commit comments

Comments
 (0)