Skip to content

Commit 1de82e7

Browse files
Update unitcellfilter fire and add an example for it
1 parent 2466c5c commit 1de82e7

File tree

2 files changed

+274
-114
lines changed

2 files changed

+274
-114
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Import dependencies
2+
import numpy as np
3+
import torch
4+
from ase.build import bulk
5+
6+
# Import torchsim models and optimizers
7+
from torchsim.models.mace import UnbatchedMaceModel
8+
from torchsim.neighbors import vesin_nl_ts
9+
from torchsim.unbatched_optimizers import unit_cell_fire
10+
from torchsim.units import UnitConversion
11+
12+
from mace.calculators.foundations_models import mace_mp
13+
14+
# Set device and data type
15+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16+
dtype = torch.float32
17+
18+
# Option 1: Load the raw model from the downloaded model
19+
mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
20+
loaded_model = mace_mp(
21+
model=mace_checkpoint_url,
22+
return_raw_model=True,
23+
default_dtype=dtype,
24+
device=device,
25+
)
26+
27+
# Option 2: Load from local file (comment out Option 1 to use this)
28+
# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model"
29+
# loaded_model = torch.load(MODEL_PATH, map_location=device)
30+
31+
PERIODIC = True
32+
33+
# Create diamond cubic Silicon with random displacements and a 5% volume compression
34+
rng = np.random.default_rng()
35+
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
36+
si_dc.positions = si_dc.positions + 0.2 * rng.standard_normal(si_dc.positions.shape)
37+
si_dc.cell = si_dc.cell.array * 0.95
38+
39+
# Prepare input tensors
40+
positions = torch.tensor(si_dc.positions, device=device, dtype=dtype)
41+
cell = torch.tensor(si_dc.cell.array, device=device, dtype=dtype)
42+
atomic_numbers = torch.tensor(si_dc.get_atomic_numbers(), device=device, dtype=torch.int)
43+
masses = torch.tensor(si_dc.get_masses(), device=device, dtype=dtype)
44+
45+
# Initialize the unbatched MACE model
46+
model = UnbatchedMaceModel(
47+
model=loaded_model,
48+
device=device,
49+
neighbor_list_fn=vesin_nl_ts,
50+
periodic=PERIODIC,
51+
compute_force=True,
52+
compute_stress=True,
53+
dtype=dtype,
54+
enable_cueq=False,
55+
)
56+
57+
# Run initial inference
58+
results = model(positions=positions, cell=cell, atomic_numbers=atomic_numbers)
59+
60+
state = {
61+
"positions": positions,
62+
"masses": masses,
63+
"cell": cell,
64+
"pbc": PERIODIC,
65+
"atomic_numbers": atomic_numbers,
66+
}
67+
# Initialize FIRE optimizer for structural relaxation
68+
fire_init, fire_update = unit_cell_fire(
69+
model=model,
70+
)
71+
72+
state = fire_init(state=state)
73+
74+
# Run optimization loop
75+
for step in range(1_000):
76+
if step % 10 == 0:
77+
PE = state.energy.item()
78+
P = torch.trace(state.stress).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa
79+
print(f"{step=}: Total energy: {PE} eV, pressure: {P} GPa")
80+
state = fire_update(state)
81+
82+
print(f"Initial energy: {results['energy'].item()} eV")
83+
print(f"Final energy: {state.energy.item()} eV")
84+
85+
86+
print(f"Initial max force: {torch.max(torch.abs(results['forces'])).item()} eV/Å")
87+
print(f"Final max force: {torch.max(torch.abs(state.forces)).item()} eV/Å")
88+
89+
print(
90+
f"Initial pressure: {torch.trace(results['stress']).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa} GPa"
91+
)
92+
print(
93+
f"Final pressure: {torch.trace(state.stress).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa} GPa"
94+
)

0 commit comments

Comments
 (0)