Skip to content

Commit d0f6652

Browse files
fallettajanosh
andauthored
Phonon examples: bands + DOS, quasi-harmonic, Wigner lattice conductivity (#91)
* included relaxation + plotting of band structure * included example for wigner conductivity * added QHA * included QHA examples + updated plotting * added fully batched QHA example * removed redundant examples * updated readme * fix variable overwriting function in docs/conf.py --------- Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
1 parent a1c2b93 commit d0f6652

8 files changed

+804
-315
lines changed

docs/conf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# add these directories to sys.path here. If the directory is relative to the
99
# documentation root, use os.path.abspath to make it absolute, like shown here.
1010

11+
import importlib.metadata
1112
import os
1213
import sys
13-
from importlib.metadata import version
1414

1515

1616
sys.path.insert(0, os.path.abspath("../../"))
@@ -22,16 +22,16 @@
2222
author = "Abhijeet Gangan, Orion Cohen, Janosh Riebesell"
2323

2424
# The short X.Y version
25-
version = version("torch-sim")
25+
version = importlib.metadata.version("torch-sim")
2626
# The full version, including alpha/beta/rc tags
27-
release = version("torch-sim")
27+
release = importlib.metadata.version("torch-sim")
2828

2929
# -- General configuration ---------------------------------------------------
3030

3131
# Add any Sphinx extension module names here, as strings. They can be
3232
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
3333
# ones.
34-
extensions = [
34+
extensions = (
3535
"sphinx.ext.autodoc",
3636
"sphinx.ext.napoleon",
3737
"sphinx.ext.intersphinx",
@@ -43,7 +43,7 @@
4343
"nbsphinx",
4444
"sphinx_design",
4545
"sphinx_copybutton",
46-
]
46+
)
4747

4848
# Add any paths that contain templates here, relative to this directory.
4949
templates_path = ["_templates"]

examples/scripts/6_Phonons/6.1_Phonon_dos_batched_MACE.py

Lines changed: 0 additions & 95 deletions
This file was deleted.
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Calculate Phonon DOS and band structure with MACE in batched mode."""
2+
3+
# /// script
4+
# dependencies = [
5+
# "mace-torch>=0.3.11",
6+
# "phonopy>=2.35",
7+
# "pymatviz[export-figs]>=0.15.1",
8+
# "seekpath",
9+
# "ase",
10+
# ]
11+
# ///
12+
13+
import numpy as np
14+
import pymatviz as pmv
15+
import seekpath
16+
import torch
17+
from ase import Atoms
18+
from ase.build import bulk
19+
from mace.calculators.foundations_models import mace_mp
20+
from phonopy import Phonopy
21+
from phonopy.phonon.band_structure import (
22+
get_band_qpoints_and_path_connections,
23+
get_band_qpoints_by_seekpath,
24+
)
25+
26+
from torch_sim import optimize
27+
from torch_sim.io import phonopy_to_state, state_to_phonopy
28+
from torch_sim.models.mace import MaceModel
29+
from torch_sim.neighbors import vesin_nl_ts
30+
from torch_sim.optimizers import frechet_cell_fire
31+
32+
33+
def get_qpts_and_connections(
34+
ase_atoms: Atoms,
35+
n_points: int = 101,
36+
) -> tuple[list[list[float]], list[bool]]:
37+
"""Get the high symmetry points and path connections for the band structure."""
38+
# Define seekpath data
39+
seekpath_data = seekpath.get_path(
40+
(ase_atoms.cell, ase_atoms.get_scaled_positions(), ase_atoms.numbers)
41+
)
42+
43+
# Extract high symmetry points and path
44+
points = seekpath_data["point_coords"]
45+
path = []
46+
for segment in seekpath_data["path"]:
47+
start_point = points[segment[0]]
48+
end_point = points[segment[1]]
49+
path.append([start_point, end_point])
50+
qpts, connections = get_band_qpoints_and_path_connections(path, npoints=n_points)
51+
52+
return qpts, connections
53+
54+
55+
def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[bool]]:
56+
"""Get the labels and coordinates of qpoints for the phonon band structure."""
57+
# Get labels and coordinates for high-symmetry points
58+
_, qpts_labels, connections = get_band_qpoints_by_seekpath(
59+
ph.primitive, npoints=n_points, is_const_interval=True
60+
)
61+
connections = [True, *connections]
62+
connections[-1] = True
63+
qpts_labels_connections = []
64+
idx = 0
65+
for connection in connections:
66+
if connection:
67+
qpts_labels_connections.append(qpts_labels[idx])
68+
idx += 1
69+
else:
70+
qpts_labels_connections.append(f"{qpts_labels[idx]}|{qpts_labels[idx + 1]}")
71+
idx += 2
72+
73+
qpts_labels_arr = [
74+
q_label.replace("\\Gamma", "Γ")
75+
.replace("$", "")
76+
.replace("\\", "")
77+
.replace("mathrm", "")
78+
.replace("{", "")
79+
.replace("}", "")
80+
for q_label in qpts_labels_connections
81+
]
82+
bands_dict = ph.get_band_structure_dict()
83+
npaths = len(bands_dict["frequencies"])
84+
qpts_coord = [bands_dict["distances"][n][0] for n in range(npaths)] + [
85+
bands_dict["distances"][-1][-1]
86+
]
87+
88+
return qpts_labels_arr, qpts_coord
89+
90+
91+
# Set device and data type
92+
device = "cuda" if torch.cuda.is_available() else "cpu"
93+
dtype = torch.float32
94+
95+
# Load the raw model
96+
mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
97+
loaded_model = mace_mp(
98+
model=mace_checkpoint_url,
99+
return_raw_model=True,
100+
default_dtype=dtype,
101+
device=device,
102+
)
103+
104+
# Structure and input parameters
105+
struct = bulk("Si", "diamond", a=5.431, cubic=True) # ASE structure
106+
supercell_matrix = 2 * np.eye(3) # supercell matrix for phonon calculation
107+
mesh = [20, 20, 20] # Phonon mesh
108+
Nrelax = 300 # number of relaxation steps
109+
displ = 0.01 # atomic displacement for phonons (in Angstrom)
110+
111+
# Relax atomic positions
112+
model = MaceModel(
113+
model=loaded_model,
114+
device=device,
115+
neighbor_list_fn=vesin_nl_ts,
116+
compute_forces=True,
117+
compute_stress=True,
118+
dtype=dtype,
119+
enable_cueq=False,
120+
)
121+
final_state = optimize(
122+
system=struct,
123+
model=model,
124+
optimizer=frechet_cell_fire,
125+
constant_volume=True,
126+
hydrostatic_strain=True,
127+
max_steps=Nrelax,
128+
)
129+
130+
# Define atoms and Phonopy object
131+
atoms = state_to_phonopy(final_state)[0]
132+
ph = Phonopy(atoms, supercell_matrix)
133+
134+
# Generate FC2 displacements
135+
ph.generate_displacements(distance=displ)
136+
supercells = ph.supercells_with_displacements
137+
138+
# Convert PhonopyAtoms to state
139+
state = phonopy_to_state(supercells, device, dtype)
140+
results = model(state)
141+
142+
# Extract forces and convert back to list of numpy arrays for phonopy
143+
n_atoms_per_supercell = [len(cell) for cell in supercells]
144+
force_sets = []
145+
start_idx = 0
146+
for n_atoms in n_atoms_per_supercell:
147+
end_idx = start_idx + n_atoms
148+
force_sets.append(results["forces"][start_idx:end_idx].detach().cpu().numpy())
149+
start_idx = end_idx
150+
151+
# Produce force constants
152+
ph.forces = force_sets
153+
ph.produce_force_constants()
154+
155+
# Set mesh for DOS calculation
156+
ph.run_mesh(mesh)
157+
ph.run_total_dos()
158+
159+
# Calculate phonon band structure
160+
ase_atoms = Atoms(
161+
symbols=atoms.symbols,
162+
positions=atoms.positions,
163+
cell=atoms.cell,
164+
pbc=True,
165+
)
166+
qpts, connections = get_qpts_and_connections(ase_atoms)
167+
ph.run_band_structure(qpts, connections)
168+
169+
# Define axis style for plots
170+
axis_style = dict(
171+
showgrid=False,
172+
zeroline=False,
173+
linecolor="black",
174+
showline=True,
175+
ticks="inside",
176+
mirror=True,
177+
linewidth=3,
178+
tickwidth=3,
179+
ticklen=10,
180+
)
181+
182+
# Plot phonon DOS
183+
fig = pmv.phonon_dos(ph.total_dos)
184+
fig.update_traces(line_width=3)
185+
fig.update_layout(
186+
xaxis_title="Frequency (THz)",
187+
yaxis_title="DOS",
188+
font=dict(size=24),
189+
xaxis=axis_style,
190+
yaxis=axis_style,
191+
width=800,
192+
height=600,
193+
plot_bgcolor="white",
194+
)
195+
fig.show()
196+
197+
# Plot phonon band structure
198+
ph.auto_band_structure(plot=False)
199+
fig = pmv.phonon_bands(
200+
ph.band_structure,
201+
line_kwargs={"width": 3},
202+
)
203+
qpts_labels, qpts_coord = get_labels_qpts(ph)
204+
for q_pt in qpts_coord:
205+
fig.add_vline(x=q_pt, line_dash="dash", line_color="black", line_width=2, opacity=1)
206+
fig.update_layout(
207+
xaxis_title="Wave Vector",
208+
yaxis_title="Frequency (THz)",
209+
font=dict(size=24),
210+
xaxis=dict(
211+
tickmode="array",
212+
tickvals=qpts_coord,
213+
ticktext=qpts_labels,
214+
showgrid=False,
215+
zeroline=False,
216+
linecolor="black",
217+
showline=True,
218+
ticks="inside",
219+
mirror=True,
220+
linewidth=3,
221+
tickwidth=3,
222+
ticklen=10,
223+
),
224+
yaxis=axis_style,
225+
width=800,
226+
height=600,
227+
plot_bgcolor="white",
228+
)
229+
fig.show()

0 commit comments

Comments
 (0)