2
2
export pf_rejuvenate!, pf_move_accept!, pf_move_reweight!
3
3
export move_reweight
4
4
5
+ using Gen: check_observations
6
+
5
7
"""
6
8
pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(),
7
- n_iters::Int=1; method=:move)
9
+ n_iters::Int=1; method=:move, kwargs... )
8
10
9
11
Rejuvenates particles by repeated application of a kernel `kern`. `kern`
10
12
should be a callable which takes a trace as its first argument, and returns
11
13
a tuple with a trace as the first return value. `method` specifies the
12
14
rejuvenation method: `:move` for MCMC moves without a reweighting step,
13
- and `:reweight` for rejuvenation with a reweighting step.
15
+ and `:reweight` for rejuvenation with a reweighting step. Additional keyword
16
+ arguments are passed to the kernel.
14
17
"""
15
18
function pf_rejuvenate! (state:: ParticleFilterView , kern, kern_args:: Tuple = (),
16
- n_iters:: Int = 1 ; method:: Symbol = :move )
19
+ n_iters:: Int = 1 ; method:: Symbol = :move , kwargs ... )
17
20
if method == :move
18
- return pf_move_accept! (state, kern, kern_args, n_iters)
21
+ return pf_move_accept! (state, kern, kern_args, n_iters; kwargs ... )
19
22
elseif method == :reweight
20
- return pf_move_reweight! (state, kern, kern_args, n_iters)
23
+ return pf_move_reweight! (state, kern, kern_args, n_iters; kwargs ... )
21
24
else
22
25
error (" Method $method not recognized." )
23
26
end
24
27
end
25
28
26
29
"""
27
30
pf_move_accept!(state::ParticleFilterState, kern,
28
- kern_args::Tuple=(), n_iters::Int=1)
31
+ kern_args::Tuple=(), n_iters::Int=1; kwargs... )
29
32
30
33
Rejuvenates particles by repeated application of a MCMC kernel `kern`. `kern`
31
34
should be a callable which takes a trace as its first argument, and returns
32
35
a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and
33
36
`accept` is true if the MCMC move was accepted. Subsequent arguments to `kern`
34
- can be supplied with `kern_args`. The kernel is repeatedly applied to each trace
35
- for `n_iters`.
37
+ can be supplied with `kern_args` or `kwargs` . The kernel is repeatedly applied
38
+ to each trace for `n_iters`.
36
39
"""
37
40
function pf_move_accept! (state:: ParticleFilterView ,
38
- kern, kern_args:: Tuple = (), n_iters:: Int = 1 )
41
+ kern, kern_args:: Tuple = (), n_iters:: Int = 1 ;
42
+ kwargs... )
39
43
# Potentially rejuvenate each trace
40
44
for (i, trace) in enumerate (state. traces)
41
45
for k = 1 : n_iters
42
- trace, accept = kern (trace, kern_args... )
46
+ trace, accept = kern (trace, kern_args... ; kwargs ... )
43
47
@debug " Accepted: $accept "
44
48
end
45
49
state. new_traces[i] = trace
50
54
51
55
"""
52
56
pf_move_reweight!(state::ParticleFilterState, kern,
53
- kern_args::Tuple=(), n_iters::Int=1)
57
+ kern_args::Tuple=(), n_iters::Int=1; kwargs... )
54
58
55
59
Rejuvenates and reweights particles by repeated application of a reweighting
56
60
kernel `kern`, as described in [1]. `kern` should be a callable which takes a
57
61
trace as its first argument, and returns a tuple `(trace, rel_weight)`,
58
62
where `trace` is the new trace, and `rel_weight` is the relative log-importance
59
- weight. Subsequent arguments to `kern` can be supplied with `kern_args`.
60
- The kernel is repeatedly applied to each trace for `n_iters`, and the weights
61
- accumulated accordingly. Both the [`move_reweight`](@ref) function and
63
+ weight. Subsequent arguments to `kern` can be supplied with `kern_args` or
64
+ `kwargs`. The kernel is repeatedly applied to each trace for `n_iters`, and the
65
+ weights accumulated accordingly.
66
+
67
+ Both the [`move_reweight`](@ref) function and
62
68
[symmetric trace translators](https://www.gen.dev/stable/ref/trace_translators/)
63
69
can serve as reweighting kernels.
64
70
65
71
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
66
72
online inference," Preprint series. Statistical Research Report, 2013.
67
73
"""
68
74
function pf_move_reweight! (state:: ParticleFilterView ,
69
- kern, kern_args:: Tuple = (), n_iters:: Int = 1 )
75
+ kern, kern_args:: Tuple = (), n_iters:: Int = 1 ;
76
+ kwargs... )
70
77
# Move and reweight each trace
71
78
for (i, trace) in enumerate (state. traces)
72
79
weight = 0
73
80
for k = 1 : n_iters
74
- trace, rel_weight = kern (trace, kern_args... )
81
+ trace, rel_weight = kern (trace, kern_args... ; kwargs ... )
75
82
weight += rel_weight
76
83
@debug " Rel. Weight: $rel_weight "
77
84
end
@@ -83,15 +90,15 @@ function pf_move_reweight!(state::ParticleFilterView,
83
90
end
84
91
85
92
"""
86
- move_reweight(trace, selection)
87
- move_reweight(trace, proposal, proposal_args)
88
- move_reweight(trace, proposal, proposal_args, involution)
89
- move_reweight(trace, proposal_fwd, args_fwd,
90
- proposal_bwd, args_bwd, involution )
93
+ move_reweight(trace, selection; kwargs... )
94
+ move_reweight(trace, proposal, proposal_args; kwargs... )
95
+ move_reweight(trace, proposal, proposal_args, involution; kwargs... )
96
+ move_reweight(trace, proposal_fwd, args_fwd, proposal_bwd, args_bwd,
97
+ involution ; kwargs... )
91
98
92
- Move-reweight MCMC kernel, which takes in a `trace` and returns a new trace
93
- along with a relative importance weight. This can be used for rejuvenation
94
- within a particle filter, as described in [1].
99
+ Move-reweight rejuvenation kernel, which takes in a `trace` and returns a
100
+ new trace along with a relative importance weight. This can be used for
101
+ rejuvenation within a particle filter, as described in [1].
95
102
96
103
Several variants of `move_reweight` exist, differing in the complexity
97
104
involved in proposing and re-weighting random choices:
@@ -108,18 +115,25 @@ involved in proposing and re-weighting random choices:
108
115
adjusts the computation of the relative importance weight by scoring
109
116
the backward choices under the backward proposal.
110
117
118
+ Similar to `metropolis_hastings`, a `check` flag and `observations` choicemap
119
+ can be provided as keyword arguments to ensure that observed choices are
120
+ preserved in the new trace.
121
+
111
122
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
112
123
online inference," Preprint series. Statistical Research Report, 2013.
113
124
"""
114
- function move_reweight (trace:: Trace , selection:: Selection )
125
+ function move_reweight (trace:: Trace , selection:: Selection ;
126
+ check= false , observations= EmptyChoiceMap ())
115
127
args = get_args (trace)
116
128
argdiffs = map ((_) -> NoChange (), args)
117
129
new_trace, rel_weight = regenerate (trace, args, argdiffs, selection)
130
+ check && check_observations (get_choices (new_trace), observations)
118
131
return new_trace, rel_weight
119
132
end
120
133
121
134
function move_reweight (trace:: Trace , proposal:: GenerativeFunction ,
122
- proposal_args:: Tuple )
135
+ proposal_args:: Tuple ;
136
+ check= false , observations= EmptyChoiceMap ())
123
137
model_args = Gen. get_args (trace)
124
138
argdiffs = map ((_) -> NoChange (), model_args)
125
139
fwd_choices, fwd_score, fwd_ret =
@@ -128,6 +142,7 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
128
142
update (trace, model_args, argdiffs, fwd_choices)
129
143
bwd_score, bwd_ret =
130
144
assess (proposal, (new_trace, proposal_args... ), discard)
145
+ check && check_observations (get_choices (new_trace), observations)
131
146
rel_weight = weight - fwd_score + bwd_score
132
147
return new_trace, rel_weight
133
148
end
@@ -140,19 +155,22 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
140
155
involution (trace, fwd_choices, fwd_ret, proposal_args)
141
156
bwd_score, bwd_ret =
142
157
assess (proposal, (new_trace, proposal_args... ), bwd_choices)
158
+ check && check_observations (get_choices (new_trace), observations)
143
159
rel_weight = weight - fwd_score + bwd_score
144
160
return new_trace, rel_weight
145
161
end
146
162
147
163
function move_reweight (trace:: Trace , proposal_fwd:: GenerativeFunction ,
148
164
args_fwd:: Tuple , proposal_bwd:: GenerativeFunction ,
149
- args_bwd:: Tuple , involution)
165
+ args_bwd:: Tuple , involution;
166
+ check= false , observations= EmptyChoiceMap ())
150
167
fwd_choices, fwd_score, fwd_ret =
151
168
propose (proposal_fwd, (trace, args_fwd... ,))
152
169
new_trace, bwd_choices, weight =
153
170
involution (trace, fwd_choices, fwd_ret, args_fwd)
154
171
bwd_score, bwd_ret =
155
172
assess (proposal_bwd, (new_trace, args_bwd... ), bwd_choices)
173
+ check && check_observations (get_choices (new_trace), observations)
156
174
rel_weight = weight - fwd_score + bwd_score
157
175
return new_trace, rel_weight
158
176
end
0 commit comments