Skip to content

Commit 244eb80

Browse files
authored
correct naming convention for BatchedUnitCellFireState and add momenta property for agreement with other States (#30)
1 parent 8d3cfa4 commit 244eb80

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

torchsim/optimizers.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212

1313
StateDict = dict[
14-
Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "batch"], torch.Tensor
14+
Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "batch"],
15+
torch.Tensor,
1516
]
1617

1718

@@ -439,7 +440,7 @@ class BatchedUnitCellFireState(BaseState):
439440
forces: torch.Tensor # [n_total_atoms, 3]
440441
energy: torch.Tensor # [n_batches]
441442
stress: torch.Tensor # [n_batches, 3, 3]
442-
velocity: torch.Tensor # [n_total_atoms, 3]
443+
velocities: torch.Tensor # [n_total_atoms, 3]
443444

444445
# cell attributes
445446
cell_positions: torch.Tensor # [n_batches * 3, 3]
@@ -459,6 +460,11 @@ class BatchedUnitCellFireState(BaseState):
459460
hydrostatic_strain: bool
460461
constant_volume: bool
461462

463+
@property
464+
def momenta(self) -> torch.Tensor:
465+
"""Atomwise momenta of the system."""
466+
return self.velocities * self.masses.unsqueeze(-1)
467+
462468

463469
def unit_cell_fire( # noqa: C901, PLR0915
464470
model: torch.nn.Module,
@@ -531,7 +537,11 @@ def unit_cell_fire( # noqa: C901, PLR0915
531537
# Setup parameters
532538
params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min]
533539
dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [
534-
p if isinstance(p, torch.Tensor) else torch.tensor(p, device=device, dtype=dtype)
540+
(
541+
p
542+
if isinstance(p, torch.Tensor)
543+
else torch.tensor(p, device=device, dtype=dtype)
544+
)
535545
for p in params
536546
]
537547

@@ -639,7 +649,7 @@ def fire_init(
639649
batch=state.batch.clone(),
640650
pbc=state.pbc,
641651
# new attrs
642-
velocity=torch.zeros_like(state.positions),
652+
velocities=torch.zeros_like(state.positions),
643653
forces=forces,
644654
energy=energy,
645655
stress=stress,
@@ -700,7 +710,7 @@ def fire_step( # noqa: PLR0915
700710
atom_wise_dt = state.dt[state.batch].unsqueeze(-1)
701711
cell_wise_dt = state.dt.repeat_interleave(3).unsqueeze(-1)
702712

703-
state.velocity += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
713+
state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
704714
state.cell_velocities += (
705715
0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1)
706716
)
@@ -709,7 +719,7 @@ def fire_step( # noqa: PLR0915
709719
atomic_positions = state.positions # shape: (n_atoms, 3)
710720

711721
# Update atomic and cell positions
712-
atomic_positions_new = atomic_positions + atom_wise_dt * state.velocity
722+
atomic_positions_new = atomic_positions + atom_wise_dt * state.velocities
713723
cell_positions_new = cell_positions + cell_wise_dt * state.cell_velocities
714724

715725
# Update cell with deformation gradient
@@ -757,13 +767,13 @@ def fire_step( # noqa: PLR0915
757767
state.cell_forces = virial_flat
758768

759769
# Velocity Verlet first half step (v += 0.5*a*dt)
760-
state.velocity += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
770+
state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
761771
state.cell_velocities += (
762772
0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1)
763773
)
764774

765775
# Calculate power (F·V) for atoms
766-
atomic_power = (state.forces * state.velocity).sum(dim=1) # [n_atoms]
776+
atomic_power = (state.forces * state.velocities).sum(dim=1) # [n_atoms]
767777
atomic_power_per_batch = torch.zeros(
768778
n_batches, device=device, dtype=atomic_power.dtype
769779
)
@@ -798,12 +808,12 @@ def fire_step( # noqa: PLR0915
798808
state.dt[batch_idx] = state.dt[batch_idx] * f_dec
799809
state.alpha[batch_idx] = alpha_start[batch_idx]
800810
# Reset velocities for both atoms and cell
801-
state.velocity[state.batch == batch_idx] = 0
811+
state.velocities[state.batch == batch_idx] = 0
802812
cell_batch = torch.arange(n_batches, device=device).repeat_interleave(3)
803813
state.cell_velocities[cell_batch == batch_idx] = 0
804814

805815
# Mix velocity and force direction using FIRE for atoms
806-
v_norm = torch.norm(state.velocity, dim=1, keepdim=True)
816+
v_norm = torch.norm(state.velocities, dim=1, keepdim=True)
807817
f_norm = torch.norm(state.forces, dim=1, keepdim=True)
808818
# Avoid division by zero
809819
# mask = f_norm > 1e-10
@@ -814,9 +824,9 @@ def fire_step( # noqa: PLR0915
814824
# state.velocity,
815825
# )
816826
batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1)
817-
state.velocity = (
827+
state.velocities = (
818828
1.0 - batch_wise_alpha
819-
) * state.velocity + batch_wise_alpha * state.forces * v_norm / (f_norm + 1e-10)
829+
) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + 1e-10)
820830

821831
# Mix velocity and force direction for cell DOFs
822832
cell_v_norm = torch.norm(state.cell_velocities, dim=1, keepdim=True)

0 commit comments

Comments
 (0)