Skip to content

Commit ff718d9

Browse files
committed
Support passing of keyword args to rejuv. kernels.
1 parent 5c3b1e8 commit ff718d9

File tree

2 files changed

+54
-33
lines changed

2 files changed

+54
-33
lines changed

src/rejuvenate.jl

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,48 @@
22
export pf_rejuvenate!, pf_move_accept!, pf_move_reweight!
33
export move_reweight
44

5+
using Gen: check_observations
6+
57
"""
68
pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(),
7-
n_iters::Int=1; method=:move)
9+
n_iters::Int=1; method=:move, kwargs...)
810
911
Rejuvenates particles by repeated application of a kernel `kern`. `kern`
1012
should be a callable which takes a trace as its first argument, and returns
1113
a tuple with a trace as the first return value. `method` specifies the
1214
rejuvenation method: `:move` for MCMC moves without a reweighting step,
13-
and `:reweight` for rejuvenation with a reweighting step.
15+
and `:reweight` for rejuvenation with a reweighting step. Additional keyword
16+
arguments are passed to the kernel.
1417
"""
1518
function pf_rejuvenate!(state::ParticleFilterView, kern, kern_args::Tuple=(),
16-
n_iters::Int=1; method::Symbol=:move)
19+
n_iters::Int=1; method::Symbol=:move, kwargs...)
1720
if method == :move
18-
return pf_move_accept!(state, kern, kern_args, n_iters)
21+
return pf_move_accept!(state, kern, kern_args, n_iters; kwargs...)
1922
elseif method == :reweight
20-
return pf_move_reweight!(state, kern, kern_args, n_iters)
23+
return pf_move_reweight!(state, kern, kern_args, n_iters; kwargs...)
2124
else
2225
error("Method $method not recognized.")
2326
end
2427
end
2528

2629
"""
2730
pf_move_accept!(state::ParticleFilterState, kern,
28-
kern_args::Tuple=(), n_iters::Int=1)
31+
kern_args::Tuple=(), n_iters::Int=1; kwargs...)
2932
3033
Rejuvenates particles by repeated application of a MCMC kernel `kern`. `kern`
3134
should be a callable which takes a trace as its first argument, and returns
3235
a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and
3336
`accept` is true if the MCMC move was accepted. Subsequent arguments to `kern`
34-
can be supplied with `kern_args`. The kernel is repeatedly applied to each trace
35-
for `n_iters`.
37+
can be supplied with `kern_args` or `kwargs`. The kernel is repeatedly applied
38+
to each trace for `n_iters`.
3639
"""
3740
function pf_move_accept!(state::ParticleFilterView,
38-
kern, kern_args::Tuple=(), n_iters::Int=1)
41+
kern, kern_args::Tuple=(), n_iters::Int=1;
42+
kwargs...)
3943
# Potentially rejuvenate each trace
4044
for (i, trace) in enumerate(state.traces)
4145
for k = 1:n_iters
42-
trace, accept = kern(trace, kern_args...)
46+
trace, accept = kern(trace, kern_args...; kwargs...)
4347
@debug "Accepted: $accept"
4448
end
4549
state.new_traces[i] = trace
@@ -50,28 +54,31 @@ end
5054

5155
"""
5256
pf_move_reweight!(state::ParticleFilterState, kern,
53-
kern_args::Tuple=(), n_iters::Int=1)
57+
kern_args::Tuple=(), n_iters::Int=1; kwargs...)
5458
5559
Rejuvenates and reweights particles by repeated application of a reweighting
5660
kernel `kern`, as described in [1]. `kern` should be a callable which takes a
5761
trace as its first argument, and returns a tuple `(trace, rel_weight)`,
5862
where `trace` is the new trace, and `rel_weight` is the relative log-importance
59-
weight. Subsequent arguments to `kern` can be supplied with `kern_args`.
60-
The kernel is repeatedly applied to each trace for `n_iters`, and the weights
61-
accumulated accordingly. Both the [`move_reweight`](@ref) function and
63+
weight. Subsequent arguments to `kern` can be supplied with `kern_args` or
64+
`kwargs`. The kernel is repeatedly applied to each trace for `n_iters`, and the
65+
weights accumulated accordingly.
66+
67+
Both the [`move_reweight`](@ref) function and
6268
[symmetric trace translators](https://www.gen.dev/stable/ref/trace_translators/)
6369
can serve as reweighting kernels.
6470
6571
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
6672
online inference," Preprint series. Statistical Research Report, 2013.
6773
"""
6874
function pf_move_reweight!(state::ParticleFilterView,
69-
kern, kern_args::Tuple=(), n_iters::Int=1)
75+
kern, kern_args::Tuple=(), n_iters::Int=1;
76+
kwargs...)
7077
# Move and reweight each trace
7178
for (i, trace) in enumerate(state.traces)
7279
weight = 0
7380
for k = 1:n_iters
74-
trace, rel_weight = kern(trace, kern_args...)
81+
trace, rel_weight = kern(trace, kern_args...; kwargs...)
7582
weight += rel_weight
7683
@debug "Rel. Weight: $rel_weight"
7784
end
@@ -83,15 +90,15 @@ function pf_move_reweight!(state::ParticleFilterView,
8390
end
8491

8592
"""
86-
move_reweight(trace, selection)
87-
move_reweight(trace, proposal, proposal_args)
88-
move_reweight(trace, proposal, proposal_args, involution)
89-
move_reweight(trace, proposal_fwd, args_fwd,
90-
proposal_bwd, args_bwd, involution)
93+
move_reweight(trace, selection; kwargs...)
94+
move_reweight(trace, proposal, proposal_args; kwargs...)
95+
move_reweight(trace, proposal, proposal_args, involution; kwargs...)
96+
move_reweight(trace, proposal_fwd, args_fwd, proposal_bwd, args_bwd,
97+
involution ; kwargs...)
9198
92-
Move-reweight MCMC kernel, which takes in a `trace` and returns a new trace
93-
along with a relative importance weight. This can be used for rejuvenation
94-
within a particle filter, as described in [1].
99+
Move-reweight rejuvenation kernel, which takes in a `trace` and returns a
100+
new trace along with a relative importance weight. This can be used for
101+
rejuvenation within a particle filter, as described in [1].
95102
96103
Several variants of `move_reweight` exist, differing in the complexity
97104
involved in proposing and re-weighting random choices:
@@ -108,18 +115,25 @@ involved in proposing and re-weighting random choices:
108115
adjusts the computation of the relative importance weight by scoring
109116
the backward choices under the backward proposal.
110117
118+
Similar to `metropolis_hastings`, a `check` flag and `observations` choicemap
119+
can be provided as keyword arguments to ensure that observed choices are
120+
preserved in the new trace.
121+
111122
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
112123
online inference," Preprint series. Statistical Research Report, 2013.
113124
"""
114-
function move_reweight(trace::Trace, selection::Selection)
125+
function move_reweight(trace::Trace, selection::Selection;
126+
check=false, observations=EmptyChoiceMap())
115127
args = get_args(trace)
116128
argdiffs = map((_) -> NoChange(), args)
117129
new_trace, rel_weight = regenerate(trace, args, argdiffs, selection)
130+
check && check_observations(get_choices(new_trace), observations)
118131
return new_trace, rel_weight
119132
end
120133

121134
function move_reweight(trace::Trace, proposal::GenerativeFunction,
122-
proposal_args::Tuple)
135+
proposal_args::Tuple;
136+
check=false, observations=EmptyChoiceMap())
123137
model_args = Gen.get_args(trace)
124138
argdiffs = map((_) -> NoChange(), model_args)
125139
fwd_choices, fwd_score, fwd_ret =
@@ -128,6 +142,7 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
128142
update(trace, model_args, argdiffs, fwd_choices)
129143
bwd_score, bwd_ret =
130144
assess(proposal, (new_trace, proposal_args...), discard)
145+
check && check_observations(get_choices(new_trace), observations)
131146
rel_weight = weight - fwd_score + bwd_score
132147
return new_trace, rel_weight
133148
end
@@ -140,19 +155,22 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
140155
involution(trace, fwd_choices, fwd_ret, proposal_args)
141156
bwd_score, bwd_ret =
142157
assess(proposal, (new_trace, proposal_args...), bwd_choices)
158+
check && check_observations(get_choices(new_trace), observations)
143159
rel_weight = weight - fwd_score + bwd_score
144160
return new_trace, rel_weight
145161
end
146162

147163
function move_reweight(trace::Trace, proposal_fwd::GenerativeFunction,
148164
args_fwd::Tuple, proposal_bwd::GenerativeFunction,
149-
args_bwd::Tuple, involution)
165+
args_bwd::Tuple, involution;
166+
check=false, observations=EmptyChoiceMap())
150167
fwd_choices, fwd_score, fwd_ret =
151168
propose(proposal_fwd, (trace, args_fwd...,))
152169
new_trace, bwd_choices, weight =
153170
involution(trace, fwd_choices, fwd_ret, args_fwd)
154171
bwd_score, bwd_ret =
155172
assess(proposal_bwd, (new_trace, args_bwd...), bwd_choices)
173+
check && check_observations(get_choices(new_trace), observations)
156174
rel_weight = weight - fwd_score + bwd_score
157175
return new_trace, rel_weight
158176
end

test/runtests.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,19 +246,21 @@ trace, _ = generate(line_model, (1,), observations)
246246
slope, outlier = trace[:slope], trace[out_addr]
247247

248248
# Test selection variant
249-
expected_w(out_old, out_new, slope) =
249+
expected_w = (out_old, out_new, slope) -> begin
250250
logpdf(normal, 0, slope, out_new ? 10. : 1.) -
251251
logpdf(normal, 0, slope, out_old ? 10. : 1.)
252+
end
252253
trs_ws = [move_reweight(trace, select(out_addr)) for i in 1:100]
253254
@test all(w expected_w(outlier, tr[out_addr], slope) for (tr, w) in trs_ws)
254255

255256
# Test proposal variant
256-
expected_w(out_old, out_new, slope) =
257+
expected_w = (out_old, out_new, slope) -> begin
257258
logpdf(bernoulli, out_new, 0.1) - logpdf(bernoulli, out_old, 0.1) +
258259
logpdf(normal, 0, slope, out_new ? 10. : 1.) -
259260
logpdf(normal, 0, slope, out_old ? 10. : 1.) +
260261
(out_old == out_new ? 0.0 :
261262
logpdf(bernoulli, out_old, 0.9) - logpdf(bernoulli, out_old, 0.1))
263+
end
262264
@gen outlier_propose(tr, idx) = {:line => idx => :outlier} ~ bernoulli(0.9)
263265
trs_ws = [move_reweight(trace, outlier_propose, (1,)) for i in 1:100]
264266
@test all(w expected_w(outlier, tr[out_addr], slope) for (tr, w) in trs_ws)
@@ -271,7 +273,7 @@ logger = SimpleLogger(buffer, Logging.Debug)
271273
state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100)
272274
old_traces = get_traces(state)
273275
with_logger(logger) do
274-
pf_move_accept!(state, metropolis_hastings, (select(:slope),), 1)
276+
pf_move_accept!(state, mh, (select(:slope),), 1; check=false)
275277
end
276278

277279
# Extract acceptances from debug log
@@ -293,7 +295,7 @@ logger = SimpleLogger(buffer, Logging.Debug)
293295
state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100)
294296
old_weights = copy(get_log_weights(state))
295297
with_logger(logger) do
296-
pf_move_reweight!(state, move_reweight, (select(:slope),), 1)
298+
pf_move_reweight!(state, move_reweight, (select(:slope),), 1; check=false)
297299
end
298300
new_weights = copy(get_log_weights(state))
299301

@@ -315,9 +317,10 @@ state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100)
315317
old_traces = get_traces(state)[1:50]
316318
old_weights = get_log_weights(state)[51:end]
317319

320+
kern_args = (select(:slope),)
318321
with_logger(logger) do
319-
pf_move_accept!(state[1:50], metropolis_hastings, (select(:slope),), 1)
320-
pf_move_reweight!(state[51:end], move_reweight, (select(:slope),), 1)
322+
pf_rejuvenate!(state[1:50], mh, kern_args, 1; method=:move)
323+
pf_rejuvenate!(state[51:end], move_reweight, kern_args, 1; method=:reweight)
321324
end
322325

323326
# Extract acceptances and relative weights from debug log

0 commit comments

Comments
 (0)