Skip to content

Commit c87386c

Browse files
committed
option to not calculate forces in Ewald
1 parent f9bf88c commit c87386c

File tree

1 file changed

+61
-46
lines changed

1 file changed

+61
-46
lines changed

src/interactions/ewald.jl

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy(
1111
inter::AbstractEwald;
1212
n_threads::Integer=Threads.nthreads(),
1313
kwargs...)
14-
fs = zero_forces(sys)
15-
pe = ewald_pe_forces!(fs, sys, inter; n_threads=n_threads)
14+
pe = ewald_pe_forces!(nothing, sys, inter; n_threads=n_threads)
1615
return pe
1716
end
1817

@@ -40,7 +39,7 @@ function AtomsCalculators.energy_forces(sys,
4039
kwargs...)
4140
fs = zero_forces(sys)
4241
pe = ewald_pe_forces!(fs, sys, inter; n_threads=n_threads)
43-
return (energy=E, forces=Fs)
42+
return (energy=pe, forces=fs)
4443
end
4544

4645
function find_excluded_pairs(eligible, special)
@@ -60,8 +59,8 @@ function find_excluded_pairs(eligible, special)
6059
return excluded_pairs
6160
end
6261

63-
function excluded_interactions_inner!(Fs, atoms, coords, boundary, α, f, i, j,
64-
::Val{T}, ::Val{atomic}) where {T, atomic}
62+
function excluded_interactions_inner!(Fs, atoms, coords, boundary, α, f, i, j, ::Val{T},
63+
::Val{calculate_forces}, ::Val{atomic}) where {T, calculate_forces, atomic}
6564
sqrt_π = sqrt(T(π))
6665
charge_ij = charge(atoms[i]) * charge(atoms[j])
6766
vec_ij = vector(coords[i], coords[j], boundary)
@@ -71,57 +70,63 @@ function excluded_interactions_inner!(Fs, atoms, coords, boundary, α, f, i, j,
7170
if erf_αr > T(1e-6)
7271
inv_r = inv(r)
7372
exclusion_E = -f * charge_ij * inv_r * erf_αr
74-
dE_dr = f * charge_ij * inv_r^3 * (erf_αr - 2 * αr * exp(-αr^2) / sqrt_π)
75-
F = dE_dr * vec_ij
76-
if atomic
77-
for dim in 1:3
78-
fval = ustrip(F[dim])
79-
Atomix.@atomic Fs[dim, i] += fval
80-
Atomix.@atomic Fs[dim, j] += -fval
73+
if calculate_forces
74+
dE_dr = f * charge_ij * inv_r^3 * (erf_αr - 2 * αr * exp(-αr^2) / sqrt_π)
75+
F = dE_dr * vec_ij
76+
if atomic
77+
for dim in 1:3
78+
fval = ustrip(F[dim])
79+
Atomix.@atomic Fs[dim, i] += fval
80+
Atomix.@atomic Fs[dim, j] += -fval
81+
end
82+
else
83+
Fs[i] += F
84+
Fs[j] -= F
8185
end
82-
else
83-
Fs[i] += F
84-
Fs[j] -= F
8586
end
8687
else
8788
exclusion_E = -α * 2 * f * charge_ij / sqrt_π
8889
end
8990
return exclusion_E
9091
end
9192

92-
function excluded_interactions!(Fs::Vector, buffer_Fs, buffer_Es, excluded_pairs,
93-
atoms, coords, boundary, α, f, force_units,
94-
energy_units, ::Val{T}) where T
93+
function excluded_interactions!(Fs, buffer_Fs, buffer_Es, excluded_pairs, atoms,
94+
coords::Vector, boundary, α, f, force_units, energy_units,
95+
calculate_forces, ::Val{T}) where T
9596
exclusion_E = zero(T) * energy_units
9697
for (i, j) in excluded_pairs
9798
E = excluded_interactions_inner!(Fs, atoms, coords, boundary, α, f,
98-
i, j, Val(T), Val(false))
99+
i, j, Val(T), Val(calculate_forces), Val(false))
99100
exclusion_E += E
100101
end
101102
return exclusion_E
102103
end
103104

104-
function excluded_interactions!(Fs::AbstractVector{SVector{D, C}}, buffer_Fs, buffer_Es,
105-
excluded_pairs, atoms, coords, boundary, α, f, force_units,
106-
energy_units, ::Val{T}) where {D, C, T}
107-
buffer_Fs .= zero(T)
108-
backend = get_backend(Fs)
105+
function excluded_interactions!(Fs, buffer_Fs, buffer_Es, excluded_pairs, atoms,
106+
coords::AbstractVector{SVector{D, C}}, boundary, α, f, force_units,
107+
energy_units, calculate_forces, ::Val{T}) where {D, C, T}
108+
if calculate_forces
109+
buffer_Fs .= zero(T)
110+
end
111+
backend = get_backend(atoms)
109112
n_threads_gpu = 128
110113
kernel! = excluded_interactions_kernel!(backend, n_threads_gpu)
111114
kernel!(buffer_Fs, buffer_Es, excluded_pairs, atoms, coords, boundary, α, f,
112-
energy_units, Val(T); ndrange=length(excluded_pairs))
113-
Fs .+= reinterpret(SVector{D, T}, vec(buffer_Fs)) .* force_units
115+
energy_units, Val(T), Val(calculate_forces); ndrange=length(excluded_pairs))
116+
if calculate_forces
117+
Fs .+= reinterpret(SVector{D, T}, vec(buffer_Fs)) .* force_units
118+
end
114119
return sum(buffer_Es) * energy_units
115120
end
116121

117122
@kernel function excluded_interactions_kernel!(Fs_mat, exclusion_Es, @Const(excluded_pairs),
118123
@Const(atoms), @Const(coords), boundary, α, f, energy_units,
119-
::Val{T}) where T
124+
::Val{T}, ::Val{calculate_forces}) where {T, calculate_forces}
120125
ei = @index(Global, Linear)
121126
if ei <= length(excluded_pairs)
122127
i, j = excluded_pairs[ei]
123128
E = excluded_interactions_inner!(Fs_mat, atoms, coords, boundary, α, f,
124-
i, j, Val(T), Val(true))
129+
i, j, Val(T), Val(calculate_forces), Val(true))
125130
exclusion_Es[ei] = ustrip(energy_units, E)
126131
end
127132
end
@@ -190,11 +195,18 @@ function ewald_params(side_length, α, error_tol)
190195
return k
191196
end
192197

193-
function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
194-
n_threads::Integer=Threads.nthreads()) where {AT, T}
195-
n_atoms = length(sys)
196-
atoms_cpu, coords_cpu = from_device(sys.atoms), from_device(sys.coords)
197-
boundary, energy_units = sys.boundary, sys.energy_units
198+
function ewald_pe_forces!(Fs, sys::System{3}, inter::AbstractEwald;
199+
n_threads::Integer=Threads.nthreads())
200+
calculate_forces = !isnothing(Fs)
201+
return ewald_pe_forces!(Fs, inter, sys.atoms, sys.coords, sys.boundary, sys.force_units,
202+
sys.energy_units, calculate_forces; n_threads=n_threads)
203+
end
204+
205+
function ewald_pe_forces!(Fs, inter::Ewald{T}, atoms, coords, boundary, force_units, energy_units,
206+
calculate_forces=true; n_threads::Integer=Threads.nthreads()) where T
207+
AT = array_type(atoms)
208+
n_atoms = length(atoms)
209+
atoms_cpu, coords_cpu = from_device(atoms), from_device(coords)
198210
dist_cutoff, error_tol = inter.dist_cutoff, inter.error_tol
199211
α = inv(dist_cutoff) * sqrt(-log(2 * error_tol))
200212
nrx, nry, nrz = ewald_params.(boundary.side_lengths, α, error_tol)
@@ -205,15 +217,15 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
205217
partial_charges_cpu = charge.(atoms_cpu)
206218
V = volume(boundary)
207219
f = (energy_units == NoUnits ? ustrip(T(Molly.coulomb_const)) : T(Molly.coulomb_const))
208-
if AT <: AbstractGPUArray
209-
Fs_cpu = zeros(SVector{3, typeof(zero(T) * sys.force_units)}, n_atoms)
220+
if AT <: AbstractGPUArray && calculate_forces
221+
Fs_cpu = zeros(SVector{3, typeof(zero(T) * force_units)}, n_atoms)
210222
else
211223
Fs_cpu = Fs
212224
end
213225

214226
exclusion_E = excluded_interactions!(Fs_cpu, nothing, nothing, inter.excluded_pairs,
215227
atoms_cpu, coords_cpu, boundary, α, f,
216-
sys.force_units, energy_units, Val(T))
228+
force_units, energy_units, calculate_forces, Val(T))
217229

218230
recip_box_size = (2 * T(π)) ./ boundary.side_lengths
219231
eir = zeros(Complex{T}, kmax * n_atoms * 3)
@@ -272,7 +284,9 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
272284
ak = exp(k2 * factor_ewald) / k2
273285
for n in 1:n_atoms
274286
F = ak * (cs * imag(tab_qxyz[n]) - ss * real(tab_qxyz[n]))
275-
Fs_cpu[n] += 2 .* recip_coeff .* F .* SVector(kx, ky, kz)
287+
if calculate_forces
288+
Fs_cpu[n] += 2 .* recip_coeff .* F .* SVector(kx, ky, kz)
289+
end
276290
end
277291
reciprocal_space_E += recip_coeff * ak * (cs * cs + ss * ss)
278292
lowrz = 1 - nrz
@@ -284,7 +298,7 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::Ewald{T};
284298
charge_E = -f * T(π) * sum(partial_charges_cpu)^2 / (2 * V * α^2)
285299
self_E = f * -sum(abs2, partial_charges_cpu) * α / sqrt(T(π)) + charge_E
286300
total_E = reciprocal_space_E + self_E + exclusion_E
287-
if AT <: AbstractGPUArray
301+
if calculate_forces && AT <: AbstractGPUArray
288302
Fs .+= to_device(Fs_cpu, AT)
289303
end
290304
return total_E
@@ -808,17 +822,16 @@ end
808822
grad_safe_fft!( charge_grid, fft_plan ) = fft_plan * charge_grid
809823
grad_safe_bfft!(charge_grid, bfft_plan) = bfft_plan * charge_grid
810824

811-
function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::PME{T};
812-
n_threads::Integer=Threads.nthreads()) where {AT, T}
825+
function ewald_pe_forces!(Fs, inter::PME{T}, atoms, coords, boundary, force_units, energy_units,
826+
calculate_forces=true; n_threads::Integer=Threads.nthreads()) where T
813827
n_thr = (inter.grad_safe ? 1 : n_threads) # Enzyme error with multiple threads
814-
atoms, coords, boundary, energy_units = sys.atoms, sys.coords, sys.boundary, sys.energy_units
815828
order, ϵr, α, mesh_dims = inter.order, inter.ϵr, inter.α, inter.mesh_dims
816829
V = volume(boundary)
817830
f = (energy_units == NoUnits ? ustrip(T(Molly.coulomb_const)) : T(Molly.coulomb_const))
818831

819832
exclusion_E = excluded_interactions!(Fs, inter.excluded_buffer_Fs, inter.excluded_buffer_Es,
820833
inter.excluded_pairs, atoms, coords, boundary, α, f,
821-
sys.force_units, energy_units, Val(T))
834+
force_units, energy_units, calculate_forces, Val(T))
822835

823836
recip_box = invert_box_vectors(boundary)
824837
grid_placement!(inter.grid_indices, inter.grid_fractions, coords, recip_box, mesh_dims)
@@ -830,12 +843,14 @@ function ewald_pe_forces!(Fs, sys::System{3, AT}, inter::PME{T};
830843
inter.bsplines_moduli_x, inter.bsplines_moduli_y, inter.bsplines_moduli_z,
831844
recip_box, f / ϵr, α, mesh_dims, boundary, energy_units, n_thr)
832845
grad_safe_bfft!(inter.charge_grid, inter.bfft_plan)
833-
interpolate_force!(Fs, inter.charge_grid, inter.grid_indices, inter.bsplines_θ,
834-
inter.bsplines_dθ, recip_box, mesh_dims, order, energy_units, atoms,
835-
n_thr)
846+
if calculate_forces
847+
interpolate_force!(Fs, inter.charge_grid, inter.grid_indices, inter.bsplines_θ,
848+
inter.bsplines_dθ, recip_box, mesh_dims, order, energy_units, atoms,
849+
n_thr)
850+
end
836851

837852
if isnothing(inter.pc_sum) || inter.grad_safe
838-
partial_charges = charges(sys)
853+
partial_charges = charge.(atoms)
839854
pc_sum = sum(partial_charges)
840855
pc_abs2_sum = sum(abs2, partial_charges)
841856
else

0 commit comments

Comments
 (0)