Skip to content

Commit 7944feb

Browse files
Update optimizers (#9)
* Update gradient descent * Update fire, examples for it and run unbatched optimizer tests in CI * Update unitcellfilter fire and add an example for it * Update batched gd and it's example * Update batched unit cell gradient decent * Update batched unit cell fire
1 parent b54236f commit 7944feb

12 files changed

+1126
-1065
lines changed

examples/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,24 @@
7878
# Run initial simulation and get results
7979
results = model(positions=positions, cell=cell, atomic_numbers=atomic_numbers)
8080

81+
state = {
82+
"positions": positions,
83+
"masses": masses,
84+
"cell": cell,
85+
"pbc": PERIODIC,
86+
"atomic_numbers": atomic_numbers,
87+
}
88+
8189
# Initialize FIRE (Fast Inertial Relaxation Engine) optimizer
8290
# FIRE is an efficient method for finding local energy minima in molecular systems
83-
state, fire_update = fire(
91+
fire_init, fire_update = fire(
8492
model=model,
85-
positions=positions,
86-
masses=masses,
87-
cell=cell,
88-
pbc=PERIODIC,
8993
dt_start=0.005, # Initial timestep
9094
dt_max=0.01, # Maximum timestep
91-
atomic_numbers=atomic_numbers,
9295
)
9396

97+
state = fire_init(state=state)
98+
9499
# Run optimization for 1000 steps
95100
for step in range(2_000):
96101
if step % 100 == 0:

examples/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
# Cu atomic mass in atomic mass units
6363
masses = torch.full((positions.shape[0],), 63.546, device=device, dtype=dtype)
6464

65-
6665
# Initialize the Soft Sphere model
6766
model = SoftSphereModel(
6867
sigma=2.5,
@@ -77,21 +76,23 @@
7776
# Run initial simulation and get results
7877
results = model(positions=positions, cell=cell, atomic_numbers=atomic_numbers)
7978

80-
print(f"Initial Energy: {results['energy']}")
81-
print(f"Initial Forces shape: {results['forces'].shape}")
79+
state = {
80+
"positions": positions,
81+
"masses": masses,
82+
"cell": cell,
83+
"pbc": PERIODIC,
84+
"atomic_numbers": atomic_numbers,
85+
}
8286

8387
# Initialize FIRE (Fast Inertial Relaxation Engine) optimizer
84-
state, fire_update = fire(
88+
fire_init, fire_update = fire(
8589
model=model,
86-
positions=positions,
87-
masses=masses,
88-
cell=cell,
89-
atomic_numbers=atomic_numbers,
90-
pbc=PERIODIC,
9190
dt_start=0.005, # Initial timestep
9291
dt_max=0.01, # Maximum timestep
9392
)
9493

94+
state = fire_init(state=state)
95+
9596
# Run optimization for 2000 steps
9697
for step in range(2_000):
9798
if step % 100 == 0:

examples/2_Structural_optimization/2.3_MACE_FIRE.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,23 @@
5555
# Run initial inference
5656
results = model(positions=positions, cell=cell, atomic_numbers=atomic_numbers)
5757

58+
state = {
59+
"positions": positions,
60+
"masses": masses,
61+
"cell": cell,
62+
"pbc": PERIODIC,
63+
"atomic_numbers": atomic_numbers,
64+
}
5865
# Initialize FIRE optimizer for structural relaxation
59-
state, fire_update = fire(
66+
fire_init, fire_update = fire(
6067
model=model,
61-
positions=positions,
62-
masses=masses,
63-
cell=cell,
64-
pbc=PERIODIC,
65-
atomic_numbers=atomic_numbers,
6668
)
6769

70+
state = fire_init(state=state)
71+
6872
# Run optimization loop
69-
for step in range(2_000):
70-
if step % 100 == 0:
73+
for step in range(1_000):
74+
if step % 10 == 0:
7175
print(f"{step=}: Total energy: {state.energy.item()} eV")
7276
state = fire_update(state)
7377

examples/2_Structural_optimization/2.4_Batched_MACE_Gradient_Desent.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from ase.build import bulk
55

66
# Import torchsim models and neighbors list
7-
from torchsim.models.mace import MaceModel, UnbatchedMaceModel
7+
from torchsim.models.mace import MaceModel
88
from torchsim.neighbors import vesin_nl_ts
9-
from torchsim.optimizers import batched_gradient_descent
10-
from torchsim.unbatched_optimizers import gradient_descent
9+
from torchsim.optimizers import gradient_descent
10+
from torchsim.runners import atoms_to_state
1111

1212
from mace.calculators.foundations_models import mace_mp
1313

@@ -33,14 +33,14 @@
3333
rng = np.random.default_rng()
3434

3535
# Create diamond cubic Silicon systems
36-
si_dc = bulk("Si", "diamond", a=5.43).repeat((4, 4, 4))
36+
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
3737
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)
3838

39-
si_dc_small = bulk("Si", "diamond", a=5.43).repeat((3, 3, 3))
40-
si_dc_small.positions += 0.2 * rng.standard_normal(si_dc_small.positions.shape)
39+
fe = bulk("Fe", "bcc", a=2.8665, cubic=True).repeat((3, 3, 3))
40+
fe.positions += 0.2 * rng.standard_normal(fe.positions.shape)
4141

4242
# Create a list of our atomic systems
43-
atoms_list = [si_dc, si_dc, si_dc_small]
43+
atoms_list = [si_dc, fe]
4444

4545
# Create batched model
4646
batched_model = MaceModel(
@@ -54,19 +54,7 @@
5454
enable_cueq=False,
5555
)
5656

57-
# Create unbatched model for comparison
58-
unbatched_model = UnbatchedMaceModel(
59-
model=loaded_model,
60-
atomic_numbers=si_dc.get_atomic_numbers(),
61-
device=device,
62-
neighbor_list_fn=vesin_nl_ts,
63-
periodic=PERIODIC,
64-
compute_force=True,
65-
compute_stress=True,
66-
dtype=dtype,
67-
enable_cueq=False,
68-
)
69-
57+
"""
7058
# Convert data to tensors
7159
positions_list = [
7260
torch.tensor(atoms.positions, device=device, dtype=dtype) for atoms in atoms_list
@@ -105,63 +93,37 @@
10593
batch_indices = torch.repeat_interleave(
10694
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
10795
)
96+
"""
97+
98+
state = atoms_to_state(atoms_list, device=device, dtype=dtype)
10899

109-
print(f"Positions shape: {positions.shape}")
110-
print(f"Cell shape: {cell.shape}")
111-
print(f"Batch indices shape: {batch_indices.shape}")
100+
print(f"Positions shape: {state.positions.shape}")
101+
print(f"Cell shape: {state.cell.shape}")
102+
print(f"Batch indices shape: {state.batch.shape}")
112103

113104
# Run initial inference
114105
results = batched_model(
115-
positions=positions, cell=cell, atomic_numbers=atomic_numbers, batch=batch_indices
106+
positions=state.positions,
107+
cell=state.cell,
108+
atomic_numbers=state.atomic_numbers,
109+
batch=state.batch,
116110
)
117111
# Use different learning rates for each batch
118112
learning_rate = 0.01
119-
learning_rates = torch.tensor(
120-
[learning_rate] * len(atoms_list), device=device, dtype=dtype
121-
)
122113

123114
# Initialize batched gradient descent optimizer
124-
batch_state, gd_update = batched_gradient_descent(
115+
gd_init, gd_update = gradient_descent(
125116
model=batched_model,
126-
positions=positions,
127-
cell=cell,
128-
atomic_numbers=atomic_numbers,
129-
masses=masses,
130-
batch=batch_indices,
131-
learning_rates=learning_rates,
117+
lr=learning_rate,
132118
)
133119

120+
state = gd_init(state)
134121
# Run batched optimization for a few steps
135122
print("\nRunning batched gradient descent:")
136123
for step in range(100):
137124
if step % 10 == 0:
138-
print(f"Step {step}, Energy: {batch_state.energy}")
139-
batch_state = gd_update(batch_state)
140-
141-
print(f"Final batched energy: {batch_state.energy}")
142-
143-
# Compare with unbatched optimization
144-
print("\nRunning unbatched gradient descent for comparison:")
145-
unbatched_pos = torch.tensor(si_dc.positions, device=device, dtype=dtype)
146-
unbatched_cell = torch.tensor(si_dc.cell.array, device=device, dtype=dtype)
147-
unbatched_masses = torch.tensor(si_dc.get_masses(), device=device, dtype=dtype)
148-
149-
state, single_gd_update = gradient_descent(
150-
model=unbatched_model,
151-
positions=unbatched_pos,
152-
masses=unbatched_masses,
153-
cell=unbatched_cell,
154-
pbc=PERIODIC,
155-
learning_rate=learning_rate,
156-
)
157-
158-
for step in range(100):
159-
if step % 10 == 0:
160-
print(f"Step {step}, Energy: {state.energy}")
161-
state = single_gd_update(state)
162-
163-
print(f"Final unbatched energy: {state.energy}")
125+
print(f"Step {step}, Energy: {[res.item() for res in state.energy]} eV")
126+
state = gd_update(state)
164127

165-
# Compare final results between batched and unbatched
166-
print("\nComparison between batched and unbatched results:")
167-
print(f"Energy difference: {torch.max(torch.abs(batch_state.energy[0] - state.energy))}")
128+
print(f"Initial energies: {[res.item() for res in results['energy']]} eV")
129+
print(f"Final energies: {[res.item() for res in state.energy]} eV")

0 commit comments

Comments
 (0)