11
11
12
12
13
13
StateDict = dict [
14
- Literal ["positions" , "masses" , "cell" , "pbc" , "atomic_numbers" , "batch" ], torch .Tensor
14
+ Literal ["positions" , "masses" , "cell" , "pbc" , "atomic_numbers" , "batch" ],
15
+ torch .Tensor ,
15
16
]
16
17
17
18
@@ -439,7 +440,7 @@ class BatchedUnitCellFireState(BaseState):
439
440
forces : torch .Tensor # [n_total_atoms, 3]
440
441
energy : torch .Tensor # [n_batches]
441
442
stress : torch .Tensor # [n_batches, 3, 3]
442
- velocity : torch .Tensor # [n_total_atoms, 3]
443
+ velocities : torch .Tensor # [n_total_atoms, 3]
443
444
444
445
# cell attributes
445
446
cell_positions : torch .Tensor # [n_batches * 3, 3]
@@ -459,6 +460,11 @@ class BatchedUnitCellFireState(BaseState):
459
460
hydrostatic_strain : bool
460
461
constant_volume : bool
461
462
463
+ @property
464
+ def momenta (self ) -> torch .Tensor :
465
+ """Atomwise momenta of the system."""
466
+ return self .velocities * self .masses .unsqueeze (- 1 )
467
+
462
468
463
469
def unit_cell_fire ( # noqa: C901, PLR0915
464
470
model : torch .nn .Module ,
@@ -531,7 +537,11 @@ def unit_cell_fire( # noqa: C901, PLR0915
531
537
# Setup parameters
532
538
params = [dt_max , dt_start , alpha_start , f_inc , f_dec , f_alpha , n_min ]
533
539
dt_max , dt_start , alpha_start , f_inc , f_dec , f_alpha , n_min = [
534
- p if isinstance (p , torch .Tensor ) else torch .tensor (p , device = device , dtype = dtype )
540
+ (
541
+ p
542
+ if isinstance (p , torch .Tensor )
543
+ else torch .tensor (p , device = device , dtype = dtype )
544
+ )
535
545
for p in params
536
546
]
537
547
@@ -639,7 +649,7 @@ def fire_init(
639
649
batch = state .batch .clone (),
640
650
pbc = state .pbc ,
641
651
# new attrs
642
- velocity = torch .zeros_like (state .positions ),
652
+ velocities = torch .zeros_like (state .positions ),
643
653
forces = forces ,
644
654
energy = energy ,
645
655
stress = stress ,
@@ -700,7 +710,7 @@ def fire_step( # noqa: PLR0915
700
710
atom_wise_dt = state .dt [state .batch ].unsqueeze (- 1 )
701
711
cell_wise_dt = state .dt .repeat_interleave (3 ).unsqueeze (- 1 )
702
712
703
- state .velocity += 0.5 * atom_wise_dt * state .forces / state .masses .unsqueeze (- 1 )
713
+ state .velocities += 0.5 * atom_wise_dt * state .forces / state .masses .unsqueeze (- 1 )
704
714
state .cell_velocities += (
705
715
0.5 * cell_wise_dt * state .cell_forces / state .cell_masses .unsqueeze (- 1 )
706
716
)
@@ -709,7 +719,7 @@ def fire_step( # noqa: PLR0915
709
719
atomic_positions = state .positions # shape: (n_atoms, 3)
710
720
711
721
# Update atomic and cell positions
712
- atomic_positions_new = atomic_positions + atom_wise_dt * state .velocity
722
+ atomic_positions_new = atomic_positions + atom_wise_dt * state .velocities
713
723
cell_positions_new = cell_positions + cell_wise_dt * state .cell_velocities
714
724
715
725
# Update cell with deformation gradient
@@ -757,13 +767,13 @@ def fire_step( # noqa: PLR0915
757
767
state .cell_forces = virial_flat
758
768
759
769
# Velocity Verlet first half step (v += 0.5*a*dt)
760
- state .velocity += 0.5 * atom_wise_dt * state .forces / state .masses .unsqueeze (- 1 )
770
+ state .velocities += 0.5 * atom_wise_dt * state .forces / state .masses .unsqueeze (- 1 )
761
771
state .cell_velocities += (
762
772
0.5 * cell_wise_dt * state .cell_forces / state .cell_masses .unsqueeze (- 1 )
763
773
)
764
774
765
775
# Calculate power (F·V) for atoms
766
- atomic_power = (state .forces * state .velocity ).sum (dim = 1 ) # [n_atoms]
776
+ atomic_power = (state .forces * state .velocities ).sum (dim = 1 ) # [n_atoms]
767
777
atomic_power_per_batch = torch .zeros (
768
778
n_batches , device = device , dtype = atomic_power .dtype
769
779
)
@@ -798,12 +808,12 @@ def fire_step( # noqa: PLR0915
798
808
state .dt [batch_idx ] = state .dt [batch_idx ] * f_dec
799
809
state .alpha [batch_idx ] = alpha_start [batch_idx ]
800
810
# Reset velocities for both atoms and cell
801
- state .velocity [state .batch == batch_idx ] = 0
811
+ state .velocities [state .batch == batch_idx ] = 0
802
812
cell_batch = torch .arange (n_batches , device = device ).repeat_interleave (3 )
803
813
state .cell_velocities [cell_batch == batch_idx ] = 0
804
814
805
815
# Mix velocity and force direction using FIRE for atoms
806
- v_norm = torch .norm (state .velocity , dim = 1 , keepdim = True )
816
+ v_norm = torch .norm (state .velocities , dim = 1 , keepdim = True )
807
817
f_norm = torch .norm (state .forces , dim = 1 , keepdim = True )
808
818
# Avoid division by zero
809
819
# mask = f_norm > 1e-10
@@ -814,9 +824,9 @@ def fire_step( # noqa: PLR0915
814
824
# state.velocity,
815
825
# )
816
826
batch_wise_alpha = state .alpha [state .batch ].unsqueeze (- 1 )
817
- state .velocity = (
827
+ state .velocities = (
818
828
1.0 - batch_wise_alpha
819
- ) * state .velocity + batch_wise_alpha * state .forces * v_norm / (f_norm + 1e-10 )
829
+ ) * state .velocities + batch_wise_alpha * state .forces * v_norm / (f_norm + 1e-10 )
820
830
821
831
# Mix velocity and force direction for cell DOFs
822
832
cell_v_norm = torch .norm (state .cell_velocities , dim = 1 , keepdim = True )
0 commit comments