Skip to content

Commit 6ecdce9

Browse files
AdeeshKolluruorionarcher
authored andcommitted
Support for fairchem models.
1 parent 63192d8 commit 6ecdce9

File tree

2 files changed

+385
-0
lines changed

2 files changed

+385
-0
lines changed

tests/models/test_fairchem.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
import torch
3+
from ase.build import bulk
4+
from fairchem.core import OCPCalculator
5+
from fairchem.core.models.model_registry import model_name_to_local_file
6+
7+
from torchsim.models.fairchem import FairChemModel
8+
from torchsim.models.interface import validate_model_outputs
9+
from torchsim.runners import atoms_to_state
10+
from torchsim.state import BaseState
11+
12+
13+
@pytest.fixture(scope="session")
14+
def model_path(tmp_path_factory: pytest.TempPathFactory) -> str:
15+
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
16+
return model_name_to_local_file(
17+
"EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=str(tmp_path)
18+
)
19+
20+
21+
@pytest.fixture
22+
def si_system(dtype: torch.dtype, device: torch.device) -> BaseState:
23+
# Create diamond cubic Silicon
24+
si_dc = bulk("Si", "diamond", a=5.43)
25+
26+
return atoms_to_state([si_dc], device, dtype)
27+
28+
29+
@pytest.fixture
30+
def fairchem_calculator(model_path: str, device: torch.device) -> FairChemModel:
31+
cpu = device.type == "cpu"
32+
return FairChemModel(
33+
model=model_path,
34+
cpu=cpu,
35+
seed=0,
36+
pbc=True,
37+
)
38+
39+
40+
@pytest.fixture
41+
def ocp_calculator(model_path: str) -> OCPCalculator:
42+
return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0)
43+
44+
45+
def test_fairchem_ocp_consistency(
46+
fairchem_calculator: FairChemModel,
47+
ocp_calculator: OCPCalculator,
48+
device: torch.device,
49+
) -> None:
50+
# Set up ASE calculator
51+
si_dc = bulk("Si", "diamond", a=5.43)
52+
si_dc.calc = ocp_calculator
53+
54+
si_state = atoms_to_state([si_dc], device, torch.float32)
55+
# Get FairChem results
56+
fairchem_results = fairchem_calculator(
57+
si_state.positions, si_state.cell, si_state.atomic_numbers
58+
)
59+
60+
# Get OCP results
61+
ocp_forces = torch.tensor(
62+
si_dc.get_forces(),
63+
device=device,
64+
dtype=fairchem_results["forces"].dtype,
65+
)
66+
67+
# Test consistency with reasonable tolerances
68+
torch.testing.assert_close(
69+
fairchem_results["energy"].item(),
70+
si_dc.get_potential_energy(),
71+
rtol=1e-2,
72+
atol=1e-2,
73+
)
74+
torch.testing.assert_close(
75+
fairchem_results["forces"], ocp_forces, rtol=1e-2, atol=1e-2
76+
)
77+
78+
79+
@pytest.mark.skipif(
80+
not torch.cuda.is_available(), reason="Batching does not work properly on CPU"
81+
)
82+
def test_validate_model_outputs(
83+
fairchem_calculator: FairChemModel, device: torch.device
84+
) -> None:
85+
validate_model_outputs(fairchem_calculator, device, torch.float32)

torchsim/models/fairchem.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""FairChem model for computing energies, forces and stresses.
2+
3+
This module provides a PyTorch implementation of the FairChem model.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import copy
9+
from types import MappingProxyType
10+
from typing import TYPE_CHECKING
11+
12+
import torch
13+
from fairchem.core.common.registry import registry
14+
from fairchem.core.common.utils import (
15+
load_config,
16+
setup_imports,
17+
setup_logging,
18+
update_config,
19+
)
20+
from fairchem.core.models.model_registry import model_name_to_local_file
21+
from torch_geometric.data import Batch
22+
23+
from torchsim.models.interface import ModelInterface
24+
25+
26+
if TYPE_CHECKING:
27+
from collections.abc import Callable
28+
from pathlib import Path
29+
30+
DTYPE_DICT = {
31+
torch.float16: "float16",
32+
torch.float32: "float32",
33+
torch.float64: "float64",
34+
}
35+
36+
37+
class FairChemModel(torch.nn.Module, ModelInterface):
38+
"""Computes energies, forces and stresses using a FairChem model.
39+
40+
Attributes:
41+
pbc (bool): Whether to use periodic boundary conditions
42+
neighbor_list_fn (Callable | None): The neighbor list function to use
43+
r_max (float): Maximum cutoff radius for atomic interactions
44+
config (dict): Model configuration dictionary
45+
trainer: The FairChem trainer object
46+
data_object (Batch): Data object containing system information
47+
implemented_properties (list): List of implemented model outputs
48+
"""
49+
50+
_reshaped_props = MappingProxyType(
51+
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
52+
)
53+
54+
def __init__( # noqa: C901, PLR0915
55+
self,
56+
model: str | Path | None,
57+
neighbor_list_fn: Callable | None = None,
58+
*, # force remaining arguments to be keyword-only
59+
config_yml: str | None = None,
60+
model_name: str | None = None,
61+
local_cache: str | None = None,
62+
trainer: str | None = None,
63+
cpu: bool = False,
64+
seed: int | None = None,
65+
pbc: bool = True,
66+
r_max: float | None = None, # noqa: ARG002
67+
dtype: torch.dtype | None = None,
68+
compute_stress: bool = False,
69+
) -> None:
70+
"""Initialize the FairChemModel.
71+
72+
Args:
73+
model: Path to model checkpoint
74+
atomic_numbers_list: List of atomic numbers for each system
75+
neighbor_list_fn: Neighbor list function (not currently supported)
76+
config_yml: Path to config YAML file
77+
model_name: Name of pretrained model
78+
local_cache: Path to local model cache
79+
trainer: Name of trainer to use
80+
cpu: Whether to use CPU instead of GPU
81+
seed: Random seed for reproducibility
82+
pbc: Whether to use periodic boundary conditions
83+
r_max: Maximum cutoff radius (overrides model default)
84+
dtype: Data type to use for the model
85+
compute_stress: Whether to compute stress
86+
"""
87+
setup_imports()
88+
setup_logging()
89+
super().__init__()
90+
91+
self._dtype = dtype or torch.float32
92+
self._compute_stress = compute_stress
93+
self._compute_force = True
94+
95+
if model_name is not None:
96+
if model is not None:
97+
raise RuntimeError(
98+
"model_name and checkpoint_path were both specified, "
99+
"please use only one at a time"
100+
)
101+
if local_cache is None:
102+
raise NotImplementedError(
103+
"Local cache must be set when specifying a model name"
104+
)
105+
model = model_name_to_local_file(
106+
model_name=model_name, local_cache=local_cache
107+
)
108+
109+
# Either the config path or the checkpoint path needs to be provided
110+
assert config_yml or model is not None
111+
112+
checkpoint = None
113+
if config_yml is not None:
114+
if isinstance(config_yml, str):
115+
config, duplicates_warning, duplicates_error = load_config(config_yml)
116+
if len(duplicates_warning) > 0:
117+
print(
118+
"Overwritten config parameters from included configs "
119+
f"(non-included parameters take precedence): {duplicates_warning}"
120+
)
121+
if len(duplicates_error) > 0:
122+
raise ValueError(
123+
"Conflicting (duplicate) parameters in simultaneously "
124+
f"included configs: {duplicates_error}"
125+
)
126+
else:
127+
config = config_yml
128+
129+
# Only keeps the train data that might have normalizer values
130+
if isinstance(config["dataset"], list):
131+
config["dataset"] = config["dataset"][0]
132+
elif isinstance(config["dataset"], dict):
133+
config["dataset"] = config["dataset"].get("train", None)
134+
else:
135+
# Loads the config from the checkpoint directly (always on CPU).
136+
checkpoint = torch.load(model, map_location=torch.device("cpu"))
137+
config = checkpoint["config"]
138+
139+
if trainer is not None:
140+
config["trainer"] = trainer
141+
else:
142+
config["trainer"] = config.get("trainer", "ocp")
143+
144+
if "model_attributes" in config:
145+
config["model_attributes"]["name"] = config.pop("model")
146+
config["model"] = config["model_attributes"]
147+
148+
self.pbc = pbc
149+
self.neighbor_list_fn = neighbor_list_fn
150+
151+
if neighbor_list_fn is None:
152+
# Calculate the edge indices on the fly
153+
config["model"]["otf_graph"] = True
154+
else:
155+
raise NotImplementedError(
156+
"Custom neighbor list is not supported for FairChemModel."
157+
)
158+
159+
if "backbone" in config["model"]:
160+
config["model"]["backbone"]["use_pbc"] = pbc
161+
config["model"]["backbone"]["use_pbc_single"] = False
162+
if dtype is not None:
163+
try:
164+
config["model"]["backbone"].update({"dtype": DTYPE_DICT[dtype]})
165+
for key in config["model"]["heads"]:
166+
config["model"]["heads"][key].update({"dtype": DTYPE_DICT[dtype]})
167+
except KeyError:
168+
print("dtype not found in backbone, using default float32")
169+
else:
170+
config["model"]["use_pbc"] = pbc
171+
config["model"]["use_pbc_single"] = False
172+
if dtype is not None:
173+
try:
174+
config["model"].update({"dtype": DTYPE_DICT[dtype]})
175+
except KeyError:
176+
print("dtype not found in backbone, using default dtype")
177+
178+
### backwards compatibility with OCP v<2.0
179+
config = update_config(config)
180+
181+
self.config = copy.deepcopy(config)
182+
self.config["checkpoint"] = str(model)
183+
del config["dataset"]["src"]
184+
185+
self.trainer = registry.get_trainer_class(config["trainer"])(
186+
task=config.get("task", {}),
187+
model=config["model"],
188+
dataset=[config["dataset"]],
189+
outputs=config["outputs"],
190+
loss_functions=config["loss_functions"],
191+
evaluation_metrics=config["evaluation_metrics"],
192+
optimizer=config["optim"],
193+
identifier="",
194+
slurm=config.get("slurm", {}),
195+
local_rank=config.get("local_rank", 0),
196+
is_debug=config.get("is_debug", True),
197+
cpu=cpu,
198+
amp=False if dtype is not None else config.get("amp", False),
199+
inference_only=True,
200+
)
201+
if dtype is not None:
202+
# Convert model parameters to specified dtype
203+
self.trainer.model.to(dtype=self.dtype)
204+
205+
if model is not None:
206+
self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint)
207+
208+
seed = seed if seed is not None else self.trainer.config["cmd"]["seed"]
209+
if seed is None:
210+
print(
211+
"No seed has been set in model checkpoint or OCPCalculator! Results may "
212+
"not be reproducible on re-run"
213+
)
214+
else:
215+
self.trainer.set_seed(seed)
216+
217+
self.implemented_properties = list(self.config["outputs"])
218+
219+
self._device = self.trainer.device
220+
221+
stress_output = "stress" in self.implemented_properties
222+
if not stress_output and compute_stress:
223+
raise NotImplementedError("Stress output not implemented for this model")
224+
225+
def load_checkpoint(
226+
self, checkpoint_path: str, checkpoint: dict | None = None
227+
) -> None:
228+
"""Load existing trained model.
229+
230+
Args:
231+
checkpoint_path: string
232+
Path to trained model
233+
checkpoint: dict
234+
A pretrained checkpoint dict
235+
"""
236+
try:
237+
self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True)
238+
except NotImplementedError:
239+
print("Unable to load checkpoint!")
240+
241+
def forward(
242+
self,
243+
positions: torch.Tensor,
244+
cell: torch.Tensor,
245+
atomic_numbers: torch.Tensor,
246+
batch: torch.Tensor | None = None,
247+
**_,
248+
) -> dict: # TODO: what are the shapes?
249+
"""Forward pass of the model.
250+
251+
Args:
252+
positions: Atomic positions tensor
253+
cell: Box vectors tensor
254+
batch: Batch tensor
255+
atomic_numbers: Atomic numbers tensor
256+
257+
Returns:
258+
Dictionary of model predictions
259+
"""
260+
if positions.device != self._device:
261+
positions = positions.to(self._device)
262+
if cell.device != self._device:
263+
cell = cell.to(self._device)
264+
265+
if batch is None:
266+
batch = torch.zeros(positions.shape[0], dtype=torch.int)
267+
268+
natoms = torch.bincount(batch)
269+
pbc = torch.tensor(
270+
[self.pbc, self.pbc, self.pbc] * len(natoms), dtype=torch.bool
271+
).view(-1, 3)
272+
fixed = torch.zeros((batch.size(0), natoms.sum()), dtype=torch.int)
273+
self.data_object = Batch(
274+
pos=positions,
275+
cell=cell,
276+
atomic_numbers=atomic_numbers,
277+
natoms=natoms,
278+
batch=batch,
279+
fixed=fixed,
280+
pbc=pbc,
281+
)
282+
283+
if self._dtype is not None:
284+
self.data_object.pos = self.data_object.pos.to(self._dtype)
285+
self.data_object.cell = self.data_object.cell.to(self._dtype)
286+
287+
predictions = self.trainer.predict(
288+
self.data_object, per_image=False, disable_tqdm=True
289+
)
290+
291+
results = {}
292+
293+
for key in predictions:
294+
_pred = predictions[key]
295+
if key in self._reshaped_props:
296+
_pred = _pred.reshape(self._reshaped_props.get(key)).squeeze()
297+
results[key] = _pred
298+
299+
results["energy"] = results["energy"].squeeze()
300+
return results

0 commit comments

Comments
 (0)