Skip to content

Commit 24f545d

Browse files
committed
revert simulators
1 parent 926d0ed commit 24f545d

File tree

1 file changed

+41
-43
lines changed

1 file changed

+41
-43
lines changed

src/simulators.jl

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ Custom simulators should implement this function.
6868
run_loggers=false,
6969
rng=Random.default_rng())
7070
# @inline needed to avoid Enzyme error
71-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
71+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
7272
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
7373
E = potential_energy(sys, neighbors; n_threads=n_threads)
7474
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads, current_potential_energy=E)
7575
using_constraints = length(sys.constraints) > 0
7676
println(sim.log_stream, "Step 0 - potential energy ", E, " - max force N/A - N/A")
7777
hn = sim.step_size
78-
coords_copy = zero(sys.coords)
79-
F_nounits = ustrip_vec.(zero(sys.coords))
78+
coords_copy = similar(sys.coords)
79+
F_nounits = ustrip_vec.(similar(sys.coords))
8080
F = F_nounits .* sys.force_units
8181
forces_buffer = init_forces_buffer!(sys, F_nounits, n_threads)
8282

@@ -89,7 +89,7 @@ Custom simulators should implement this function.
8989
coords_copy .= sys.coords
9090
sys.coords .+= hn .* F ./ max_force
9191
using_constraints && apply_position_constraints!(sys, coords_copy; n_threads=n_threads)
92-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
92+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
9393

9494
neighbors_copy = neighbors
9595
neighbors = find_neighbors(sys, sys.neighbor_finder, neighbors, step_n;
@@ -145,23 +145,23 @@ end
145145
n_threads::Integer=Threads.nthreads(),
146146
run_loggers=true,
147147
rng=Random.default_rng())
148-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
148+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
149149
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
150150
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
151151
forces_nounits_t = ustrip_vec.(zero(sys.coords))
152152
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
153-
forces_nounits_t .= forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, 0;
154-
n_threads=n_threads)
153+
forces_nounits_t = forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, 0;
154+
n_threads=n_threads)
155155
forces_t = forces_nounits_t .* sys.force_units
156156
accels_t = forces_t ./ masses(sys)
157-
forces_nounits_t_dt = ustrip_vec.(zero(sys.coords))
157+
forces_nounits_t_dt = ustrip_vec.(similar(sys.coords))
158158
forces_t_dt = forces_nounits_t_dt .* sys.force_units
159159
accels_t_dt = zero(accels_t)
160160
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads, current_forces=forces_t)
161161
using_constraints = length(sys.constraints) > 0
162162
if using_constraints
163-
cons_coord_storage = zero(sys.coords)
164-
cons_vel_storage = zero(sys.velocities)
163+
cons_coord_storage = similar(sys.coords)
164+
cons_vel_storage = similar(sys.velocities)
165165
end
166166

167167
for step_n in 1:n_steps
@@ -172,7 +172,7 @@ end
172172
sys.coords .+= sys.velocities .* sim.dt .+ ((accels_t .* sim.dt .^ 2) ./ 2)
173173
using_constraints && apply_position_constraints!(sys, cons_coord_storage, cons_vel_storage,
174174
sim.dt; n_threads=n_threads)
175-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
175+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
176176

177177
forces_nounits_t_dt .= forces_nounits!(forces_nounits_t_dt, sys, neighbors, forces_buffer,
178178
step_n; n_threads=n_threads)
@@ -236,17 +236,17 @@ end
236236
n_threads::Integer=Threads.nthreads(),
237237
run_loggers=true,
238238
rng=Random.default_rng())
239-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
239+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
240240
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
241241
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
242242
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
243-
forces_nounits_t = ustrip_vec.(zero(sys.coords))
243+
forces_nounits_t = ustrip_vec.(similar(sys.coords))
244244
forces_t = forces_nounits_t .* sys.force_units
245245
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
246246
accels_t = forces_t ./ masses(sys)
247247
using_constraints = length(sys.constraints) > 0
248248
if using_constraints
249-
cons_coord_storage = zero(sys.coords)
249+
cons_coord_storage = similar(sys.coords)
250250
end
251251

252252
for step_n in 1:n_steps
@@ -268,7 +268,7 @@ end
268268
sys.velocities .= (sys.coords .- cons_coord_storage) ./ sim.dt
269269
end
270270

271-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
271+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
272272

273273
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
274274
remove_CM_motion!(sys)
@@ -312,11 +312,11 @@ StormerVerlet(; dt, coupling=NoCoupling()) = StormerVerlet(dt, coupling)
312312
n_threads::Integer=Threads.nthreads(),
313313
run_loggers=true,
314314
rng=Random.default_rng())
315-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
315+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
316316
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
317317
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
318-
coords_last, coords_copy = zero(sys.coords), zero(sys.coords)
319-
forces_nounits_t = ustrip_vec.(zero(sys.coords))
318+
coords_last, coords_copy = similar(sys.coords), similar(sys.coords)
319+
forces_nounits_t = ustrip_vec.(similar(sys.coords))
320320
forces_t = forces_nounits_t .* sys.force_units
321321
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
322322
accels_t = forces_t ./ masses(sys)
@@ -339,7 +339,7 @@ StormerVerlet(; dt, coupling=NoCoupling()) = StormerVerlet(dt, coupling)
339339

340340
using_constraints && apply_position_constraints!(sys, coords_copy; n_threads=n_threads)
341341

342-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
342+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
343343
# This is accurate to O(dt)
344344
sys.velocities .= vector.(coords_copy, sys.coords, (sys.boundary,)) ./ sim.dt
345345

@@ -396,19 +396,19 @@ end
396396
n_threads::Integer=Threads.nthreads(),
397397
run_loggers=true,
398398
rng=Random.default_rng())
399-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
399+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
400400
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
401401
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
402402
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
403-
forces_nounits_t = ustrip_vec.(zero(sys.coords))
403+
forces_nounits_t = ustrip_vec.(similar(sys.coords))
404404
forces_t = forces_nounits_t .* sys.force_units
405405
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
406406
accels_t = forces_t ./ masses(sys)
407-
noise = zero(sys.velocities)
407+
noise = similar(sys.velocities)
408408
using_constraints = length(sys.constraints) > 0
409409
if using_constraints
410-
cons_coord_storage = zero(sys.coords)
411-
cons_vel_storage = zero(sys.velocities)
410+
cons_coord_storage = similar(sys.coords)
411+
cons_vel_storage = similar(sys.velocities)
412412
end
413413

414414
for step_n in 1:n_steps
@@ -432,7 +432,7 @@ end
432432

433433
using_constraints && apply_position_constraints!(sys, cons_coord_storage, cons_vel_storage,
434434
sim.dt; n_threads=n_threads)
435-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
435+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
436436

437437
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
438438
remove_CM_motion!(sys)
@@ -504,17 +504,15 @@ end
504504
α_eff = exp.(-sim.friction * sim.dt .* M_inv / count('O', sim.splitting))
505505
σ_eff = sqrt.((1 * unit(eltype(α_eff))) .- (α_eff .^ 2))
506506

507-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
507+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
508508
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
509509
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
510510
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
511-
forces_nounits_t = ustrip_vec.(zero(sys.coords))
512-
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
513-
forces_nounits_t .= forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, 0;
514-
n_threads=n_threads)
511+
forces_nounits_t = ustrip_vec.(similar(sys.coords))
515512
forces_t = forces_nounits_t .* sys.force_units
513+
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
516514
accels_t = forces_t ./ masses(sys)
517-
noise = zero(sys.velocities)
515+
noise = similar(sys.velocities)
518516

519517
effective_dts = [sim.dt / count(c, sim.splitting) for c in sim.splitting]
520518

@@ -553,7 +551,7 @@ end
553551
step!(args..., neighbors, step_n)
554552
end
555553

556-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
554+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
557555
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
558556
remove_CM_motion!(sys)
559557
end
@@ -568,7 +566,7 @@ end
568566

569567
function A_step!(sys, dt_eff, neighbors, step_n)
570568
sys.coords .+= sys.velocities .* dt_eff
571-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
569+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
572570
return sys
573571
end
574572

@@ -626,15 +624,15 @@ end
626624
@warn "OverdampedLangevin is not currently compatible with constraints, " *
627625
"constraints will be ignored"
628626
end
629-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
627+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
630628
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
631629
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
632630
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
633-
forces_nounits_t = ustrip_vec.(zero(sys.coords))
631+
forces_nounits_t = ustrip_vec.(similar(sys.coords))
634632
forces_t = forces_nounits_t .* sys.force_units
635633
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
636634
accels_t = forces_t ./ masses(sys)
637-
noise = zero(sys.velocities)
635+
noise = similar(sys.velocities)
638636

639637
for step_n in 1:n_steps
640638
forces_nounits_t .= forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, step_n;
@@ -644,7 +642,7 @@ end
644642

645643
random_velocities!(noise, sys, sim.temperature; rng=rng)
646644
sys.coords .+= (accels_t ./ sim.friction) .* sim.dt .+ sqrt((2 / sim.friction) * sim.dt) .* noise
647-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
645+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
648646

649647
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
650648
remove_CM_motion!(sys)
@@ -701,16 +699,16 @@ end
701699
@warn "NoseHoover is not currently compatible with constraints, " *
702700
"constraints will be ignored"
703701
end
704-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
702+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
705703
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
706704
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
707705
forces_nounits_t = ustrip_vec.(zero(sys.coords))
708706
forces_buffer = init_forces_buffer!(sys, forces_nounits_t, n_threads)
709-
forces_nounits_t .= forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, 0;
710-
n_threads=n_threads)
707+
forces_nounits_t = forces_nounits!(forces_nounits_t, sys, neighbors, forces_buffer, 0;
708+
n_threads=n_threads)
711709
forces_t = forces_nounits_t .* sys.force_units
712710
accels_t = forces_t ./ masses(sys)
713-
forces_nounits_t_dt = ustrip_vec.(zero(sys.coords))
711+
forces_nounits_t_dt = ustrip_vec.(similar(sys.coords))
714712
forces_t_dt = forces_nounits_t_dt .* sys.force_units
715713
accels_t_dt = zero(accels_t)
716714
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads, current_forces=forces_t)
@@ -721,7 +719,7 @@ end
721719
v_half .= sys.velocities .+ (accels_t .- (sys.velocities .* zeta)) .* (sim.dt ./ 2)
722720

723721
sys.coords .+= v_half .* sim.dt
724-
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
722+
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
725723

726724
zeta_half = zeta + (sim.dt / (2 * (sim.damping^2))) *
727725
((temperature(sys) / sim.temperature) - 1)
@@ -1064,7 +1062,7 @@ end
10641062
rng=Random.default_rng())
10651063
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
10661064
E_old = potential_energy(sys, neighbors; n_threads=n_threads)
1067-
coords_old = zero(sys.coords)
1065+
coords_old = similar(sys.coords)
10681066

10691067
for step_n in 1:n_steps
10701068
coords_old .= sys.coords

0 commit comments

Comments
 (0)