7
7
import copy
8
8
import importlib
9
9
import warnings
10
- from dataclasses import dataclass , field
10
+ from dataclasses import dataclass
11
11
from typing import TYPE_CHECKING , Literal , Self
12
12
13
13
import torch
22
22
from pymatgen .core import Structure
23
23
24
24
25
- @dataclass
25
+ @dataclass ( init = False )
26
26
class SimState :
27
27
"""State representation for atomistic systems with batched operations support.
28
28
@@ -47,9 +47,8 @@ class SimState:
47
47
used by ASE.
48
48
pbc (bool): Boolean indicating whether to use periodic boundary conditions
49
49
atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)
50
- system_idx (torch.Tensor, optional): Maps each atom index to its system index.
51
- Has shape (n_atoms,), defaults to None, must be unique consecutive
52
- integers starting from 0
50
+ system_idx (torch.Tensor): Maps each atom index to its system index.
51
+ Has shape (n_atoms,), must be unique consecutive integers starting from 0.
53
52
54
53
Properties:
55
54
wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary
@@ -81,10 +80,35 @@ class SimState:
81
80
cell : torch .Tensor
82
81
pbc : bool # TODO: do all calculators support mixed pbc?
83
82
atomic_numbers : torch .Tensor
84
- system_idx : torch .Tensor | None = field (default = None , kw_only = True )
83
+ system_idx : torch .Tensor
84
+
85
+ def __init__ (
86
+ self ,
87
+ positions : torch .Tensor ,
88
+ masses : torch .Tensor ,
89
+ cell : torch .Tensor ,
90
+ pbc : bool , # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.)
91
+ atomic_numbers : torch .Tensor ,
92
+ system_idx : torch .Tensor | None = None ,
93
+ ) -> None :
94
+ """Initialize the SimState and validate the arguments.
95
+
96
+ Args:
97
+ positions (torch.Tensor): Atomic positions with shape (n_atoms, 3)
98
+ masses (torch.Tensor): Atomic masses with shape (n_atoms,)
99
+ cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3).
100
+ pbc (bool): Boolean indicating whether to use periodic boundary conditions
101
+ atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)
102
+ system_idx (torch.Tensor | None): Maps each atom index to its system index.
103
+ Has shape (n_atoms,), must be unique consecutive integers starting from 0.
104
+ If not provided, it is initialized to zeros.
105
+ """
106
+ self .positions = positions
107
+ self .masses = masses
108
+ self .cell = cell
109
+ self .pbc = pbc
110
+ self .atomic_numbers = atomic_numbers
85
111
86
- def __post_init__ (self ) -> None :
87
- """Validate and process the state after initialization."""
88
112
# data validation and fill system_idx
89
113
# should make pbc a tensor here
90
114
# if devices aren't all the same, raise an error, in a clean way
@@ -107,24 +131,25 @@ def __post_init__(self) -> None:
107
131
f"masses { shapes [1 ]} , atomic_numbers { shapes [2 ]} "
108
132
)
109
133
110
- if self .cell .ndim != 3 and self .system_idx is None :
111
- self .cell = self .cell .unsqueeze (0 )
112
-
113
- if self .cell .shape [- 2 :] != (3 , 3 ):
114
- raise ValueError ("Cell must have shape (n_systems, 3, 3)" )
115
-
116
- if self .system_idx is None :
134
+ if system_idx is None :
117
135
self .system_idx = torch .zeros (
118
136
self .n_atoms , device = self .device , dtype = torch .int64
119
137
)
120
138
else :
139
+ self .system_idx = system_idx
121
140
# assert that system indices are unique consecutive integers
122
141
# TODO(curtis): I feel like this logic is not reliable.
123
142
# I'll come up with something better later.
124
143
_ , counts = torch .unique_consecutive (self .system_idx , return_counts = True )
125
144
if not torch .all (counts == torch .bincount (self .system_idx )):
126
145
raise ValueError ("System indices must be unique consecutive integers" )
127
146
147
+ if self .cell .ndim != 3 and system_idx is None :
148
+ self .cell = self .cell .unsqueeze (0 )
149
+
150
+ if self .cell .shape [- 2 :] != (3 , 3 ):
151
+ raise ValueError ("Cell must have shape (n_systems, 3, 3)" )
152
+
128
153
if self .cell .shape [0 ] != self .n_systems :
129
154
raise ValueError (
130
155
f"Cell must have shape (n_systems, 3, 3), got { self .cell .shape } "
0 commit comments