Skip to content

Commit e90b272

Browse files
authored
Initial fix for concatenation of states in InFlightAutoBatcher (#219)
1 parent f6cd006 commit e90b272

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

tests/test_autobatching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def convergence_fn(state: ts.SimState) -> bool:
490490
break
491491

492492
# run 10 steps, arbitrary number
493-
for _ in range(10):
493+
for _ in range(5):
494494
state = fire_update(state)
495495
convergence_tensor = convergence_fn(state)
496496

torch_sim/optimizers.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,9 @@ def fire_init(
590590
atomic_numbers=state.atomic_numbers.clone(),
591591
system_idx=state.system_idx.clone(),
592592
pbc=state.pbc,
593-
velocities=None,
593+
velocities=torch.full(
594+
state.positions.shape, torch.nan, device=device, dtype=dtype
595+
),
594596
forces=forces,
595597
energy=energy,
596598
# Optimization attributes
@@ -863,13 +865,17 @@ def fire_init(
863865
atomic_numbers=state.atomic_numbers.clone(),
864866
system_idx=state.system_idx.clone(),
865867
pbc=state.pbc,
866-
velocities=None,
868+
velocities=torch.full(
869+
state.positions.shape, torch.nan, device=device, dtype=dtype
870+
),
867871
forces=forces,
868872
energy=energy,
869873
stress=stress,
870874
# Cell attributes
871875
cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype),
872-
cell_velocities=None,
876+
cell_velocities=torch.full(
877+
cell_forces.shape, torch.nan, device=device, dtype=dtype
878+
),
873879
cell_forces=cell_forces,
874880
cell_masses=cell_masses,
875881
# Optimization attributes
@@ -1162,13 +1168,17 @@ def fire_init(
11621168
atomic_numbers=state.atomic_numbers,
11631169
system_idx=state.system_idx,
11641170
pbc=state.pbc,
1165-
velocities=None,
1171+
velocities=torch.full(
1172+
state.positions.shape, torch.nan, device=device, dtype=dtype
1173+
),
11661174
forces=forces,
11671175
energy=energy,
11681176
stress=stress,
11691177
# Cell attributes
11701178
cell_positions=cell_positions,
1171-
cell_velocities=None,
1179+
cell_velocities=torch.full(
1180+
cell_forces.shape, torch.nan, device=device, dtype=dtype
1181+
),
11721182
cell_forces=cell_forces,
11731183
cell_masses=cell_masses,
11741184
# Optimization attributes
@@ -1245,15 +1255,19 @@ def _vv_fire_step( # noqa: C901, PLR0915
12451255
dtype = state.positions.dtype
12461256
deform_grad_new: torch.Tensor | None = None
12471257

1248-
if state.velocities is None:
1249-
state.velocities = torch.zeros_like(state.positions)
1258+
nan_velocities = state.velocities.isnan().any(dim=1)
1259+
if nan_velocities.any():
1260+
state.velocities[nan_velocities] = torch.zeros_like(
1261+
state.positions[nan_velocities]
1262+
)
12501263
if is_cell_optimization:
12511264
if not isinstance(state, AnyFireCellState):
12521265
raise ValueError(
12531266
f"Cell optimization requires one of {get_args(AnyFireCellState)}."
12541267
)
1255-
state.cell_velocities = torch.zeros(
1256-
(n_systems, 3, 3), device=device, dtype=dtype
1268+
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
1269+
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
1270+
state.cell_positions[nan_cell_velocities]
12571271
)
12581272

12591273
alpha_start_system = torch.full(
@@ -1462,16 +1476,20 @@ def _ase_fire_step( # noqa: C901, PLR0915
14621476

14631477
cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError
14641478

1465-
if state.velocities is None:
1466-
state.velocities = torch.zeros_like(state.positions)
1479+
nan_velocities = state.velocities.isnan().any(dim=1)
1480+
if nan_velocities.any():
1481+
state.velocities[nan_velocities] = torch.zeros_like(
1482+
state.positions[nan_velocities]
1483+
)
14671484
forces = state.forces
14681485
if is_cell_optimization:
14691486
if not isinstance(state, AnyFireCellState):
14701487
raise ValueError(
14711488
f"Cell optimization requires one of {get_args(AnyFireCellState)}."
14721489
)
1473-
state.cell_velocities = torch.zeros(
1474-
(n_systems, 3, 3), device=device, dtype=dtype
1490+
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
1491+
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
1492+
state.cell_positions[nan_cell_velocities]
14751493
)
14761494
cur_deform_grad = state.deform_grad()
14771495
else:

0 commit comments

Comments
 (0)