Skip to content

Commit f6cd006

Browse files
Improve Typing of ModelInterface (#215)
Signed-off-by: Rhys Goodall <rhys.goodall@outlook.com> Co-authored-by: Rhys Goodall <rhys.goodall@outlook.com>
1 parent 88abcff commit f6cd006

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+364
-269
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
Before a pull request can be merged, the following items must be checked:
99

1010
* [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google).
11-
Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
11+
* [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
12+
* [ ] Run `uvx ty check` on the repo.
1213
* [ ] Tests have been added for any new functionality or bug fixes.
1314

1415
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit.

examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@
8989
positions=positions,
9090
masses=masses,
9191
cell=cell.unsqueeze(0),
92-
pbc=True,
9392
atomic_numbers=atomic_numbers,
93+
pbc=True,
9494
)
9595

9696
# Run initial simulation and get results

examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
positions=positions,
8181
masses=masses,
8282
cell=cell.unsqueeze(0),
83-
pbc=True,
8483
atomic_numbers=atomic_numbers,
84+
pbc=True,
8585
)
8686

8787
# Initialize the Soft Sphere model

examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@
9797
positions=positions,
9898
masses=masses,
9999
cell=cell.unsqueeze(0),
100-
pbc=True,
101100
atomic_numbers=atomic_numbers,
101+
pbc=True,
102102
)
103103
# Run initial simulation and get results
104104
results = model(state)
@@ -148,11 +148,11 @@
148148

149149

150150
stress = model(state)["stress"]
151-
calc_kinetic_energy = calc_kinetic_energy(
151+
kinetic_energy = calc_kinetic_energy(
152152
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
153153
)
154154
volume = torch.linalg.det(state.cell)
155-
pressure = get_pressure(stress, calc_kinetic_energy, volume)
155+
pressure = get_pressure(stress, kinetic_energy, volume)
156156
pressure = pressure.item() / Units.pressure
157157
print(f"Final {pressure=:.4f}")
158158
print(stress * UnitConversion.eV_per_Ang3_to_GPa)

examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@
7878
masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype)
7979

8080
state = ts.SimState(
81-
positions=positions,
82-
masses=masses,
83-
cell=cell,
84-
pbc=True,
85-
atomic_numbers=atomic_numbers,
81+
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
8682
)
8783
# Initialize the Lennard-Jones model
8884
# Parameters:

examples/scripts/3_Dynamics/3.2_MACE_NVE.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,7 @@
6060
)
6161

6262
state = ts.SimState(
63-
positions=positions,
64-
masses=masses,
65-
cell=cell,
66-
pbc=True,
67-
atomic_numbers=atomic_numbers,
63+
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
6864
)
6965

7066
# Run initial inference

examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@
5959
)
6060

6161
state = ts.SimState(
62-
positions=positions,
63-
masses=masses,
64-
cell=cell,
65-
pbc=True,
66-
atomic_numbers=atomic_numbers,
62+
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
6763
)
6864

6965
dt = 0.002 * Units.time # Timestep (ps)

examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@
9696
positions=positions,
9797
masses=masses,
9898
cell=cell.unsqueeze(0),
99-
pbc=True,
10099
atomic_numbers=atomic_numbers,
100+
pbc=True,
101101
)
102102
# Run initial simulation and get results
103103
results = model(state)

examples/scripts/4_High_level_api/4.1_high_level_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
prop_calculators = {
5555
10: {"potential_energy": lambda state: state.energy},
5656
20: {
57-
"kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses)
57+
"kinetic_energy": lambda state: calc_kinetic_energy(
58+
momenta=state.momenta, masses=state.masses
59+
)
5860
},
5961
}
6062

examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from phonopy.structure.atoms import PhonopyAtoms
2525

2626
import torch_sim as ts
27+
from torch_sim.models.interface import ModelInterface
2728
from torch_sim.models.mace import MaceModel, MaceUrls
2829

2930

3031
def get_relaxed_structure(
3132
struct: Atoms,
32-
model: torch.nn.Module | None,
33+
model: ModelInterface,
3334
Nrelax: int = 300,
3435
fmax: float = 1e-3,
3536
*,
@@ -80,7 +81,7 @@ def get_relaxed_structure(
8081
def get_qha_structures(
8182
state: ts.state.SimState,
8283
length_factors: np.ndarray,
83-
model: torch.nn.Module | None,
84+
model: ModelInterface,
8485
Nmax: int = 300,
8586
fmax: float = 1e-3,
8687
*,
@@ -129,7 +130,7 @@ def get_qha_structures(
129130

130131
def get_qha_phonons(
131132
scaled_structures: list[PhonopyAtoms],
132-
model: torch.nn.Module | None,
133+
model: ModelInterface,
133134
supercell_matrix: np.ndarray | None,
134135
displ: float = 0.05,
135136
*,

0 commit comments

Comments
 (0)