@@ -11,8 +11,7 @@ AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy(
11
11
inter:: AbstractEwald ;
12
12
n_threads:: Integer = Threads. nthreads (),
13
13
kwargs... )
14
- fs = zero_forces (sys)
15
- pe = ewald_pe_forces! (fs, sys, inter; n_threads= n_threads)
14
+ pe = ewald_pe_forces! (nothing , sys, inter; n_threads= n_threads)
16
15
return pe
17
16
end
18
17
@@ -40,7 +39,7 @@ function AtomsCalculators.energy_forces(sys,
40
39
kwargs... )
41
40
fs = zero_forces (sys)
42
41
pe = ewald_pe_forces! (fs, sys, inter; n_threads= n_threads)
43
- return (energy= E , forces= Fs )
42
+ return (energy= pe , forces= fs )
44
43
end
45
44
46
45
function find_excluded_pairs (eligible, special)
@@ -60,8 +59,8 @@ function find_excluded_pairs(eligible, special)
60
59
return excluded_pairs
61
60
end
62
61
63
- function excluded_interactions_inner! (Fs, atoms, coords, boundary, α, f, i, j,
64
- :: Val{T } , :: Val{atomic} ) where {T, atomic}
62
+ function excluded_interactions_inner! (Fs, atoms, coords, boundary, α, f, i, j, :: Val{T} ,
63
+ :: Val{calculate_forces } , :: Val{atomic} ) where {T, calculate_forces , atomic}
65
64
sqrt_π = sqrt (T (π))
66
65
charge_ij = charge (atoms[i]) * charge (atoms[j])
67
66
vec_ij = vector (coords[i], coords[j], boundary)
@@ -71,57 +70,63 @@ function excluded_interactions_inner!(Fs, atoms, coords, boundary, α, f, i, j,
71
70
if erf_αr > T (1e-6 )
72
71
inv_r = inv (r)
73
72
exclusion_E = - f * charge_ij * inv_r * erf_αr
74
- dE_dr = f * charge_ij * inv_r^ 3 * (erf_αr - 2 * αr * exp (- αr^ 2 ) / sqrt_π)
75
- F = dE_dr * vec_ij
76
- if atomic
77
- for dim in 1 : 3
78
- fval = ustrip (F[dim])
79
- Atomix. @atomic Fs[dim, i] += fval
80
- Atomix. @atomic Fs[dim, j] += - fval
73
+ if calculate_forces
74
+ dE_dr = f * charge_ij * inv_r^ 3 * (erf_αr - 2 * αr * exp (- αr^ 2 ) / sqrt_π)
75
+ F = dE_dr * vec_ij
76
+ if atomic
77
+ for dim in 1 : 3
78
+ fval = ustrip (F[dim])
79
+ Atomix. @atomic Fs[dim, i] += fval
80
+ Atomix. @atomic Fs[dim, j] += - fval
81
+ end
82
+ else
83
+ Fs[i] += F
84
+ Fs[j] -= F
81
85
end
82
- else
83
- Fs[i] += F
84
- Fs[j] -= F
85
86
end
86
87
else
87
88
exclusion_E = - α * 2 * f * charge_ij / sqrt_π
88
89
end
89
90
return exclusion_E
90
91
end
91
92
92
- function excluded_interactions! (Fs:: Vector , buffer_Fs, buffer_Es, excluded_pairs,
93
- atoms, coords, boundary, α, f, force_units,
94
- energy_units , :: Val{T} ) where T
93
+ function excluded_interactions! (Fs, buffer_Fs, buffer_Es, excluded_pairs, atoms ,
94
+ coords:: Vector , boundary, α, f, force_units, energy_units ,
95
+ calculate_forces , :: Val{T} ) where T
95
96
exclusion_E = zero (T) * energy_units
96
97
for (i, j) in excluded_pairs
97
98
E = excluded_interactions_inner! (Fs, atoms, coords, boundary, α, f,
98
- i, j, Val (T), Val (false ))
99
+ i, j, Val (T), Val (calculate_forces), Val ( false ))
99
100
exclusion_E += E
100
101
end
101
102
return exclusion_E
102
103
end
103
104
104
- function excluded_interactions! (Fs:: AbstractVector{SVector{D, C}} , buffer_Fs, buffer_Es,
105
- excluded_pairs, atoms, coords, boundary, α, f, force_units,
106
- energy_units, :: Val{T} ) where {D, C, T}
107
- buffer_Fs .= zero (T)
108
- backend = get_backend (Fs)
105
+ function excluded_interactions! (Fs, buffer_Fs, buffer_Es, excluded_pairs, atoms,
106
+ coords:: AbstractVector{SVector{D, C}} , boundary, α, f, force_units,
107
+ energy_units, calculate_forces, :: Val{T} ) where {D, C, T}
108
+ if calculate_forces
109
+ buffer_Fs .= zero (T)
110
+ end
111
+ backend = get_backend (atoms)
109
112
n_threads_gpu = 128
110
113
kernel! = excluded_interactions_kernel! (backend, n_threads_gpu)
111
114
kernel! (buffer_Fs, buffer_Es, excluded_pairs, atoms, coords, boundary, α, f,
112
- energy_units, Val (T); ndrange= length (excluded_pairs))
113
- Fs .+ = reinterpret (SVector{D, T}, vec (buffer_Fs)) .* force_units
115
+ energy_units, Val (T), Val (calculate_forces); ndrange= length (excluded_pairs))
116
+ if calculate_forces
117
+ Fs .+ = reinterpret (SVector{D, T}, vec (buffer_Fs)) .* force_units
118
+ end
114
119
return sum (buffer_Es) * energy_units
115
120
end
116
121
117
122
@kernel function excluded_interactions_kernel! (Fs_mat, exclusion_Es, @Const (excluded_pairs),
118
123
@Const (atoms), @Const (coords), boundary, α, f, energy_units,
119
- :: Val{T} ) where T
124
+ :: Val{T} , :: Val{calculate_forces} ) where {T, calculate_forces}
120
125
ei = @index (Global, Linear)
121
126
if ei <= length (excluded_pairs)
122
127
i, j = excluded_pairs[ei]
123
128
E = excluded_interactions_inner! (Fs_mat, atoms, coords, boundary, α, f,
124
- i, j, Val (T), Val (true ))
129
+ i, j, Val (T), Val (calculate_forces), Val ( true ))
125
130
exclusion_Es[ei] = ustrip (energy_units, E)
126
131
end
127
132
end
@@ -190,11 +195,18 @@ function ewald_params(side_length, α, error_tol)
190
195
return k
191
196
end
192
197
193
- function ewald_pe_forces! (Fs, sys:: System{3, AT} , inter:: Ewald{T} ;
194
- n_threads:: Integer = Threads. nthreads ()) where {AT, T}
195
- n_atoms = length (sys)
196
- atoms_cpu, coords_cpu = from_device (sys. atoms), from_device (sys. coords)
197
- boundary, energy_units = sys. boundary, sys. energy_units
198
+ function ewald_pe_forces! (Fs, sys:: System{3} , inter:: AbstractEwald ;
199
+ n_threads:: Integer = Threads. nthreads ())
200
+ calculate_forces = ! isnothing (Fs)
201
+ return ewald_pe_forces! (Fs, inter, sys. atoms, sys. coords, sys. boundary, sys. force_units,
202
+ sys. energy_units, calculate_forces; n_threads= n_threads)
203
+ end
204
+
205
+ function ewald_pe_forces! (Fs, inter:: Ewald{T} , atoms, coords, boundary, force_units, energy_units,
206
+ calculate_forces= true ; n_threads:: Integer = Threads. nthreads ()) where T
207
+ AT = array_type (atoms)
208
+ n_atoms = length (atoms)
209
+ atoms_cpu, coords_cpu = from_device (atoms), from_device (coords)
198
210
dist_cutoff, error_tol = inter. dist_cutoff, inter. error_tol
199
211
α = inv (dist_cutoff) * sqrt (- log (2 * error_tol))
200
212
nrx, nry, nrz = ewald_params .(boundary. side_lengths, α, error_tol)
@@ -205,15 +217,15 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
205
217
partial_charges_cpu = charge .(atoms_cpu)
206
218
V = volume (boundary)
207
219
f = (energy_units == NoUnits ? ustrip (T (Molly. coulomb_const)) : T (Molly. coulomb_const))
208
- if AT <: AbstractGPUArray
209
- Fs_cpu = zeros (SVector{3 , typeof (zero (T) * sys . force_units)}, n_atoms)
220
+ if AT <: AbstractGPUArray && calculate_forces
221
+ Fs_cpu = zeros (SVector{3 , typeof (zero (T) * force_units)}, n_atoms)
210
222
else
211
223
Fs_cpu = Fs
212
224
end
213
225
214
226
exclusion_E = excluded_interactions! (Fs_cpu, nothing , nothing , inter. excluded_pairs,
215
227
atoms_cpu, coords_cpu, boundary, α, f,
216
- sys . force_units, energy_units, Val (T))
228
+ force_units, energy_units, calculate_forces , Val (T))
217
229
218
230
recip_box_size = (2 * T (π)) ./ boundary. side_lengths
219
231
eir = zeros (Complex{T}, kmax * n_atoms * 3 )
@@ -272,7 +284,9 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
272
284
ak = exp (k2 * factor_ewald) / k2
273
285
for n in 1 : n_atoms
274
286
F = ak * (cs * imag (tab_qxyz[n]) - ss * real (tab_qxyz[n]))
275
- Fs_cpu[n] += 2 .* recip_coeff .* F .* SVector (kx, ky, kz)
287
+ if calculate_forces
288
+ Fs_cpu[n] += 2 .* recip_coeff .* F .* SVector (kx, ky, kz)
289
+ end
276
290
end
277
291
reciprocal_space_E += recip_coeff * ak * (cs * cs + ss * ss)
278
292
lowrz = 1 - nrz
@@ -284,7 +298,7 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
284
298
charge_E = - f * T (π) * sum (partial_charges_cpu)^ 2 / (2 * V * α^ 2 )
285
299
self_E = f * - sum (abs2, partial_charges_cpu) * α / sqrt (T (π)) + charge_E
286
300
total_E = reciprocal_space_E + self_E + exclusion_E
287
- if AT <: AbstractGPUArray
301
+ if calculate_forces && AT <: AbstractGPUArray
288
302
Fs .+ = to_device (Fs_cpu, AT)
289
303
end
290
304
return total_E
@@ -808,17 +822,16 @@ end
808
822
grad_safe_fft! ( charge_grid, fft_plan ) = fft_plan * charge_grid
809
823
grad_safe_bfft! (charge_grid, bfft_plan) = bfft_plan * charge_grid
810
824
811
- function ewald_pe_forces! (Fs, sys :: System{3, AT} , inter:: PME{T} ;
812
- n_threads:: Integer = Threads. nthreads ()) where {AT, T}
825
+ function ewald_pe_forces! (Fs, inter:: PME{T} , atoms, coords, boundary, force_units, energy_units,
826
+ calculate_forces = true ; n_threads:: Integer = Threads. nthreads ()) where T
813
827
n_thr = (inter. grad_safe ? 1 : n_threads) # Enzyme error with multiple threads
814
- atoms, coords, boundary, energy_units = sys. atoms, sys. coords, sys. boundary, sys. energy_units
815
828
order, ϵr, α, mesh_dims = inter. order, inter. ϵr, inter. α, inter. mesh_dims
816
829
V = volume (boundary)
817
830
f = (energy_units == NoUnits ? ustrip (T (Molly. coulomb_const)) : T (Molly. coulomb_const))
818
831
819
832
exclusion_E = excluded_interactions! (Fs, inter. excluded_buffer_Fs, inter. excluded_buffer_Es,
820
833
inter. excluded_pairs, atoms, coords, boundary, α, f,
821
- sys . force_units, energy_units, Val (T))
834
+ force_units, energy_units, calculate_forces , Val (T))
822
835
823
836
recip_box = invert_box_vectors (boundary)
824
837
grid_placement! (inter. grid_indices, inter. grid_fractions, coords, recip_box, mesh_dims)
@@ -830,12 +843,14 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::PME{T};
830
843
inter. bsplines_moduli_x, inter. bsplines_moduli_y, inter. bsplines_moduli_z,
831
844
recip_box, f / ϵr, α, mesh_dims, boundary, energy_units, n_thr)
832
845
grad_safe_bfft! (inter. charge_grid, inter. bfft_plan)
833
- interpolate_force! (Fs, inter. charge_grid, inter. grid_indices, inter. bsplines_θ,
834
- inter. bsplines_dθ, recip_box, mesh_dims, order, energy_units, atoms,
835
- n_thr)
846
+ if calculate_forces
847
+ interpolate_force! (Fs, inter. charge_grid, inter. grid_indices, inter. bsplines_θ,
848
+ inter. bsplines_dθ, recip_box, mesh_dims, order, energy_units, atoms,
849
+ n_thr)
850
+ end
836
851
837
852
if isnothing (inter. pc_sum) || inter. grad_safe
838
- partial_charges = charges (sys )
853
+ partial_charges = charge .(atoms )
839
854
pc_sum = sum (partial_charges)
840
855
pc_abs2_sum = sum (abs2, partial_charges)
841
856
else
0 commit comments