-
Notifications
You must be signed in to change notification settings - Fork 38
Fix concatenation of states in InFlightAutoBatcher
#229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `velocities` and `cell_velocities` are initialized to `None` in the `(FrechetCell)FIREState`. However, when using the `InFlightAutoBatcher` during an optimization, the current and new states are concatenated in `torch_sim.state.concatenate_states`. When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a `Tensor` with a `None`. Here, we initialize the `(cell_)velocities` as tensors full of `nan` instead, so that one can merge already processed and newly initialized states. During the first initialization, the `fire` methods look for `nan` rows and replace them with zeros.
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the ✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
examples/scripts/reproduce_err.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should put this in tests and make separate unit and integration dirs. Whilst the example scripts are tested this doesn't particularly make much sense to me in this location
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this file when this PR is out of draft. it's pure testing code rn
torch_sim/optimizers.py
Outdated
@@ -587,9 +587,11 @@ def fire_init( | |||
masses=state.masses.clone(), | |||
cell=state.cell.clone(), | |||
atomic_numbers=state.atomic_numbers.clone(), | |||
system_idx=state.system_idx.clone(), | |||
system_index=state.system_idx.clone(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
system_index
vs system_idx
still looks inconsistently applied? see line 863 below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. this was a "hack" I was trying last night. but it will not work because of other parts of the code. I will have this renamed which makes it easier to see inconsistencies. Thanks for pointing out the inconsistency though
for visibility I have thought of a good way to resolve this problem and will visit it hopefully later tonight. We'll definitely have a good PR by Saturday |
closing since this PR is superseded by #232 |
see #219
Summary
This is actually kinda a serious issue and I'll outline it here in a clear manner.
MD SimStates often track
velocity
. But on the first iteration, the states do NOT have velocity - so they are currently initialized asnone
.But once the optimizer gets going, these states end up having a
velocity
attribute.The problem is how we concatenate SimStates. Inside the autobatcher, when some SimStates finish before others, we swap those finished states with fresh states. This means inside the entire SimState, we have some systems with velocity set to
none
(since they were just swapped in and are fresh) and other systems with a set velocity.When we concatenate these "mixed" SimStates (during the optimization process), we do
torch.concatenate([torch.Tensor, none, none])
. Where the first system's velocity exists (because it's a torch.Tensor, and the last 2 systems do NOT have a velocity - since they were just swapped in by the autobatcher.PyTmorch cannot concatenate this because we're passing in
none
as an input which is invalid.@t-reents 's solution works pretty well and is valid (which is why I'm touching it up in this PR). His solution is: "rather than initializing
vector
attributes asnone
, we initiliaze it asnan
so we can do torch.concatentate between states that are old, and states that have just been swapped in.Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.