|
| 1 | +"""Calculate Phonon DOS and band structure with MACE in batched mode.""" |
| 2 | + |
| 3 | +# /// script |
| 4 | +# dependencies = [ |
| 5 | +# "mace-torch>=0.3.11", |
| 6 | +# "phonopy>=2.35", |
| 7 | +# "pymatviz[export-figs]>=0.15.1", |
| 8 | +# "seekpath", |
| 9 | +# "ase", |
| 10 | +# ] |
| 11 | +# /// |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import pymatviz as pmv |
| 15 | +import seekpath |
| 16 | +import torch |
| 17 | +from ase import Atoms |
| 18 | +from ase.build import bulk |
| 19 | +from mace.calculators.foundations_models import mace_mp |
| 20 | +from phonopy import Phonopy |
| 21 | +from phonopy.phonon.band_structure import ( |
| 22 | + get_band_qpoints_and_path_connections, |
| 23 | + get_band_qpoints_by_seekpath, |
| 24 | +) |
| 25 | + |
| 26 | +from torch_sim import optimize |
| 27 | +from torch_sim.io import phonopy_to_state, state_to_phonopy |
| 28 | +from torch_sim.models.mace import MaceModel |
| 29 | +from torch_sim.neighbors import vesin_nl_ts |
| 30 | +from torch_sim.optimizers import frechet_cell_fire |
| 31 | + |
| 32 | + |
| 33 | +def get_qpts_and_connections( |
| 34 | + ase_atoms: Atoms, |
| 35 | + n_points: int = 101, |
| 36 | +) -> tuple[list[list[float]], list[bool]]: |
| 37 | + """Get the high symmetry points and path connections for the band structure.""" |
| 38 | + # Define seekpath data |
| 39 | + seekpath_data = seekpath.get_path( |
| 40 | + (ase_atoms.cell, ase_atoms.get_scaled_positions(), ase_atoms.numbers) |
| 41 | + ) |
| 42 | + |
| 43 | + # Extract high symmetry points and path |
| 44 | + points = seekpath_data["point_coords"] |
| 45 | + path = [] |
| 46 | + for segment in seekpath_data["path"]: |
| 47 | + start_point = points[segment[0]] |
| 48 | + end_point = points[segment[1]] |
| 49 | + path.append([start_point, end_point]) |
| 50 | + qpts, connections = get_band_qpoints_and_path_connections(path, npoints=n_points) |
| 51 | + |
| 52 | + return qpts, connections |
| 53 | + |
| 54 | + |
| 55 | +def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[bool]]: |
| 56 | + """Get the labels and coordinates of qpoints for the phonon band structure.""" |
| 57 | + # Get labels and coordinates for high-symmetry points |
| 58 | + _, qpts_labels, connections = get_band_qpoints_by_seekpath( |
| 59 | + ph.primitive, npoints=n_points, is_const_interval=True |
| 60 | + ) |
| 61 | + connections = [True, *connections] |
| 62 | + connections[-1] = True |
| 63 | + qpts_labels_connections = [] |
| 64 | + idx = 0 |
| 65 | + for connection in connections: |
| 66 | + if connection: |
| 67 | + qpts_labels_connections.append(qpts_labels[idx]) |
| 68 | + idx += 1 |
| 69 | + else: |
| 70 | + qpts_labels_connections.append(f"{qpts_labels[idx]}|{qpts_labels[idx + 1]}") |
| 71 | + idx += 2 |
| 72 | + |
| 73 | + qpts_labels_arr = [ |
| 74 | + q_label.replace("\\Gamma", "Γ") |
| 75 | + .replace("$", "") |
| 76 | + .replace("\\", "") |
| 77 | + .replace("mathrm", "") |
| 78 | + .replace("{", "") |
| 79 | + .replace("}", "") |
| 80 | + for q_label in qpts_labels_connections |
| 81 | + ] |
| 82 | + bands_dict = ph.get_band_structure_dict() |
| 83 | + npaths = len(bands_dict["frequencies"]) |
| 84 | + qpts_coord = [bands_dict["distances"][n][0] for n in range(npaths)] + [ |
| 85 | + bands_dict["distances"][-1][-1] |
| 86 | + ] |
| 87 | + |
| 88 | + return qpts_labels_arr, qpts_coord |
| 89 | + |
| 90 | + |
| 91 | +# Set device and data type |
| 92 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 93 | +dtype = torch.float32 |
| 94 | + |
| 95 | +# Load the raw model |
| 96 | +mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" |
| 97 | +loaded_model = mace_mp( |
| 98 | + model=mace_checkpoint_url, |
| 99 | + return_raw_model=True, |
| 100 | + default_dtype=dtype, |
| 101 | + device=device, |
| 102 | +) |
| 103 | + |
| 104 | +# Structure and input parameters |
| 105 | +struct = bulk("Si", "diamond", a=5.431, cubic=True) # ASE structure |
| 106 | +supercell_matrix = 2 * np.eye(3) # supercell matrix for phonon calculation |
| 107 | +mesh = [20, 20, 20] # Phonon mesh |
| 108 | +Nrelax = 300 # number of relaxation steps |
| 109 | +displ = 0.01 # atomic displacement for phonons (in Angstrom) |
| 110 | + |
| 111 | +# Relax atomic positions |
| 112 | +model = MaceModel( |
| 113 | + model=loaded_model, |
| 114 | + device=device, |
| 115 | + neighbor_list_fn=vesin_nl_ts, |
| 116 | + compute_forces=True, |
| 117 | + compute_stress=True, |
| 118 | + dtype=dtype, |
| 119 | + enable_cueq=False, |
| 120 | +) |
| 121 | +final_state = optimize( |
| 122 | + system=struct, |
| 123 | + model=model, |
| 124 | + optimizer=frechet_cell_fire, |
| 125 | + constant_volume=True, |
| 126 | + hydrostatic_strain=True, |
| 127 | + max_steps=Nrelax, |
| 128 | +) |
| 129 | + |
| 130 | +# Define atoms and Phonopy object |
| 131 | +atoms = state_to_phonopy(final_state)[0] |
| 132 | +ph = Phonopy(atoms, supercell_matrix) |
| 133 | + |
| 134 | +# Generate FC2 displacements |
| 135 | +ph.generate_displacements(distance=displ) |
| 136 | +supercells = ph.supercells_with_displacements |
| 137 | + |
| 138 | +# Convert PhonopyAtoms to state |
| 139 | +state = phonopy_to_state(supercells, device, dtype) |
| 140 | +results = model(state) |
| 141 | + |
| 142 | +# Extract forces and convert back to list of numpy arrays for phonopy |
| 143 | +n_atoms_per_supercell = [len(cell) for cell in supercells] |
| 144 | +force_sets = [] |
| 145 | +start_idx = 0 |
| 146 | +for n_atoms in n_atoms_per_supercell: |
| 147 | + end_idx = start_idx + n_atoms |
| 148 | + force_sets.append(results["forces"][start_idx:end_idx].detach().cpu().numpy()) |
| 149 | + start_idx = end_idx |
| 150 | + |
| 151 | +# Produce force constants |
| 152 | +ph.forces = force_sets |
| 153 | +ph.produce_force_constants() |
| 154 | + |
| 155 | +# Set mesh for DOS calculation |
| 156 | +ph.run_mesh(mesh) |
| 157 | +ph.run_total_dos() |
| 158 | + |
| 159 | +# Calculate phonon band structure |
| 160 | +ase_atoms = Atoms( |
| 161 | + symbols=atoms.symbols, |
| 162 | + positions=atoms.positions, |
| 163 | + cell=atoms.cell, |
| 164 | + pbc=True, |
| 165 | +) |
| 166 | +qpts, connections = get_qpts_and_connections(ase_atoms) |
| 167 | +ph.run_band_structure(qpts, connections) |
| 168 | + |
| 169 | +# Define axis style for plots |
| 170 | +axis_style = dict( |
| 171 | + showgrid=False, |
| 172 | + zeroline=False, |
| 173 | + linecolor="black", |
| 174 | + showline=True, |
| 175 | + ticks="inside", |
| 176 | + mirror=True, |
| 177 | + linewidth=3, |
| 178 | + tickwidth=3, |
| 179 | + ticklen=10, |
| 180 | +) |
| 181 | + |
| 182 | +# Plot phonon DOS |
| 183 | +fig = pmv.phonon_dos(ph.total_dos) |
| 184 | +fig.update_traces(line_width=3) |
| 185 | +fig.update_layout( |
| 186 | + xaxis_title="Frequency (THz)", |
| 187 | + yaxis_title="DOS", |
| 188 | + font=dict(size=24), |
| 189 | + xaxis=axis_style, |
| 190 | + yaxis=axis_style, |
| 191 | + width=800, |
| 192 | + height=600, |
| 193 | + plot_bgcolor="white", |
| 194 | +) |
| 195 | +fig.show() |
| 196 | + |
| 197 | +# Plot phonon band structure |
| 198 | +ph.auto_band_structure(plot=False) |
| 199 | +fig = pmv.phonon_bands( |
| 200 | + ph.band_structure, |
| 201 | + line_kwargs={"width": 3}, |
| 202 | +) |
| 203 | +qpts_labels, qpts_coord = get_labels_qpts(ph) |
| 204 | +for q_pt in qpts_coord: |
| 205 | + fig.add_vline(x=q_pt, line_dash="dash", line_color="black", line_width=2, opacity=1) |
| 206 | +fig.update_layout( |
| 207 | + xaxis_title="Wave Vector", |
| 208 | + yaxis_title="Frequency (THz)", |
| 209 | + font=dict(size=24), |
| 210 | + xaxis=dict( |
| 211 | + tickmode="array", |
| 212 | + tickvals=qpts_coord, |
| 213 | + ticktext=qpts_labels, |
| 214 | + showgrid=False, |
| 215 | + zeroline=False, |
| 216 | + linecolor="black", |
| 217 | + showline=True, |
| 218 | + ticks="inside", |
| 219 | + mirror=True, |
| 220 | + linewidth=3, |
| 221 | + tickwidth=3, |
| 222 | + ticklen=10, |
| 223 | + ), |
| 224 | + yaxis=axis_style, |
| 225 | + width=800, |
| 226 | + height=600, |
| 227 | + plot_bgcolor="white", |
| 228 | +) |
| 229 | +fig.show() |
0 commit comments