|
| 1 | +# Import dependencies |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +from ase.build import bulk |
| 5 | + |
| 6 | +# Import torchsim models and optimizers |
| 7 | +from torchsim.models.mace import UnbatchedMaceModel |
| 8 | +from torchsim.neighbors import vesin_nl_ts |
| 9 | +from torchsim.unbatched_optimizers import unit_cell_fire |
| 10 | +from torchsim.units import UnitConversion |
| 11 | + |
| 12 | +from mace.calculators.foundations_models import mace_mp |
| 13 | + |
| 14 | +# Set device and data type |
| 15 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 16 | +dtype = torch.float32 |
| 17 | + |
| 18 | +# Option 1: Load the raw model from the downloaded model |
| 19 | +mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" |
| 20 | +loaded_model = mace_mp( |
| 21 | + model=mace_checkpoint_url, |
| 22 | + return_raw_model=True, |
| 23 | + default_dtype=dtype, |
| 24 | + device=device, |
| 25 | +) |
| 26 | + |
| 27 | +# Option 2: Load from local file (comment out Option 1 to use this) |
| 28 | +# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" |
| 29 | +# loaded_model = torch.load(MODEL_PATH, map_location=device) |
| 30 | + |
| 31 | +PERIODIC = True |
| 32 | + |
| 33 | +# Create diamond cubic Silicon with random displacements and a 5% volume compression |
| 34 | +rng = np.random.default_rng() |
| 35 | +si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) |
| 36 | +si_dc.positions = si_dc.positions + 0.2 * rng.standard_normal(si_dc.positions.shape) |
| 37 | +si_dc.cell = si_dc.cell.array * 0.95 |
| 38 | + |
| 39 | +# Prepare input tensors |
| 40 | +positions = torch.tensor(si_dc.positions, device=device, dtype=dtype) |
| 41 | +cell = torch.tensor(si_dc.cell.array, device=device, dtype=dtype) |
| 42 | +atomic_numbers = torch.tensor(si_dc.get_atomic_numbers(), device=device, dtype=torch.int) |
| 43 | +masses = torch.tensor(si_dc.get_masses(), device=device, dtype=dtype) |
| 44 | + |
| 45 | +# Initialize the unbatched MACE model |
| 46 | +model = UnbatchedMaceModel( |
| 47 | + model=loaded_model, |
| 48 | + device=device, |
| 49 | + neighbor_list_fn=vesin_nl_ts, |
| 50 | + periodic=PERIODIC, |
| 51 | + compute_force=True, |
| 52 | + compute_stress=True, |
| 53 | + dtype=dtype, |
| 54 | + enable_cueq=False, |
| 55 | +) |
| 56 | + |
| 57 | +# Run initial inference |
| 58 | +results = model(positions=positions, cell=cell, atomic_numbers=atomic_numbers) |
| 59 | + |
| 60 | +state = { |
| 61 | + "positions": positions, |
| 62 | + "masses": masses, |
| 63 | + "cell": cell, |
| 64 | + "pbc": PERIODIC, |
| 65 | + "atomic_numbers": atomic_numbers, |
| 66 | +} |
| 67 | +# Initialize FIRE optimizer for structural relaxation |
| 68 | +fire_init, fire_update = unit_cell_fire( |
| 69 | + model=model, |
| 70 | +) |
| 71 | + |
| 72 | +state = fire_init(state=state) |
| 73 | + |
| 74 | +# Run optimization loop |
| 75 | +for step in range(1_000): |
| 76 | + if step % 10 == 0: |
| 77 | + PE = state.energy.item() |
| 78 | + P = torch.trace(state.stress).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa |
| 79 | + print(f"{step=}: Total energy: {PE} eV, pressure: {P} GPa") |
| 80 | + state = fire_update(state) |
| 81 | + |
| 82 | +print(f"Initial energy: {results['energy'].item()} eV") |
| 83 | +print(f"Final energy: {state.energy.item()} eV") |
| 84 | + |
| 85 | + |
| 86 | +print(f"Initial max force: {torch.max(torch.abs(results['forces'])).item()} eV/Å") |
| 87 | +print(f"Final max force: {torch.max(torch.abs(state.forces)).item()} eV/Å") |
| 88 | + |
| 89 | +print( |
| 90 | + f"Initial pressure: {torch.trace(results['stress']).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa} GPa" |
| 91 | +) |
| 92 | +print( |
| 93 | + f"Final pressure: {torch.trace(state.stress).item() / 3.0 * UnitConversion.eV_per_Ang3_to_GPa} GPa" |
| 94 | +) |
0 commit comments