|
| 1 | +"""FairChem model for computing energies, forces and stresses. |
| 2 | +
|
| 3 | +This module provides a PyTorch implementation of the FairChem model. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import copy |
| 9 | +from types import MappingProxyType |
| 10 | +from typing import TYPE_CHECKING |
| 11 | + |
| 12 | +import torch |
| 13 | +from fairchem.core.common.registry import registry |
| 14 | +from fairchem.core.common.utils import ( |
| 15 | + load_config, |
| 16 | + setup_imports, |
| 17 | + setup_logging, |
| 18 | + update_config, |
| 19 | +) |
| 20 | +from fairchem.core.models.model_registry import model_name_to_local_file |
| 21 | +from torch_geometric.data import Batch |
| 22 | + |
| 23 | +from torchsim.models.interface import ModelInterface |
| 24 | + |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + from collections.abc import Callable |
| 28 | + from pathlib import Path |
| 29 | + |
| 30 | +DTYPE_DICT = { |
| 31 | + torch.float16: "float16", |
| 32 | + torch.float32: "float32", |
| 33 | + torch.float64: "float64", |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +class FairChemModel(torch.nn.Module, ModelInterface): |
| 38 | + """Computes energies, forces and stresses using a FairChem model. |
| 39 | +
|
| 40 | + Attributes: |
| 41 | + pbc (bool): Whether to use periodic boundary conditions |
| 42 | + neighbor_list_fn (Callable | None): The neighbor list function to use |
| 43 | + r_max (float): Maximum cutoff radius for atomic interactions |
| 44 | + config (dict): Model configuration dictionary |
| 45 | + trainer: The FairChem trainer object |
| 46 | + data_object (Batch): Data object containing system information |
| 47 | + implemented_properties (list): List of implemented model outputs |
| 48 | + """ |
| 49 | + |
| 50 | + _reshaped_props = MappingProxyType( |
| 51 | + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} |
| 52 | + ) |
| 53 | + |
| 54 | + def __init__( # noqa: C901, PLR0915 |
| 55 | + self, |
| 56 | + model: str | Path | None, |
| 57 | + neighbor_list_fn: Callable | None = None, |
| 58 | + *, # force remaining arguments to be keyword-only |
| 59 | + config_yml: str | None = None, |
| 60 | + model_name: str | None = None, |
| 61 | + local_cache: str | None = None, |
| 62 | + trainer: str | None = None, |
| 63 | + cpu: bool = False, |
| 64 | + seed: int | None = None, |
| 65 | + pbc: bool = True, |
| 66 | + r_max: float | None = None, # noqa: ARG002 |
| 67 | + dtype: torch.dtype | None = None, |
| 68 | + compute_stress: bool = False, |
| 69 | + ) -> None: |
| 70 | + """Initialize the FairChemModel. |
| 71 | +
|
| 72 | + Args: |
| 73 | + model: Path to model checkpoint |
| 74 | + atomic_numbers_list: List of atomic numbers for each system |
| 75 | + neighbor_list_fn: Neighbor list function (not currently supported) |
| 76 | + config_yml: Path to config YAML file |
| 77 | + model_name: Name of pretrained model |
| 78 | + local_cache: Path to local model cache |
| 79 | + trainer: Name of trainer to use |
| 80 | + cpu: Whether to use CPU instead of GPU |
| 81 | + seed: Random seed for reproducibility |
| 82 | + pbc: Whether to use periodic boundary conditions |
| 83 | + r_max: Maximum cutoff radius (overrides model default) |
| 84 | + dtype: Data type to use for the model |
| 85 | + compute_stress: Whether to compute stress |
| 86 | + """ |
| 87 | + setup_imports() |
| 88 | + setup_logging() |
| 89 | + super().__init__() |
| 90 | + |
| 91 | + self._dtype = dtype or torch.float32 |
| 92 | + self._compute_stress = compute_stress |
| 93 | + self._compute_force = True |
| 94 | + |
| 95 | + if model_name is not None: |
| 96 | + if model is not None: |
| 97 | + raise RuntimeError( |
| 98 | + "model_name and checkpoint_path were both specified, " |
| 99 | + "please use only one at a time" |
| 100 | + ) |
| 101 | + if local_cache is None: |
| 102 | + raise NotImplementedError( |
| 103 | + "Local cache must be set when specifying a model name" |
| 104 | + ) |
| 105 | + model = model_name_to_local_file( |
| 106 | + model_name=model_name, local_cache=local_cache |
| 107 | + ) |
| 108 | + |
| 109 | + # Either the config path or the checkpoint path needs to be provided |
| 110 | + assert config_yml or model is not None |
| 111 | + |
| 112 | + checkpoint = None |
| 113 | + if config_yml is not None: |
| 114 | + if isinstance(config_yml, str): |
| 115 | + config, duplicates_warning, duplicates_error = load_config(config_yml) |
| 116 | + if len(duplicates_warning) > 0: |
| 117 | + print( |
| 118 | + "Overwritten config parameters from included configs " |
| 119 | + f"(non-included parameters take precedence): {duplicates_warning}" |
| 120 | + ) |
| 121 | + if len(duplicates_error) > 0: |
| 122 | + raise ValueError( |
| 123 | + "Conflicting (duplicate) parameters in simultaneously " |
| 124 | + f"included configs: {duplicates_error}" |
| 125 | + ) |
| 126 | + else: |
| 127 | + config = config_yml |
| 128 | + |
| 129 | + # Only keeps the train data that might have normalizer values |
| 130 | + if isinstance(config["dataset"], list): |
| 131 | + config["dataset"] = config["dataset"][0] |
| 132 | + elif isinstance(config["dataset"], dict): |
| 133 | + config["dataset"] = config["dataset"].get("train", None) |
| 134 | + else: |
| 135 | + # Loads the config from the checkpoint directly (always on CPU). |
| 136 | + checkpoint = torch.load(model, map_location=torch.device("cpu")) |
| 137 | + config = checkpoint["config"] |
| 138 | + |
| 139 | + if trainer is not None: |
| 140 | + config["trainer"] = trainer |
| 141 | + else: |
| 142 | + config["trainer"] = config.get("trainer", "ocp") |
| 143 | + |
| 144 | + if "model_attributes" in config: |
| 145 | + config["model_attributes"]["name"] = config.pop("model") |
| 146 | + config["model"] = config["model_attributes"] |
| 147 | + |
| 148 | + self.pbc = pbc |
| 149 | + self.neighbor_list_fn = neighbor_list_fn |
| 150 | + |
| 151 | + if neighbor_list_fn is None: |
| 152 | + # Calculate the edge indices on the fly |
| 153 | + config["model"]["otf_graph"] = True |
| 154 | + else: |
| 155 | + raise NotImplementedError( |
| 156 | + "Custom neighbor list is not supported for FairChemModel." |
| 157 | + ) |
| 158 | + |
| 159 | + if "backbone" in config["model"]: |
| 160 | + config["model"]["backbone"]["use_pbc"] = pbc |
| 161 | + config["model"]["backbone"]["use_pbc_single"] = False |
| 162 | + if dtype is not None: |
| 163 | + try: |
| 164 | + config["model"]["backbone"].update({"dtype": DTYPE_DICT[dtype]}) |
| 165 | + for key in config["model"]["heads"]: |
| 166 | + config["model"]["heads"][key].update({"dtype": DTYPE_DICT[dtype]}) |
| 167 | + except KeyError: |
| 168 | + print("dtype not found in backbone, using default float32") |
| 169 | + else: |
| 170 | + config["model"]["use_pbc"] = pbc |
| 171 | + config["model"]["use_pbc_single"] = False |
| 172 | + if dtype is not None: |
| 173 | + try: |
| 174 | + config["model"].update({"dtype": DTYPE_DICT[dtype]}) |
| 175 | + except KeyError: |
| 176 | + print("dtype not found in backbone, using default dtype") |
| 177 | + |
| 178 | + ### backwards compatibility with OCP v<2.0 |
| 179 | + config = update_config(config) |
| 180 | + |
| 181 | + self.config = copy.deepcopy(config) |
| 182 | + self.config["checkpoint"] = str(model) |
| 183 | + del config["dataset"]["src"] |
| 184 | + |
| 185 | + self.trainer = registry.get_trainer_class(config["trainer"])( |
| 186 | + task=config.get("task", {}), |
| 187 | + model=config["model"], |
| 188 | + dataset=[config["dataset"]], |
| 189 | + outputs=config["outputs"], |
| 190 | + loss_functions=config["loss_functions"], |
| 191 | + evaluation_metrics=config["evaluation_metrics"], |
| 192 | + optimizer=config["optim"], |
| 193 | + identifier="", |
| 194 | + slurm=config.get("slurm", {}), |
| 195 | + local_rank=config.get("local_rank", 0), |
| 196 | + is_debug=config.get("is_debug", True), |
| 197 | + cpu=cpu, |
| 198 | + amp=False if dtype is not None else config.get("amp", False), |
| 199 | + inference_only=True, |
| 200 | + ) |
| 201 | + if dtype is not None: |
| 202 | + # Convert model parameters to specified dtype |
| 203 | + self.trainer.model.to(dtype=self.dtype) |
| 204 | + |
| 205 | + if model is not None: |
| 206 | + self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) |
| 207 | + |
| 208 | + seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] |
| 209 | + if seed is None: |
| 210 | + print( |
| 211 | + "No seed has been set in model checkpoint or OCPCalculator! Results may " |
| 212 | + "not be reproducible on re-run" |
| 213 | + ) |
| 214 | + else: |
| 215 | + self.trainer.set_seed(seed) |
| 216 | + |
| 217 | + self.implemented_properties = list(self.config["outputs"]) |
| 218 | + |
| 219 | + self._device = self.trainer.device |
| 220 | + |
| 221 | + stress_output = "stress" in self.implemented_properties |
| 222 | + if not stress_output and compute_stress: |
| 223 | + raise NotImplementedError("Stress output not implemented for this model") |
| 224 | + |
| 225 | + def load_checkpoint( |
| 226 | + self, checkpoint_path: str, checkpoint: dict | None = None |
| 227 | + ) -> None: |
| 228 | + """Load existing trained model. |
| 229 | +
|
| 230 | + Args: |
| 231 | + checkpoint_path: string |
| 232 | + Path to trained model |
| 233 | + checkpoint: dict |
| 234 | + A pretrained checkpoint dict |
| 235 | + """ |
| 236 | + try: |
| 237 | + self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) |
| 238 | + except NotImplementedError: |
| 239 | + print("Unable to load checkpoint!") |
| 240 | + |
| 241 | + def forward( |
| 242 | + self, |
| 243 | + positions: torch.Tensor, |
| 244 | + cell: torch.Tensor, |
| 245 | + atomic_numbers: torch.Tensor, |
| 246 | + batch: torch.Tensor | None = None, |
| 247 | + **_, |
| 248 | + ) -> dict: # TODO: what are the shapes? |
| 249 | + """Forward pass of the model. |
| 250 | +
|
| 251 | + Args: |
| 252 | + positions: Atomic positions tensor |
| 253 | + cell: Box vectors tensor |
| 254 | + batch: Batch tensor |
| 255 | + atomic_numbers: Atomic numbers tensor |
| 256 | +
|
| 257 | + Returns: |
| 258 | + Dictionary of model predictions |
| 259 | + """ |
| 260 | + if positions.device != self._device: |
| 261 | + positions = positions.to(self._device) |
| 262 | + if cell.device != self._device: |
| 263 | + cell = cell.to(self._device) |
| 264 | + |
| 265 | + if batch is None: |
| 266 | + batch = torch.zeros(positions.shape[0], dtype=torch.int) |
| 267 | + |
| 268 | + natoms = torch.bincount(batch) |
| 269 | + pbc = torch.tensor( |
| 270 | + [self.pbc, self.pbc, self.pbc] * len(natoms), dtype=torch.bool |
| 271 | + ).view(-1, 3) |
| 272 | + fixed = torch.zeros((batch.size(0), natoms.sum()), dtype=torch.int) |
| 273 | + self.data_object = Batch( |
| 274 | + pos=positions, |
| 275 | + cell=cell, |
| 276 | + atomic_numbers=atomic_numbers, |
| 277 | + natoms=natoms, |
| 278 | + batch=batch, |
| 279 | + fixed=fixed, |
| 280 | + pbc=pbc, |
| 281 | + ) |
| 282 | + |
| 283 | + if self._dtype is not None: |
| 284 | + self.data_object.pos = self.data_object.pos.to(self._dtype) |
| 285 | + self.data_object.cell = self.data_object.cell.to(self._dtype) |
| 286 | + |
| 287 | + predictions = self.trainer.predict( |
| 288 | + self.data_object, per_image=False, disable_tqdm=True |
| 289 | + ) |
| 290 | + |
| 291 | + results = {} |
| 292 | + |
| 293 | + for key in predictions: |
| 294 | + _pred = predictions[key] |
| 295 | + if key in self._reshaped_props: |
| 296 | + _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() |
| 297 | + results[key] = _pred |
| 298 | + |
| 299 | + results["energy"] = results["energy"].squeeze() |
| 300 | + return results |
0 commit comments