2
2
3
3
from collections .abc import Callable
4
4
from dataclasses import dataclass
5
+ from typing import Literal
5
6
6
7
import torch
7
8
8
9
from torchsim .state import BaseState
9
10
from torchsim .unbatched_integrators import velocity_verlet
10
11
11
12
13
+ StateDict = dict [
14
+ Literal ["positions" , "masses" , "cell" , "pbc" , "atomic_numbers" , "batch" ], torch .Tensor
15
+ ]
16
+
12
17
eps = 1e-8
13
18
14
19
@@ -41,19 +46,12 @@ class GDState(OptimizerState):
41
46
lr: Learning rate for position updates
42
47
"""
43
48
44
- lr : torch .Tensor
45
-
46
49
47
50
def gradient_descent (
48
51
* ,
49
- positions : torch .Tensor ,
50
- masses : torch .Tensor ,
51
- cell : torch .Tensor ,
52
- pbc : bool ,
53
52
model : torch .nn .Module ,
54
- learning_rate : float = 0.01 ,
55
- ** extra_state_kwargs ,
56
- ) -> tuple [GDState , Callable [[GDState ], GDState ]]:
53
+ lr : float = 0.01 ,
54
+ ) -> tuple [Callable [[StateDict ], GDState ], Callable [[GDState ], GDState ]]:
57
55
"""Initialize a simple gradient descent optimization.
58
56
59
57
Gradient descent updates atomic positions by moving along the direction of the forces
@@ -63,31 +61,67 @@ def gradient_descent(
63
61
64
62
Args:
65
63
model: Neural network model that computes energies and forces
66
- positions: Atomic positions tensor of shape (n_atoms, 3)
67
- masses: Atomic masses tensor of shape (n_atoms,)
68
- cell: Unit cell tensor of shape (3, 3)
69
- pbc: Periodic boundary conditions flags
70
- learning_rate: Step size for position updates (default: 0.01)
71
- **extra_state_kwargs: Additional keyword arguments to pass to the state
64
+ lr: Step size for position updates (default: 0.01)
72
65
73
66
Returns:
74
67
Tuple containing:
75
- - Initial GDState with system state
68
+ - Initialization function that creates the initial GDState
76
69
- Update function that performs one gradient descent step
77
70
78
71
Notes:
79
72
- Best suited for systems close to their minimum energy configuration
80
73
"""
81
- device = positions .device
82
- dtype = positions .dtype
74
+ device = model .device
75
+ dtype = model .dtype
83
76
84
77
# Convert learning rate to tensor
85
- lr = torch .tensor (learning_rate , device = device , dtype = dtype )
78
+ if not isinstance (lr , torch .Tensor ):
79
+ lr = torch .tensor (lr , device = device , dtype = dtype )
80
+
81
+ def gd_init (state : BaseState | StateDict , ** extra_state_kwargs ) -> GDState :
82
+ """Initialize the gradient descent optimizer state.
83
+
84
+ Args:
85
+ state: Initial system state
86
+ **extra_state_kwargs: Additional keyword arguments for state initialization
87
+
88
+ Returns:
89
+ Initial GDState with system configuration and forces
90
+ """
91
+ if not isinstance (state , BaseState ):
92
+ state = BaseState (** state )
93
+
94
+ atomic_numbers = extra_state_kwargs .get ("atomic_numbers" , state .atomic_numbers )
95
+
96
+ # Get initial forces and energy from model
97
+ model_output = model (
98
+ positions = state .positions ,
99
+ cell = state .cell ,
100
+ atomic_numbers = atomic_numbers ,
101
+ )
86
102
87
- def gd_step (state : GDState ) -> GDState :
88
- """Perform one gradient descent optimization step."""
103
+ return GDState (
104
+ positions = state .positions ,
105
+ masses = state .masses ,
106
+ cell = state .cell ,
107
+ pbc = state .pbc ,
108
+ atomic_numbers = state .atomic_numbers ,
109
+ forces = model_output ["forces" ],
110
+ energy = model_output ["energy" ],
111
+ )
112
+
113
+ def gd_step (state : GDState , lr : torch .Tensor = lr ) -> GDState :
114
+ """Perform one gradient descent optimization step.
115
+
116
+ Args:
117
+ state: Current optimization state
118
+ lr: Learning rate for position updates (default: value from initialization)
119
+
120
+ Returns:
121
+ Updated state after one optimization step
122
+ """
89
123
# Update positions using forces and learning rate
90
- state .positions = state .positions + state . lr * state .forces
124
+ state .positions = state .positions + lr * state .forces
91
125
92
126
# Update forces and energy at new positions
93
127
results = model (
@@ -100,23 +134,7 @@ def gd_step(state: GDState) -> GDState:
100
134
101
135
return state
102
136
103
- model_output = model (
104
- positions = positions ,
105
- cell = cell ,
106
- atomic_numbers = extra_state_kwargs .get ("atomic_numbers" ),
107
- )
108
-
109
- initial_state = GDState (
110
- positions = positions ,
111
- masses = masses ,
112
- cell = cell ,
113
- pbc = pbc ,
114
- atomic_numbers = extra_state_kwargs .get ("atomic_numbers" ),
115
- forces = model_output ["forces" ],
116
- energy = model_output ["energy" ],
117
- lr = lr ,
118
- )
119
- return initial_state , gd_step
137
+ return gd_init , gd_step
120
138
121
139
122
140
@dataclass
0 commit comments