Skip to content

Commit 88abcff

Browse files
authored
Make system_idx non-optional in SimState [1/2] (#231)
1 parent 926e043 commit 88abcff

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

torch_sim/integrators/nvt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def nvt_nose_hoover_init(
389389
cell=state.cell,
390390
pbc=state.pbc,
391391
atomic_numbers=atomic_numbers,
392+
system_idx=state.system_idx,
392393
chain=chain_fns.initialize(total_dof, KE, kT),
393394
_chain_fns=chain_fns, # Store the chain functions
394395
)

torch_sim/state.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import copy
88
import importlib
99
import warnings
10-
from dataclasses import dataclass, field
10+
from dataclasses import dataclass
1111
from typing import TYPE_CHECKING, Literal, Self
1212

1313
import torch
@@ -22,7 +22,7 @@
2222
from pymatgen.core import Structure
2323

2424

25-
@dataclass
25+
@dataclass(init=False)
2626
class SimState:
2727
"""State representation for atomistic systems with batched operations support.
2828
@@ -47,9 +47,8 @@ class SimState:
4747
used by ASE.
4848
pbc (bool): Boolean indicating whether to use periodic boundary conditions
4949
atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)
50-
system_idx (torch.Tensor, optional): Maps each atom index to its system index.
51-
Has shape (n_atoms,), defaults to None, must be unique consecutive
52-
integers starting from 0
50+
system_idx (torch.Tensor): Maps each atom index to its system index.
51+
Has shape (n_atoms,), must be unique consecutive integers starting from 0.
5352
5453
Properties:
5554
wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary
@@ -81,10 +80,35 @@ class SimState:
8180
cell: torch.Tensor
8281
pbc: bool # TODO: do all calculators support mixed pbc?
8382
atomic_numbers: torch.Tensor
84-
system_idx: torch.Tensor | None = field(default=None, kw_only=True)
83+
system_idx: torch.Tensor
84+
85+
def __init__(
86+
self,
87+
positions: torch.Tensor,
88+
masses: torch.Tensor,
89+
cell: torch.Tensor,
90+
pbc: bool, # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.)
91+
atomic_numbers: torch.Tensor,
92+
system_idx: torch.Tensor | None = None,
93+
) -> None:
94+
"""Initialize the SimState and validate the arguments.
95+
96+
Args:
97+
positions (torch.Tensor): Atomic positions with shape (n_atoms, 3)
98+
masses (torch.Tensor): Atomic masses with shape (n_atoms,)
99+
cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3).
100+
pbc (bool): Boolean indicating whether to use periodic boundary conditions
101+
atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)
102+
system_idx (torch.Tensor | None): Maps each atom index to its system index.
103+
Has shape (n_atoms,), must be unique consecutive integers starting from 0.
104+
If not provided, it is initialized to zeros.
105+
"""
106+
self.positions = positions
107+
self.masses = masses
108+
self.cell = cell
109+
self.pbc = pbc
110+
self.atomic_numbers = atomic_numbers
85111

86-
def __post_init__(self) -> None:
87-
"""Validate and process the state after initialization."""
88112
# data validation and fill system_idx
89113
# should make pbc a tensor here
90114
# if devices aren't all the same, raise an error, in a clean way
@@ -107,24 +131,25 @@ def __post_init__(self) -> None:
107131
f"masses {shapes[1]}, atomic_numbers {shapes[2]}"
108132
)
109133

110-
if self.cell.ndim != 3 and self.system_idx is None:
111-
self.cell = self.cell.unsqueeze(0)
112-
113-
if self.cell.shape[-2:] != (3, 3):
114-
raise ValueError("Cell must have shape (n_systems, 3, 3)")
115-
116-
if self.system_idx is None:
134+
if system_idx is None:
117135
self.system_idx = torch.zeros(
118136
self.n_atoms, device=self.device, dtype=torch.int64
119137
)
120138
else:
139+
self.system_idx = system_idx
121140
# assert that system indices are unique consecutive integers
122141
# TODO(curtis): I feel like this logic is not reliable.
123142
# I'll come up with something better later.
124143
_, counts = torch.unique_consecutive(self.system_idx, return_counts=True)
125144
if not torch.all(counts == torch.bincount(self.system_idx)):
126145
raise ValueError("System indices must be unique consecutive integers")
127146

147+
if self.cell.ndim != 3 and system_idx is None:
148+
self.cell = self.cell.unsqueeze(0)
149+
150+
if self.cell.shape[-2:] != (3, 3):
151+
raise ValueError("Cell must have shape (n_systems, 3, 3)")
152+
128153
if self.cell.shape[0] != self.n_systems:
129154
raise ValueError(
130155
f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}"

0 commit comments

Comments
 (0)