Skip to content

Commit 71e1d41

Browse files
authored
Fix simstate concatenation [2/2] (#232)
1 parent e90b272 commit 71e1d41

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

tests/test_autobatching.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order(
448448
# batcher.restore_original_order([si_sim_state])
449449

450450

451+
@pytest.mark.parametrize(
452+
"num_steps_per_batch",
453+
[
454+
5, # At 5 steps, not every state will converge before the next batch.
455+
# This tests the merging of partially converged states with new states
456+
# which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219
457+
10, # At 10 steps, all states will converge before the next batch
458+
],
459+
)
451460
def test_in_flight_with_fire(
452461
si_sim_state: ts.SimState,
453462
fe_supercell_sim_state: ts.SimState,
454463
lj_model: LennardJonesModel,
464+
num_steps_per_batch: int,
455465
) -> None:
456466
fire_init, fire_update = unit_cell_fire(lj_model)
457467

@@ -489,8 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool:
489499
if state is None:
490500
break
491501

492-
# run 10 steps, arbitrary number
493-
for _ in range(5):
502+
for _ in range(num_steps_per_batch):
494503
state = fire_update(state)
495504
convergence_tensor = convergence_fn(state)
496505

tests/test_state.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,16 @@ def test_deprecated_batch_properties_equal_to_new_system_properties(
635635
state.batch = new_system_idx
636636
assert torch.allclose(state.system_idx, new_system_idx)
637637
assert torch.allclose(state.batch, new_system_idx)
638+
639+
640+
def test_derived_classes_trigger_init_subclass() -> None:
641+
"""Test that derived classes cannot have attributes that are "tensors | None"."""
642+
643+
with pytest.raises(TypeError) as excinfo:
644+
645+
class DerivedState(SimState):
646+
invalid_attr: torch.Tensor | None = None
647+
648+
assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str(
649+
excinfo.value
650+
)

torch_sim/optimizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ class FireState(SimState):
475475
# Required attributes not in SimState
476476
forces: torch.Tensor
477477
energy: torch.Tensor
478-
velocities: torch.Tensor | None
478+
velocities: torch.Tensor
479479

480480
# FIRE algorithm parameters
481481
dt: torch.Tensor
@@ -972,7 +972,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
972972

973973
# Cell attributes
974974
cell_positions: torch.Tensor
975-
cell_velocities: torch.Tensor | None
975+
cell_velocities: torch.Tensor
976976
cell_forces: torch.Tensor
977977
cell_masses: torch.Tensor
978978

torch_sim/runners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,8 @@ def static(
538538
@dataclass
539539
class StaticState(type(state)):
540540
energy: torch.Tensor
541-
forces: torch.Tensor | None
542-
stress: torch.Tensor | None
541+
forces: torch.Tensor
542+
stress: torch.Tensor
543543

544544
all_props: list[dict[str, torch.Tensor]] = []
545545
og_filenames = trajectory_reporter.filenames

torch_sim/state.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import importlib
9+
import typing
910
import warnings
1011
from dataclasses import dataclass
1112
from typing import TYPE_CHECKING, Literal, Self, cast
@@ -400,6 +401,33 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) ->
400401

401402
return _slice_state(self, system_indices)
402403

404+
def __init_subclass__(cls, **kwargs) -> None:
405+
"""Enforce that all derived states cannot have tensor attributes that can also be
406+
None. This is because torch.concatenate cannot concat between a tensor and a None.
407+
See https://github.com/Radical-AI/torch-sim/pull/219 for more details.
408+
"""
409+
# We need to use get_type_hints to correctly inspect the types
410+
type_hints = typing.get_type_hints(cls)
411+
for attr_name, attr_typehint in type_hints.items():
412+
origin = typing.get_origin(attr_typehint)
413+
414+
is_union = origin is typing.Union
415+
if not is_union and origin is not None:
416+
# For Python 3.10+ `|` syntax, origin is types.UnionType
417+
# We check by name to be robust against module reloading/patching issues
418+
is_union = origin.__module__ == "types" and origin.__name__ == "UnionType"
419+
if is_union:
420+
args = typing.get_args(attr_typehint)
421+
if torch.Tensor in args and type(None) in args:
422+
raise TypeError(
423+
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
424+
"allowed to be of type 'torch.Tensor | None' because torch.cat "
425+
"cannot concatenate between a tensor and a None. Please default "
426+
"the tensor with dummy values and track the 'None' case."
427+
)
428+
429+
super().__init_subclass__(**kwargs)
430+
403431

404432
class DeformGradMixin:
405433
"""Mixin for states that support deformation gradients."""

0 commit comments

Comments
 (0)