Skip to content

Commit aa80654

Browse files
authored
Fix unit_cell_fire RuntimeError from bad dtype (#27)
* fix RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype torchsim/optimizers.py", line 768, in fire_step atomic_power_per_batch.scatter_add_( RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype * fix MaceModel init by setting self.pbc_template ahead of first potential use in setup_from_batch
1 parent 7944feb commit aa80654

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

torchsim/models/mace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
)
9595

9696
# setup system boundary conditions
97-
pbc = [True] * 3 if periodic else [False] * 3
97+
pbc = [periodic] * 3
9898
self.pbc = torch.tensor([pbc], device=self.device)
9999

100100
if atomic_numbers is not None:
@@ -331,6 +331,11 @@ def __init__(
331331
# Store flag to track if atomic numbers were provided at init
332332
self.atomic_numbers_in_init = atomic_numbers is not None
333333

334+
# Set PBC
335+
pbc = [periodic] * 3
336+
self.pbc_template = torch.tensor([pbc], device=self._device)
337+
self.pbc = None # Will be set in forward
338+
334339
# Set up batch information if atomic numbers are provided
335340
if atomic_numbers is not None:
336341
if batch is None:
@@ -341,11 +346,6 @@ def __init__(
341346

342347
self.setup_from_batch(atomic_numbers, batch)
343348

344-
# Set PBC
345-
pbc = [True] * 3 if periodic else [False] * 3
346-
self.pbc_template = torch.tensor([pbc], device=self._device)
347-
self.pbc = None # Will be set in forward
348-
349349
def setup_from_batch(self, atomic_numbers: torch.Tensor, batch: torch.Tensor) -> None:
350350
"""Set up internal state from atomic numbers and batch indices.
351351

torchsim/optimizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,9 @@ def fire_step( # noqa: PLR0915
764764

765765
# Calculate power (F·V) for atoms
766766
atomic_power = (state.forces * state.velocity).sum(dim=1) # [n_atoms]
767-
atomic_power_per_batch = torch.zeros(n_batches, device=device, dtype=dtype)
767+
atomic_power_per_batch = torch.zeros(
768+
n_batches, device=device, dtype=atomic_power.dtype
769+
)
768770
atomic_power_per_batch.scatter_add_(
769771
dim=0, index=state.batch, src=atomic_power
770772
) # [n_batches]
@@ -774,7 +776,9 @@ def fire_step( # noqa: PLR0915
774776
dim=1
775777
) # [n_batches*3]
776778
cell_batch = torch.arange(n_batches, device=device).repeat_interleave(3)
777-
cell_power_per_batch = torch.zeros(n_batches, device=device, dtype=dtype)
779+
cell_power_per_batch = torch.zeros(
780+
n_batches, device=device, dtype=cell_power.dtype
781+
)
778782
cell_power_per_batch.scatter_add_(
779783
dim=0, index=cell_batch, src=cell_power
780784
) # [n_batches]

0 commit comments

Comments
 (0)