Skip to content

Commit a550917

Browse files
committed
almost there
1 parent 0a04a4c commit a550917

File tree

4 files changed

+89
-32
lines changed

4 files changed

+89
-32
lines changed

src/loggers.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ export
2323
AutoCorrelationLogger,
2424
AverageObservableLogger,
2525
ReplicaExchangeLogger,
26-
MonteCarloLogger
26+
MonteCarloLogger,
27+
ImageFlagLogger,
28+
DisplacementLogger
2729

2830
"""
2931
apply_loggers!(system, neighbors=nothing, step_n=0, run_loggers=true;
@@ -300,6 +302,36 @@ function virial_wrapper(sys, neighbors, step_n; n_threads, kwargs...)
300302
return virial(sys, neighbors, step_n; n_threads=n_threads)
301303
end
302304

305+
306+
image_flag_wrapper(sys, args...; kwargs...) = copy(sys.image_flags)
307+
308+
"""
309+
ImageFlagLogger(n_steps; dims::Integer=3)
310+
311+
Log the image flags of a atoms in the system throughout a simulation.
312+
"""
313+
function ImageFlagLogger(n_steps::Integer; dims::Integer=3)
314+
return GeneralObservableLogger(
315+
image_flag_wrapper,
316+
Array{SArray{Tuple{dims}, Int32, 1, dims}, 1},
317+
n_steps,
318+
)
319+
end
320+
321+
displacement_helper(sys, args...; kwargs...) = #*TODO CALCULATE DISPLACEMENTS HERE
322+
323+
"""
324+
DisplacementLogger(n_steps; dims=3)
325+
"""
326+
function DisplacementLogger(T, n_steps::Integer; dims::Integer=3)
327+
return GeneralObservableLogger(
328+
disp_wrapper,
329+
Array{SArray{Tuple{dims}, T, 1, dims}, 1},
330+
n_steps,
331+
)
332+
end
333+
334+
303335
"""
304336
VirialLogger(n_steps)
305337
VirialLogger(T, n_steps)

src/simulators.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Custom simulators should implement this function.
6868
run_loggers=false,
6969
rng=Random.default_rng())
7070
# @inline needed to avoid Enzyme error
71-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
71+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
7272
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
7373
E = potential_energy(sys, neighbors; n_threads=n_threads)
7474
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads, current_potential_energy=E)
@@ -89,7 +89,7 @@ Custom simulators should implement this function.
8989
coords_copy .= sys.coords
9090
sys.coords .+= hn .* F ./ max_force
9191
using_constraints && apply_position_constraints!(sys, coords_copy; n_threads=n_threads)
92-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
92+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
9393

9494
neighbors_copy = neighbors
9595
neighbors = find_neighbors(sys, sys.neighbor_finder, neighbors, step_n;
@@ -145,7 +145,7 @@ end
145145
n_threads::Integer=Threads.nthreads(),
146146
run_loggers=true,
147147
rng=Random.default_rng())
148-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
148+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
149149
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
150150
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
151151
forces_nounits_t = ustrip_vec.(zero(sys.coords))
@@ -172,7 +172,7 @@ end
172172
sys.coords .+= sys.velocities .* sim.dt .+ ((accels_t .* sim.dt .^ 2) ./ 2)
173173
using_constraints && apply_position_constraints!(sys, cons_coord_storage, cons_vel_storage,
174174
sim.dt; n_threads=n_threads)
175-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
175+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
176176

177177
forces_nounits_t_dt .= forces_nounits!(forces_nounits_t_dt, sys, neighbors, forces_buffer,
178178
step_n; n_threads=n_threads)
@@ -236,7 +236,7 @@ end
236236
n_threads::Integer=Threads.nthreads(),
237237
run_loggers=true,
238238
rng=Random.default_rng())
239-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
239+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
240240
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
241241
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
242242
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
@@ -268,7 +268,7 @@ end
268268
sys.velocities .= (sys.coords .- cons_coord_storage) ./ sim.dt
269269
end
270270

271-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
271+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
272272

273273
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
274274
remove_CM_motion!(sys)
@@ -312,7 +312,7 @@ StormerVerlet(; dt, coupling=NoCoupling()) = StormerVerlet(dt, coupling)
312312
n_threads::Integer=Threads.nthreads(),
313313
run_loggers=true,
314314
rng=Random.default_rng())
315-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
315+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
316316
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
317317
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
318318
coords_last, coords_copy = zero(sys.coords), zero(sys.coords)
@@ -339,7 +339,7 @@ StormerVerlet(; dt, coupling=NoCoupling()) = StormerVerlet(dt, coupling)
339339

340340
using_constraints && apply_position_constraints!(sys, coords_copy; n_threads=n_threads)
341341

342-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
342+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
343343
# This is accurate to O(dt)
344344
sys.velocities .= vector.(coords_copy, sys.coords, (sys.boundary,)) ./ sim.dt
345345

@@ -396,7 +396,7 @@ end
396396
n_threads::Integer=Threads.nthreads(),
397397
run_loggers=true,
398398
rng=Random.default_rng())
399-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
399+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
400400
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
401401
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
402402
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
@@ -432,7 +432,7 @@ end
432432

433433
using_constraints && apply_position_constraints!(sys, cons_coord_storage, cons_vel_storage,
434434
sim.dt; n_threads=n_threads)
435-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
435+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
436436

437437
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
438438
remove_CM_motion!(sys)
@@ -504,7 +504,7 @@ end
504504
α_eff = exp.(-sim.friction * sim.dt .* M_inv / count('O', sim.splitting))
505505
σ_eff = sqrt.((1 * unit(eltype(α_eff))) .- (α_eff .^ 2))
506506

507-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
507+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
508508
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
509509
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
510510
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
@@ -553,7 +553,7 @@ end
553553
step!(args..., neighbors, step_n)
554554
end
555555

556-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
556+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
557557
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
558558
remove_CM_motion!(sys)
559559
end
@@ -568,7 +568,7 @@ end
568568

569569
function A_step!(sys, dt_eff, neighbors, step_n)
570570
sys.coords .+= sys.velocities .* dt_eff
571-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
571+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
572572
return sys
573573
end
574574

@@ -626,7 +626,7 @@ end
626626
@warn "OverdampedLangevin is not currently compatible with constraints, " *
627627
"constraints will be ignored"
628628
end
629-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
629+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
630630
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
631631
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
632632
apply_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
@@ -644,7 +644,7 @@ end
644644

645645
random_velocities!(noise, sys, sim.temperature; rng=rng)
646646
sys.coords .+= (accels_t ./ sim.friction) .* sim.dt .+ sqrt((2 / sim.friction) * sim.dt) .* noise
647-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
647+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
648648

649649
if !iszero(sim.remove_CM_motion) && step_n % sim.remove_CM_motion == 0
650650
remove_CM_motion!(sys)
@@ -701,7 +701,7 @@ end
701701
@warn "NoseHoover is not currently compatible with constraints, " *
702702
"constraints will be ignored"
703703
end
704-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
704+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
705705
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
706706
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
707707
forces_nounits_t = ustrip_vec.(zero(sys.coords))
@@ -721,7 +721,7 @@ end
721721
v_half .= sys.velocities .+ (accels_t .- (sys.velocities .* zeta)) .* (sim.dt ./ 2)
722722

723723
sys.coords .+= v_half .* sim.dt
724-
sys.coords .= wrap_coords.(sys.coords, (sys.boundary,))
724+
sys.coords, sys.img_flags .= wrap_coords.(sys.coords, (sys.boundary,), sys.img_flags)
725725

726726
zeta_half = zeta + (sim.dt / (2 * (sim.damping^2))) *
727727
((temperature(sys) / sim.temperature) - 1)

src/spatial.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,19 @@ trim3D(v::SVector{3, T}, boundary::RectangularBoundary{T}) where T = SVector{2,
465465
trim3D(v::SVector{3}, boundary) = v
466466

467467
"""
468-
wrap_coord_1D(c, side_length)
468+
wrap_coord_1D(c, side_length, image_flag)
469469
470470
Ensure a 1D coordinate is within the bounding box and return the coordinate.
471+
Update image flags to track which box each atom is in.
471472
"""
472-
function wrap_coord_1D(c, side_length)
473+
function wrap_coord_1D(c, side_length, image_flag)
473474
if isinf(side_length)
474-
return c
475+
return c, image_flag
475476
else
476-
return c - floor(c / side_length) * side_length
477+
shift = floor(c / side_length)
478+
c -= shift * side_length
479+
image_flag += Int32(shift)
480+
return c, image_flag
477481
end
478482
end
479483

@@ -482,21 +486,29 @@ end
482486
483487
Ensure a coordinate is within the bounding box and return the coordinate.
484488
"""
485-
wrap_coords(v, boundary::Union{CubicBoundary, RectangularBoundary}) = wrap_coord_1D.(v, boundary)
489+
wrap_coords(v, boundary::Union{CubicBoundary, RectangularBoundary}, img_flags) = wrap_coord_1D.(v, boundary, img_flags)
486490

487-
function wrap_coords(v, boundary::TriclinicBoundary)
491+
function wrap_coords(v, boundary::TriclinicBoundary, img_flags)
488492
bv, rs = boundary.basis_vectors, boundary.reciprocal_size
489493
v_wrap = v
490494
# Bound in z-axis
491-
v_wrap -= bv[3] * floor(v_wrap[3] * rs[3])
495+
iz = floor(v_wrap[3] * rs[3])
496+
v_wrap -= bv[3] * iz
492497
# Bound in y-axis
493-
v_wrap -= bv[2] * floor((v_wrap[2] - v_wrap[3] / boundary.tan_bprojyz_cprojyz) * rs[2])
498+
y_term = (v_wrap[2] - v_wrap[3] / boundary.tan_bprojyz_cprojyz)
499+
iy = floor(y_term * rs[2])
500+
v_wrap -= bv[2] * iy
501+
494502
dz_projxy = v_wrap[3] / boundary.tan_c_cprojxy
495503
dx = dz_projxy * boundary.cos_a_cprojxy
496504
dy = dz_projxy * boundary.sin_a_cprojxy
505+
497506
# Bound in x-axis
498-
v_wrap -= bv[1] * floor((v_wrap[1] - dx - (v_wrap[2] - dy) / boundary.tan_a_b) * rs[1])
499-
return v_wrap
507+
x_term = (v_wrap[1] - dx - (v_wrap[2] - dy) / boundary.tan_a_b)
508+
ix = floor(x_term * rs[1])
509+
v_wrap -= bv[1] * ix
510+
511+
return v_wrap, img_flags .+ SVector{Int32}(ix, iy, iz)
500512
end
501513

502514
"""

src/types.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ interface described there.
489489
- `data::DA=nothing`: arbitrary data associated with the system.
490490
"""
491491
mutable struct System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF,
492-
L, F, E, K, M, DA} <: AtomsBase.AbstractSystem{D}
492+
L, F, E, K, M, DA, IF} <: AtomsBase.AbstractSystem{D}
493493
atoms::A
494494
coords::C
495495
boundary::B
@@ -508,6 +508,7 @@ mutable struct System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF,
508508
k::K
509509
masses::M
510510
data::DA
511+
image_flags::IF
511512
end
512513

513514
function System(;
@@ -603,10 +604,13 @@ function System(;
603604
check_units(atoms, coords, vels, energy_units, force_units, pairwise_inters,
604605
specific_inter_lists, general_inters, boundary)
605606

606-
return System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF, L, F, E, K, M, DA}(
607+
img_flags = similar(sys.coords, Int32)
608+
IF = typeof(img_flags)
609+
610+
return System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF, L, F, E, K, M, DA, IF}(
607611
atoms, coords, boundary, vels, atoms_data, topology, pairwise_inters,
608612
specific_inter_lists, general_inters, constraints, neighbor_finder, loggers,
609-
df, force_units, energy_units, k_converted, atom_masses, data)
613+
df, force_units, energy_units, k_converted, atom_masses, data, img_flags)
610614
end
611615

612616
"""
@@ -960,6 +964,15 @@ function ReplicaSystem(;
960964
end
961965
end
962966

967+
# Check if we need to calculate image flags
968+
track_image_flags = map(replica_loggers) do loggers
969+
if any(map(l -> l isa ImageFlagLogger || l isa DisplacementLogger, loggers))
970+
return true
971+
else
972+
return false
973+
end
974+
end
975+
963976
if isnothing(exchange_logger)
964977
exchange_logger = ReplicaExchangeLogger(T, n_replicas)
965978
end
@@ -1035,7 +1048,7 @@ function ReplicaSystem(;
10351048
replica_topology[i], replica_pairwise_inters[i], replica_specific_inter_lists[i],
10361049
replica_general_inters[i], replica_constraints[i],
10371050
deepcopy(neighbor_finder), replica_loggers[i], replica_dfs[i],
1038-
force_units, energy_units, k_converted, atom_masses, nothing) for i in 1:n_replicas)
1051+
force_units, energy_units, k_converted, atom_masses, nothing, track_image_flags[i]) for i in 1:n_replicas)
10391052
R = typeof(replicas)
10401053

10411054
return ReplicaSystem{D, AT, T, A, AD, EL, F, E, K, R, DA}(

0 commit comments

Comments
 (0)