Skip to content

Commit d3101d4

Browse files
Update gradient descent
1 parent 6feb03a commit d3101d4

File tree

1 file changed

+57
-39
lines changed

1 file changed

+57
-39
lines changed

torchsim/unbatched_optimizers.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5+
from typing import Literal
56

67
import torch
78

89
from torchsim.state import BaseState
910
from torchsim.unbatched_integrators import velocity_verlet
1011

1112

13+
StateDict = dict[
14+
Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "batch"], torch.Tensor
15+
]
16+
1217
eps = 1e-8
1318

1419

@@ -41,19 +46,12 @@ class GDState(OptimizerState):
4146
lr: Learning rate for position updates
4247
"""
4348

44-
lr: torch.Tensor
45-
4649

4750
def gradient_descent(
4851
*,
49-
positions: torch.Tensor,
50-
masses: torch.Tensor,
51-
cell: torch.Tensor,
52-
pbc: bool,
5352
model: torch.nn.Module,
54-
learning_rate: float = 0.01,
55-
**extra_state_kwargs,
56-
) -> tuple[GDState, Callable[[GDState], GDState]]:
53+
lr: float = 0.01,
54+
) -> tuple[Callable[[StateDict], GDState], Callable[[GDState], GDState]]:
5755
"""Initialize a simple gradient descent optimization.
5856
5957
Gradient descent updates atomic positions by moving along the direction of the forces
@@ -63,31 +61,67 @@ def gradient_descent(
6361
6462
Args:
6563
model: Neural network model that computes energies and forces
66-
positions: Atomic positions tensor of shape (n_atoms, 3)
67-
masses: Atomic masses tensor of shape (n_atoms,)
68-
cell: Unit cell tensor of shape (3, 3)
69-
pbc: Periodic boundary conditions flags
70-
learning_rate: Step size for position updates (default: 0.01)
71-
**extra_state_kwargs: Additional keyword arguments to pass to the state
64+
lr: Step size for position updates (default: 0.01)
7265
7366
Returns:
7467
Tuple containing:
75-
- Initial GDState with system state
68+
- Initialization function that creates the initial GDState
7669
- Update function that performs one gradient descent step
7770
7871
Notes:
7972
- Best suited for systems close to their minimum energy configuration
8073
"""
81-
device = positions.device
82-
dtype = positions.dtype
74+
device = model.device
75+
dtype = model.dtype
8376

8477
# Convert learning rate to tensor
85-
lr = torch.tensor(learning_rate, device=device, dtype=dtype)
78+
if not isinstance(lr, torch.Tensor):
79+
lr = torch.tensor(lr, device=device, dtype=dtype)
80+
81+
def gd_init(state: BaseState | StateDict, **extra_state_kwargs) -> GDState:
82+
"""Initialize the gradient descent optimizer state.
83+
84+
Args:
85+
state: Initial system state
86+
**extra_state_kwargs: Additional keyword arguments for state initialization
87+
88+
Returns:
89+
Initial GDState with system configuration and forces
90+
"""
91+
if not isinstance(state, BaseState):
92+
state = BaseState(**state)
93+
94+
atomic_numbers = extra_state_kwargs.get("atomic_numbers", state.atomic_numbers)
95+
96+
# Get initial forces and energy from model
97+
model_output = model(
98+
positions=state.positions,
99+
cell=state.cell,
100+
atomic_numbers=atomic_numbers,
101+
)
86102

87-
def gd_step(state: GDState) -> GDState:
88-
"""Perform one gradient descent optimization step."""
103+
return GDState(
104+
positions=state.positions,
105+
masses=state.masses,
106+
cell=state.cell,
107+
pbc=state.pbc,
108+
atomic_numbers=state.atomic_numbers,
109+
forces=model_output["forces"],
110+
energy=model_output["energy"],
111+
)
112+
113+
def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState:
114+
"""Perform one gradient descent optimization step.
115+
116+
Args:
117+
state: Current optimization state
118+
lr: Learning rate for position updates (default: value from initialization)
119+
120+
Returns:
121+
Updated state after one optimization step
122+
"""
89123
# Update positions using forces and learning rate
90-
state.positions = state.positions + state.lr * state.forces
124+
state.positions = state.positions + lr * state.forces
91125

92126
# Update forces and energy at new positions
93127
results = model(
@@ -100,23 +134,7 @@ def gd_step(state: GDState) -> GDState:
100134

101135
return state
102136

103-
model_output = model(
104-
positions=positions,
105-
cell=cell,
106-
atomic_numbers=extra_state_kwargs.get("atomic_numbers"),
107-
)
108-
109-
initial_state = GDState(
110-
positions=positions,
111-
masses=masses,
112-
cell=cell,
113-
pbc=pbc,
114-
atomic_numbers=extra_state_kwargs.get("atomic_numbers"),
115-
forces=model_output["forces"],
116-
energy=model_output["energy"],
117-
lr=lr,
118-
)
119-
return initial_state, gd_step
137+
return gd_init, gd_step
120138

121139

122140
@dataclass

0 commit comments

Comments
 (0)