Skip to content

Commit 6b8c171

Browse files
committed
use vector instead of image flags
1 parent a550917 commit 6b8c171

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

src/loggers.jl

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -303,35 +303,6 @@ function virial_wrapper(sys, neighbors, step_n; n_threads, kwargs...)
303303
end
304304

305305

306-
image_flag_wrapper(sys, args...; kwargs...) = copy(sys.image_flags)
307-
308-
"""
309-
ImageFlagLogger(n_steps; dims::Integer=3)
310-
311-
Log the image flags of a atoms in the system throughout a simulation.
312-
"""
313-
function ImageFlagLogger(n_steps::Integer; dims::Integer=3)
314-
return GeneralObservableLogger(
315-
image_flag_wrapper,
316-
Array{SArray{Tuple{dims}, Int32, 1, dims}, 1},
317-
n_steps,
318-
)
319-
end
320-
321-
displacement_helper(sys, args...; kwargs...) = #*TODO CALCULATE DISPLACEMENTS HERE
322-
323-
"""
324-
DisplacementLogger(n_steps; dims=3)
325-
"""
326-
function DisplacementLogger(T, n_steps::Integer; dims::Integer=3)
327-
return GeneralObservableLogger(
328-
disp_wrapper,
329-
Array{SArray{Tuple{dims}, T, 1, dims}, 1},
330-
n_steps,
331-
)
332-
end
333-
334-
335306
"""
336307
VirialLogger(n_steps)
337308
VirialLogger(T, n_steps)
@@ -988,3 +959,56 @@ function log_property!(mcl::MonteCarloLogger{T},
988959
push!(mcl.state_changed, success)
989960
push!(mcl.energy_rates, energy_rate)
990961
end
962+
963+
964+
"""
965+
DisplacementLogger(n_steps; n_update::Integer=1, dims::Integer=3)
966+
DisplacementLogger(T, n_steps; n_update::Integer=1, dims::Integer=3)
967+
968+
Log the displacements of atoms in a system throughout a simulation. Displacements are
969+
updated every `n_update` steps and saved every `n_steps` steps.
970+
971+
The logger assumes a particle does not cross 2 periodic boxes in `n_update` steps.
972+
By default `n_update` is set to one to mitigate this assumption, but it can be
973+
set to a higher value to reduce cost. `n_update` must be a multiple of `n_steps`.
974+
"""
975+
mutable struct DisplacementLogger{A, B}
976+
displacements::Vector{A}
977+
last_coords::Vector{B}
978+
last_displacements::Vector{B}
979+
n_steps::Int
980+
n_update::Int
981+
end
982+
983+
function DisplacementLogger(T, n_steps::Integer; n_update::Integer = 1, dims::Integer = 3)
984+
return DisplacementLogger(n_steps; T = T, n_update = n_update, dims = dims)
985+
end
986+
987+
function DisplacementLogger(n_steps; T = typeof(one(DefaultFloat)u"nm"), n_update::Integer = 1, dims::Integer = 3)
988+
B = SArray{Tuple{dims}, T, 1, dims}
989+
A = Array{B, 1}
990+
if n_update % n_steps != 0
991+
throw(ArgumentError("DisplacementLogger: n_update ($n_update) must be a multiple of n_steps ($(n_steps)) and >= n_steps"))
992+
end
993+
return DisplacementLogger{A, B}(A[], B[], B[], n_steps, n_update)
994+
end
995+
996+
Base.values(dl::DisplacementLogger) = dl.displacements
997+
998+
function log_property!(dl::AverageObservableLogger{T}, s::System, neighbors=nothing,
999+
step_n::Integer=0; kwargs...) where T
1000+
1001+
if (step_n % dl.n_update) == 0
1002+
dl.last_displacements .+= vector.(dl.last_coords, s.coords, s.boundary)
1003+
dl.last_coords .= s.coords
1004+
if (step_n % dl.n_update) == 0
1005+
push!(dl.displacements, copy(dl.last_displacements))
1006+
end
1007+
end
1008+
end
1009+
1010+
function Base.show(io::IO, dl::DisplacementLogger)
1011+
print(io, "DisplacementLogger with updating every ", dl.n_update, " steps, saving every ",
1012+
dl.n_steps, " steps with", length(dl.displacements), " displacements in storage.")
1013+
end
1014+

src/types.jl

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ interface described there.
489489
- `data::DA=nothing`: arbitrary data associated with the system.
490490
"""
491491
mutable struct System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF,
492-
L, F, E, K, M, DA, IF} <: AtomsBase.AbstractSystem{D}
492+
L, F, E, K, M, DA} <: AtomsBase.AbstractSystem{D}
493493
atoms::A
494494
coords::C
495495
boundary::B
@@ -508,7 +508,6 @@ mutable struct System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF,
508508
k::K
509509
masses::M
510510
data::DA
511-
image_flags::IF
512511
end
513512

514513
function System(;
@@ -604,13 +603,10 @@ function System(;
604603
check_units(atoms, coords, vels, energy_units, force_units, pairwise_inters,
605604
specific_inter_lists, general_inters, boundary)
606605

607-
img_flags = similar(sys.coords, Int32)
608-
IF = typeof(img_flags)
609-
610606
return System{D, AT, T, A, C, B, V, AD, TO, PI, SI, GI, CN, NF, L, F, E, K, M, DA, IF}(
611607
atoms, coords, boundary, vels, atoms_data, topology, pairwise_inters,
612608
specific_inter_lists, general_inters, constraints, neighbor_finder, loggers,
613-
df, force_units, energy_units, k_converted, atom_masses, data, img_flags)
609+
df, force_units, energy_units, k_converted, atom_masses, data)
614610
end
615611

616612
"""
@@ -964,15 +960,6 @@ function ReplicaSystem(;
964960
end
965961
end
966962

967-
# Check if we need to calculate image flags
968-
track_image_flags = map(replica_loggers) do loggers
969-
if any(map(l -> l isa ImageFlagLogger || l isa DisplacementLogger, loggers))
970-
return true
971-
else
972-
return false
973-
end
974-
end
975-
976963
if isnothing(exchange_logger)
977964
exchange_logger = ReplicaExchangeLogger(T, n_replicas)
978965
end
@@ -1048,7 +1035,7 @@ function ReplicaSystem(;
10481035
replica_topology[i], replica_pairwise_inters[i], replica_specific_inter_lists[i],
10491036
replica_general_inters[i], replica_constraints[i],
10501037
deepcopy(neighbor_finder), replica_loggers[i], replica_dfs[i],
1051-
force_units, energy_units, k_converted, atom_masses, nothing, track_image_flags[i]) for i in 1:n_replicas)
1038+
force_units, energy_units, k_converted, atom_masses, nothing) for i in 1:n_replicas)
10521039
R = typeof(replicas)
10531040

10541041
return ReplicaSystem{D, AT, T, A, AD, EL, F, E, K, R, DA}(

0 commit comments

Comments
 (0)