-
Notifications
You must be signed in to change notification settings - Fork 38
InitVar dataclass initialization (and subclass checks) #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe changes refactor how system indices and initialization variables are handled in the simulation state classes. The Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Atoms
participant io.py
participant SimState
User->>Atoms: Provides Atoms object
io.py->>Atoms: Reads system_idx from Atoms
io.py->>SimState: Constructs with init_system_idx=system_idx
SimState->>SimState: __post_init__ sets system_idx
SimState-->>io.py: Returns initialized SimState
sequenceDiagram
participant Optimizer
participant FireState
participant construct_state
Optimizer->>FireState: Requests initialization
FireState->>construct_state: Uses new attribute dict with init_system_idx
construct_state->>FireState: Instantiates with InitVar
FireState-->>Optimizer: Returns new FireState instance
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Poem
✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (4)
torch_sim/io.py (2)
306-313
: Missing parameter rename instructures_to_state
The
system_idx
parameter should be renamed toinit_system_idx
to match theSimState
refactoring, similar to the change made inatoms_to_state
.Apply this diff:
return ts.SimState( positions=positions, masses=masses, cell=cell, pbc=True, # Structures are always periodic atomic_numbers=atomic_numbers, - system_idx=system_idx, + init_system_idx=system_idx, )
384-391
: Missing parameter rename inphonopy_to_state
The
system_idx
parameter should be renamed toinit_system_idx
to match theSimState
refactoring.Apply this diff:
return ts.SimState( positions=positions, masses=masses, cell=cell, pbc=True, atomic_numbers=atomic_numbers, - system_idx=system_idx, + init_system_idx=system_idx, )torch_sim/optimizers.py (2)
865-872
: Type inconsistency with velocities initialization in UnitCellFireStateThe
velocities
andcell_velocities
fields are typed as non-optional but initialized asNone
.Initialize with zero tensors:
pbc=state.pbc, -velocities=None, +velocities=torch.zeros_like(state.positions), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), -cell_velocities=None, +cell_velocities=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), cell_forces=cell_forces,
1163-1171
: Type inconsistency with velocities initialization in FrechetCellFIREStateThe
velocities
andcell_velocities
fields are typed as non-optional but initialized asNone
.Initialize with zero tensors:
pbc=state.pbc, -velocities=None, +velocities=torch.zeros_like(state.positions), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, -cell_velocities=None, +cell_velocities=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), cell_forces=cell_forces,
🧹 Nitpick comments (3)
torch_sim/optimizers.py (1)
1247-1257
: Reconsider velocity initialization strategyWith velocities now being non-optional, the
if state.velocities is None:
check becomes problematic. Consider either:
- Keep velocities optional in the type system, or
- Use a different mechanism to track whether velocities have been initialized (e.g., a separate boolean flag)
The current approach of initializing velocities as zero tensors and then checking for None won't work as intended.
torch_sim/state.py (2)
379-432
: Consider improving error messages for better developer experienceThe validation logic is excellent, but the error messages could be more actionable. Consider adding examples to help developers fix issues quickly.
For example:
raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " "allowed to be of type 'torch.Tensor | None'. " "Optional tensor attributes are disallowed in SimState " "subclasses to prevent concatenation errors.\n" "If this attribute will take on a default value in the " "post_init method, please use an InitVar for that attribute " - "but with a prepended 'init_' to the name. (e.g. init_system_idx)" + "but with a prepended 'init_' to the name.\n" + f"Example: Change '{attr_name}: torch.Tensor | None' to:\n" + f" {attr_name}: torch.Tensor = field(init=False)\n" + f" init_{attr_name}: InitVar[torch.Tensor | None]" )
786-787
: Move TypeVar declaration to module levelTypeVar declarations are typically placed at the module level with other imports for better organization and reusability.
Move this to the top of the file after imports:
+SimStateT = TypeVar("SimStateT", bound=SimState) + @dataclass class SimState:And remove lines 786-787.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
torch_sim/io.py
(1 hunks)torch_sim/optimizers.py
(3 hunks)torch_sim/state.py
(10 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/io.py (2)
tests/test_io.py (2)
test_single_atoms_to_state
(58-70)test_multiple_atoms_to_state
(73-88)tests/test_state.py (1)
test_initialize_state_from_atoms
(308-314)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (22)
- GitHub Check: test-examples (examples/scripts/7_Others/7.3_Batched_neighbor_list.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.3_Fairchem.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
🔇 Additional comments (3)
torch_sim/io.py (1)
238-245
: LGTM!The parameter rename from
system_idx
toinit_system_idx
correctly aligns with theSimState
refactoring to useInitVar
for initialization.torch_sim/state.py (2)
86-88
: LGTM! Clean InitVar implementationThe refactoring of
system_idx
to useInitVar
pattern is well-designed and provides explicit initialization control.
113-131
: LGTM! Proper initialization orderThe reordering of cell shape adjustment to occur after
system_idx
initialization is correct, asn_systems
calculation depends onsystem_idx
.
torch_sim/state.py
Outdated
def construct_state( | ||
old_state: SimStateT, | ||
new_state_attrs: dict[str, typing.Any], | ||
) -> SimStateT: | ||
"""Construct a new state from an old state and new state parameters.""" | ||
# 1) process the attrs so they are the init params | ||
processed_params = {} | ||
for param in inspect.signature(old_state.__class__).parameters: | ||
if param.startswith("init_"): | ||
# this is an InitVar field | ||
# we need to rename the corresponding field in system_attrs to have | ||
# an "init_" prefix | ||
non_init_attr_name = param.removeprefix("init_") | ||
processed_params[param] = new_state_attrs[non_init_attr_name] | ||
else: | ||
processed_params[param] = new_state_attrs[param] | ||
|
||
# 2) construct the new state | ||
return type(old_state)(**processed_params) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add comprehensive docstring to construct_state
The construct_state
function is a key utility but lacks documentation. Please add a docstring following Google style as mentioned in the PR objectives.
def construct_state(
old_state: SimStateT,
new_state_attrs: dict[str, typing.Any],
) -> SimStateT:
- """Construct a new state from an old state and new state parameters."""
+ """Construct a new state from an old state and new state parameters.
+
+ This function handles the mapping of InitVar fields by automatically
+ prefixing the corresponding attribute names with 'init_' when calling
+ the constructor.
+
+ Args:
+ old_state: The state instance whose type will be used for construction
+ new_state_attrs: Dictionary of attributes for the new state. Keys
+ corresponding to InitVar fields should not have the 'init_' prefix;
+ it will be added automatically.
+
+ Returns:
+ A new instance of the same type as old_state with the provided attributes
+
+ Example:
+ >>> attrs = {'positions': tensor1, 'system_idx': tensor2}
+ >>> new_state = construct_state(old_state, attrs)
+ # This will call type(old_state)(positions=tensor1, init_system_idx=tensor2)
+ """
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def construct_state( | |
old_state: SimStateT, | |
new_state_attrs: dict[str, typing.Any], | |
) -> SimStateT: | |
"""Construct a new state from an old state and new state parameters.""" | |
# 1) process the attrs so they are the init params | |
processed_params = {} | |
for param in inspect.signature(old_state.__class__).parameters: | |
if param.startswith("init_"): | |
# this is an InitVar field | |
# we need to rename the corresponding field in system_attrs to have | |
# an "init_" prefix | |
non_init_attr_name = param.removeprefix("init_") | |
processed_params[param] = new_state_attrs[non_init_attr_name] | |
else: | |
processed_params[param] = new_state_attrs[param] | |
# 2) construct the new state | |
return type(old_state)(**processed_params) | |
def construct_state( | |
old_state: SimStateT, | |
new_state_attrs: dict[str, typing.Any], | |
) -> SimStateT: | |
"""Construct a new state from an old state and new state parameters. | |
This function handles the mapping of InitVar fields by automatically | |
prefixing the corresponding attribute names with 'init_' when calling | |
the constructor. | |
Args: | |
old_state: The state instance whose type will be used for construction. | |
new_state_attrs: Dictionary of attributes for the new state. Keys | |
corresponding to InitVar fields should not have the 'init_' prefix; | |
it will be added automatically. | |
Returns: | |
A new instance of the same type as old_state with the provided attributes. | |
Example: | |
>>> attrs = {'positions': tensor1, 'system_idx': tensor2} | |
>>> new_state = construct_state(old_state, attrs) | |
# This will call type(old_state)(positions=tensor1, init_system_idx=tensor2) | |
""" | |
# 1) process the attrs so they are the init params | |
processed_params = {} | |
for param in inspect.signature(old_state.__class__).parameters: | |
if param.startswith("init_"): | |
# this is an InitVar field | |
# we need to rename the corresponding field in system_attrs to have | |
# an "init_" prefix | |
non_init_attr_name = param.removeprefix("init_") | |
processed_params[param] = new_state_attrs[non_init_attr_name] | |
else: | |
processed_params[param] = new_state_attrs[param] | |
# 2) construct the new state | |
return type(old_state)(**processed_params) |
🤖 Prompt for AI Agents
In torch_sim/state.py around lines 789 to 808, the construct_state function
lacks a comprehensive docstring. Add a detailed Google style docstring that
describes the function's purpose, its parameters including old_state and
new_state_attrs with their types, and the return value specifying it returns a
new SimStateT instance constructed from the provided attributes.
@@ -107,24 +110,25 @@ def __post_init__(self) -> None: | |||
f"masses {shapes[1]}, atomic_numbers {shapes[2]}" | |||
) | |||
|
|||
if self.cell.ndim != 3 and self.system_idx is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved these checks down since it depends on self.system_idx (also to bundle it with the self.cell.shape[0] check)
torch_sim/state.py
Outdated
@@ -272,7 +276,7 @@ def clone(self) -> Self: | |||
else: | |||
attrs[attr_name] = copy.deepcopy(attr_value) | |||
|
|||
return self.__class__(**attrs) | |||
return construct_state(self, attrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we now have InitVar params in the constructor, creating a class is not as easy as __class__(**attrs)
. We need to properly handle the init_
constructor params
closing since I don't like this approach. it doesn't feel right to call initialization params |
Summary
In order to properly type
SimState
, I want to remove the| None
type annotation onsystem_idx
. This is solved via theInitVar
variableinit_system_idx
.In addition to typing, this PR introduces the
__post_init__
check for subclasses ofSimState
. This paradigm allows us to verify properties of derived classes. In particular, this PR uses__post_init__
to verify that allInitVar
classes start with theinit_
prefix (which is required to properly mangle params during concatenation / splitting ofSimStates
.In the next PR, we will use the
__post_init__
to enforce that all derivedSimState
classes cannot have a| None
attribute inside (which will break functions liketorch.concatenate
since we cannot concat attributes that are tensors and attributes that are none - See the description of #229 for more info)This PR is breaking since we modify the constructor for SimState.
Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit