Skip to content

Commit 0269780

Browse files
committed
Reduce computation and test instability in stratified resampling.
1 parent 42b0464 commit 0269780

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GenParticleFilters"
22
uuid = "56b76ac4-72ef-411e-b419-6d312ed86a6f"
33
authors = ["Xuan <tanqazx@gmail.com>"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/resample.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,17 @@ function pf_stratified_resample!(state::ParticleFilterState;
148148
state.log_weights : priority_fn.(state.log_weights)
149149
weights = exp.(lognorm(log_priorities))
150150
# Optionally sort particles by weight before resampling
151-
order = sort_particles ? sortperm(log_priorities) : collect(1:n_particles)
151+
order = sort_particles ?
152+
sortperm(log_priorities, rev=true) : collect(1:n_particles)
152153
# Sample particles within each weight stratum [i, i+1/n)
153-
i_old, accum_weight = 0, 0.0
154-
for (i_new, lower) in enumerate((0:n_particles-1)/n_particles)
155-
u = rand() * (1/n_particles) + lower
156-
while accum_weight < u
157-
accum_weight += weights[order[i_old+1]]
158-
i_old += 1
154+
i_old, weight_step, accum_weight = 0, 1/n_particles, 0.0
155+
for (i_new, lower) in enumerate(0.0:weight_step:1.0-weight_step)
156+
if lower + weight_step > accum_weight
157+
u = rand() * weight_step + lower
158+
while accum_weight < u
159+
accum_weight += weights[order[i_old+1]]
160+
i_old += 1
161+
end
159162
end
160163
state.parents[i_new] = order[i_old]
161164
state.new_traces[i_new] = state.traces[order[i_old]]

0 commit comments

Comments
 (0)