Skip to content

Commit 9a9ebbf

Browse files
committed
update temp
1 parent c9fb3a2 commit 9a9ebbf

File tree

4 files changed

+11
-9
lines changed

4 files changed

+11
-9
lines changed

dptb/data/interfaces/ham_to_feature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def feature_to_block(data, idp):
405405
blocks[block_index] = block.T
406406
else:
407407
blocks[block_index] += block.T
408+
408409
return blocks
409410

410411

dptb/entrypoints/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dptb.utils.argcheck import normalize_run
1010
from dptb.utils.tools import j_loader
1111
from dptb.utils.tools import j_must_have
12-
from dptb.postprocess.write_ham import write_ham
12+
from dptb.postprocess.write_block import write_block
1313
import torch
1414
import h5py
1515

@@ -88,7 +88,7 @@ def run(
8888

8989
elif task=='write_block':
9090
task = torch.load(init_model, map_location="cpu")["task"]
91-
block = write_ham(data=struct_file, AtomicData_options=jdata['AtomicData_options'], model=model, device=jdata["device"])
91+
block = write_block(data=struct_file, AtomicData_options=jdata['AtomicData_options'], model=model, device=jdata["device"])
9292
# write to h5 file, block is a dict, write to a h5 file
9393
with h5py.File(os.path.join(results_path, task+".h5"), 'w') as fid:
9494
default_group = fid.create_group("0")

dptb/postprocess/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .bandstructure import Band
22
from .totbplas import TBPLaS
3-
from .write_ham import write_ham
3+
from .write_block import write_block
44

55

66
__all__ = [
77
Band,
88
TBPLaS,
9-
write_ham,
9+
write_block,
1010

1111
]

dptb/postprocess/write_ham.py renamed to dptb/postprocess/write_block.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
log = logging.getLogger(__name__)
1515

16-
def write_ham(
16+
def write_block(
1717
data: Union[AtomicData, ase.Atoms, str],
1818
model: torch.nn.Module,
1919
AtomicData_options: dict={},
@@ -35,11 +35,12 @@ def write_ham(
3535
data = data
3636

3737
data = AtomicData.to_AtomicDataDict(data.to(device))
38-
data = model.idp(data)
38+
with torch.no_grad():
39+
data = model.idp(data)
3940

40-
# set the kpoint of the AtomicData
41-
data = model(data)
42-
block = feature_to_block(data=data, idp=model.idp)
41+
# set the kpoint of the AtomicData
42+
data = model(data)
43+
block = feature_to_block(data=data, idp=model.idp)
4344

4445
return block
4546

0 commit comments

Comments
 (0)