|
4 | 4 | from ase.build import bulk
|
5 | 5 |
|
6 | 6 | # Import torchsim models and neighbors list
|
7 |
| -from torchsim.models.mace import MaceModel, UnbatchedMaceModel |
| 7 | +from torchsim.models.mace import MaceModel |
8 | 8 | from torchsim.neighbors import vesin_nl_ts
|
9 |
| -from torchsim.optimizers import batched_gradient_descent |
10 |
| -from torchsim.unbatched_optimizers import gradient_descent |
| 9 | +from torchsim.optimizers import gradient_descent |
| 10 | +from torchsim.runners import atoms_to_state |
11 | 11 |
|
12 | 12 | from mace.calculators.foundations_models import mace_mp
|
13 | 13 |
|
|
33 | 33 | rng = np.random.default_rng()
|
34 | 34 |
|
35 | 35 | # Create diamond cubic Silicon systems
|
36 |
| -si_dc = bulk("Si", "diamond", a=5.43).repeat((4, 4, 4)) |
| 36 | +si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) |
37 | 37 | si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)
|
38 | 38 |
|
39 |
| -si_dc_small = bulk("Si", "diamond", a=5.43).repeat((3, 3, 3)) |
40 |
| -si_dc_small.positions += 0.2 * rng.standard_normal(si_dc_small.positions.shape) |
| 39 | +fe = bulk("Fe", "bcc", a=2.8665, cubic=True).repeat((3, 3, 3)) |
| 40 | +fe.positions += 0.2 * rng.standard_normal(fe.positions.shape) |
41 | 41 |
|
42 | 42 | # Create a list of our atomic systems
|
43 |
| -atoms_list = [si_dc, si_dc, si_dc_small] |
| 43 | +atoms_list = [si_dc, fe] |
44 | 44 |
|
45 | 45 | # Create batched model
|
46 | 46 | batched_model = MaceModel(
|
|
54 | 54 | enable_cueq=False,
|
55 | 55 | )
|
56 | 56 |
|
57 |
| -# Create unbatched model for comparison |
58 |
| -unbatched_model = UnbatchedMaceModel( |
59 |
| - model=loaded_model, |
60 |
| - atomic_numbers=si_dc.get_atomic_numbers(), |
61 |
| - device=device, |
62 |
| - neighbor_list_fn=vesin_nl_ts, |
63 |
| - periodic=PERIODIC, |
64 |
| - compute_force=True, |
65 |
| - compute_stress=True, |
66 |
| - dtype=dtype, |
67 |
| - enable_cueq=False, |
68 |
| -) |
69 |
| - |
| 57 | +""" |
70 | 58 | # Convert data to tensors
|
71 | 59 | positions_list = [
|
72 | 60 | torch.tensor(atoms.positions, device=device, dtype=dtype) for atoms in atoms_list
|
|
105 | 93 | batch_indices = torch.repeat_interleave(
|
106 | 94 | torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
|
107 | 95 | )
|
| 96 | +""" |
| 97 | + |
| 98 | +state = atoms_to_state(atoms_list, device=device, dtype=dtype) |
108 | 99 |
|
109 |
| -print(f"Positions shape: {positions.shape}") |
110 |
| -print(f"Cell shape: {cell.shape}") |
111 |
| -print(f"Batch indices shape: {batch_indices.shape}") |
| 100 | +print(f"Positions shape: {state.positions.shape}") |
| 101 | +print(f"Cell shape: {state.cell.shape}") |
| 102 | +print(f"Batch indices shape: {state.batch.shape}") |
112 | 103 |
|
113 | 104 | # Run initial inference
|
114 | 105 | results = batched_model(
|
115 |
| - positions=positions, cell=cell, atomic_numbers=atomic_numbers, batch=batch_indices |
| 106 | + positions=state.positions, |
| 107 | + cell=state.cell, |
| 108 | + atomic_numbers=state.atomic_numbers, |
| 109 | + batch=state.batch, |
116 | 110 | )
|
117 | 111 | # Use different learning rates for each batch
|
118 | 112 | learning_rate = 0.01
|
119 |
| -learning_rates = torch.tensor( |
120 |
| - [learning_rate] * len(atoms_list), device=device, dtype=dtype |
121 |
| -) |
122 | 113 |
|
123 | 114 | # Initialize batched gradient descent optimizer
|
124 |
| -batch_state, gd_update = batched_gradient_descent( |
| 115 | +gd_init, gd_update = gradient_descent( |
125 | 116 | model=batched_model,
|
126 |
| - positions=positions, |
127 |
| - cell=cell, |
128 |
| - atomic_numbers=atomic_numbers, |
129 |
| - masses=masses, |
130 |
| - batch=batch_indices, |
131 |
| - learning_rates=learning_rates, |
| 117 | + lr=learning_rate, |
132 | 118 | )
|
133 | 119 |
|
| 120 | +state = gd_init(state) |
134 | 121 | # Run batched optimization for a few steps
|
135 | 122 | print("\nRunning batched gradient descent:")
|
136 | 123 | for step in range(100):
|
137 | 124 | if step % 10 == 0:
|
138 |
| - print(f"Step {step}, Energy: {batch_state.energy}") |
139 |
| - batch_state = gd_update(batch_state) |
140 |
| - |
141 |
| -print(f"Final batched energy: {batch_state.energy}") |
142 |
| - |
143 |
| -# Compare with unbatched optimization |
144 |
| -print("\nRunning unbatched gradient descent for comparison:") |
145 |
| -unbatched_pos = torch.tensor(si_dc.positions, device=device, dtype=dtype) |
146 |
| -unbatched_cell = torch.tensor(si_dc.cell.array, device=device, dtype=dtype) |
147 |
| -unbatched_masses = torch.tensor(si_dc.get_masses(), device=device, dtype=dtype) |
148 |
| - |
149 |
| -state, single_gd_update = gradient_descent( |
150 |
| - model=unbatched_model, |
151 |
| - positions=unbatched_pos, |
152 |
| - masses=unbatched_masses, |
153 |
| - cell=unbatched_cell, |
154 |
| - pbc=PERIODIC, |
155 |
| - learning_rate=learning_rate, |
156 |
| -) |
157 |
| - |
158 |
| -for step in range(100): |
159 |
| - if step % 10 == 0: |
160 |
| - print(f"Step {step}, Energy: {state.energy}") |
161 |
| - state = single_gd_update(state) |
162 |
| - |
163 |
| -print(f"Final unbatched energy: {state.energy}") |
| 125 | + print(f"Step {step}, Energy: {[res.item() for res in state.energy]} eV") |
| 126 | + state = gd_update(state) |
164 | 127 |
|
165 |
| -# Compare final results between batched and unbatched |
166 |
| -print("\nComparison between batched and unbatched results:") |
167 |
| -print(f"Energy difference: {torch.max(torch.abs(batch_state.energy[0] - state.energy))}") |
| 128 | +print(f"Initial energies: {[res.item() for res in results['energy']]} eV") |
| 129 | +print(f"Final energies: {[res.item() for res in state.energy]} eV") |
0 commit comments