@@ -590,7 +590,9 @@ def fire_init(
590
590
atomic_numbers = state .atomic_numbers .clone (),
591
591
system_idx = state .system_idx .clone (),
592
592
pbc = state .pbc ,
593
- velocities = None ,
593
+ velocities = torch .full (
594
+ state .positions .shape , torch .nan , device = device , dtype = dtype
595
+ ),
594
596
forces = forces ,
595
597
energy = energy ,
596
598
# Optimization attributes
@@ -863,13 +865,17 @@ def fire_init(
863
865
atomic_numbers = state .atomic_numbers .clone (),
864
866
system_idx = state .system_idx .clone (),
865
867
pbc = state .pbc ,
866
- velocities = None ,
868
+ velocities = torch .full (
869
+ state .positions .shape , torch .nan , device = device , dtype = dtype
870
+ ),
867
871
forces = forces ,
868
872
energy = energy ,
869
873
stress = stress ,
870
874
# Cell attributes
871
875
cell_positions = torch .zeros (n_systems , 3 , 3 , device = device , dtype = dtype ),
872
- cell_velocities = None ,
876
+ cell_velocities = torch .full (
877
+ cell_forces .shape , torch .nan , device = device , dtype = dtype
878
+ ),
873
879
cell_forces = cell_forces ,
874
880
cell_masses = cell_masses ,
875
881
# Optimization attributes
@@ -1162,13 +1168,17 @@ def fire_init(
1162
1168
atomic_numbers = state .atomic_numbers ,
1163
1169
system_idx = state .system_idx ,
1164
1170
pbc = state .pbc ,
1165
- velocities = None ,
1171
+ velocities = torch .full (
1172
+ state .positions .shape , torch .nan , device = device , dtype = dtype
1173
+ ),
1166
1174
forces = forces ,
1167
1175
energy = energy ,
1168
1176
stress = stress ,
1169
1177
# Cell attributes
1170
1178
cell_positions = cell_positions ,
1171
- cell_velocities = None ,
1179
+ cell_velocities = torch .full (
1180
+ cell_forces .shape , torch .nan , device = device , dtype = dtype
1181
+ ),
1172
1182
cell_forces = cell_forces ,
1173
1183
cell_masses = cell_masses ,
1174
1184
# Optimization attributes
@@ -1245,15 +1255,19 @@ def _vv_fire_step( # noqa: C901, PLR0915
1245
1255
dtype = state .positions .dtype
1246
1256
deform_grad_new : torch .Tensor | None = None
1247
1257
1248
- if state .velocities is None :
1249
- state .velocities = torch .zeros_like (state .positions )
1258
+ nan_velocities = state .velocities .isnan ().any (dim = 1 )
1259
+ if nan_velocities .any ():
1260
+ state .velocities [nan_velocities ] = torch .zeros_like (
1261
+ state .positions [nan_velocities ]
1262
+ )
1250
1263
if is_cell_optimization :
1251
1264
if not isinstance (state , AnyFireCellState ):
1252
1265
raise ValueError (
1253
1266
f"Cell optimization requires one of { get_args (AnyFireCellState )} ."
1254
1267
)
1255
- state .cell_velocities = torch .zeros (
1256
- (n_systems , 3 , 3 ), device = device , dtype = dtype
1268
+ nan_cell_velocities = state .cell_velocities .isnan ().any (dim = (1 , 2 ))
1269
+ state .cell_velocities [nan_cell_velocities ] = torch .zeros_like (
1270
+ state .cell_positions [nan_cell_velocities ]
1257
1271
)
1258
1272
1259
1273
alpha_start_system = torch .full (
@@ -1462,16 +1476,20 @@ def _ase_fire_step( # noqa: C901, PLR0915
1462
1476
1463
1477
cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError
1464
1478
1465
- if state .velocities is None :
1466
- state .velocities = torch .zeros_like (state .positions )
1479
+ nan_velocities = state .velocities .isnan ().any (dim = 1 )
1480
+ if nan_velocities .any ():
1481
+ state .velocities [nan_velocities ] = torch .zeros_like (
1482
+ state .positions [nan_velocities ]
1483
+ )
1467
1484
forces = state .forces
1468
1485
if is_cell_optimization :
1469
1486
if not isinstance (state , AnyFireCellState ):
1470
1487
raise ValueError (
1471
1488
f"Cell optimization requires one of { get_args (AnyFireCellState )} ."
1472
1489
)
1473
- state .cell_velocities = torch .zeros (
1474
- (n_systems , 3 , 3 ), device = device , dtype = dtype
1490
+ nan_cell_velocities = state .cell_velocities .isnan ().any (dim = (1 , 2 ))
1491
+ state .cell_velocities [nan_cell_velocities ] = torch .zeros_like (
1492
+ state .cell_positions [nan_cell_velocities ]
1475
1493
)
1476
1494
cur_deform_grad = state .deform_grad ()
1477
1495
else :
0 commit comments