Skip to content

Commit 3394832

Browse files
committed
More informative translator fieldname for trace transform.
1 parent fbf8b4b commit 3394832

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

src/translate.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Gen: run_first_pass, jacobian_correction, check_round_trip, run_transform
1111
new_observations::ChoiceMap = EmptyChoiceMap(),
1212
q_forward::GenerativeFunction,
1313
q_forward_args::Tuple = (),
14-
f::Union{TraceTransformDSLProgram,Nothing} = nothing)
14+
transform::Union{TraceTransformDSLProgram,Nothing} = nothing)
1515
Constructor for a extending trace translator.
1616
Run the translator with:
1717
(output_trace, log_weight) = translator(input_trace)
@@ -22,7 +22,7 @@ Run the translator with:
2222
new_observations::ChoiceMap = EmptyChoiceMap()
2323
q_forward::GenerativeFunction
2424
q_forward_args::Tuple = ()
25-
f::Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
25+
transform::Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
2626
end
2727

2828
function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
@@ -33,14 +33,14 @@ function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
3333
forward_proposal_score = get_score(forward_proposal_trace)
3434

3535
# transform forward proposal
36-
if translator.f === nothing
36+
if translator.transform === nothing
3737
constraints = get_choices(forward_proposal_trace)
3838
log_abs_determinant = 0.0
3939
else
4040
first_pass_results =
41-
run_first_pass(translator.f, forward_proposal_trace, nothing)
41+
run_first_pass(translator.transform, forward_proposal_trace, nothing)
4242
log_abs_determinant =
43-
jacobian_correction(translator.f, forward_proposal_trace,
43+
jacobian_correction(translator.transform, forward_proposal_trace,
4444
nothing, first_pass_results, nothing)
4545
constraints = first_pass_results.constraints
4646
end
@@ -97,7 +97,7 @@ the observed random choices in the previous trace.
9797
q_forward_args::Tuple = ()
9898
q_backward::GenerativeFunction
9999
q_backward_args::Tuple = ()
100-
f::TraceTransformDSLProgram
100+
transform::TraceTransformDSLProgram
101101
end
102102

103103
function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trace,
@@ -106,23 +106,22 @@ function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trac
106106
get_args(prev_model_trace), map((_)->UnknownChange(), get_args(prev_model_trace)),
107107
prev_observations, translator.q_backward, translator.q_backward_args,
108108
translator.q_forward, translator.q_forward_args,
109-
inverse(translator.f))
109+
inverse(translator.transform))
110110
end
111111

112112
function Gen.run_transform(translator::UpdatingTraceTranslator,
113-
prev_model_trace::Trace, forward_proposal_trace::Trace,
114-
check::Bool=false)
115-
@unpack f, new_observations = translator
113+
prev_model_trace::Trace, forward_proposal_trace::Trace)
114+
@unpack transform, new_observations = translator
116115
@unpack p_new_args, p_argdiffs, q_backward, q_backward_args = translator
117-
first_pass_results =
118-
Gen.run_first_pass(f, prev_model_trace, forward_proposal_trace)
116+
first_pass_results = run_first_pass(
117+
transform, prev_model_trace, forward_proposal_trace)
119118
constraints = merge(first_pass_results.constraints, new_observations)
120-
(new_model_trace, _, _, discard) = update(
119+
new_model_trace, _, _, discard = update(
121120
prev_model_trace, p_new_args, p_argdiffs, constraints)
122-
log_abs_determinant = jacobian_correction(f, prev_model_trace,
123-
forward_proposal_trace, first_pass_results, discard)
124-
backward_proposal_trace, = generate(q_backward,
125-
(new_model_trace, q_backward_args...), first_pass_results.u_back)
121+
log_abs_determinant = jacobian_correction(
122+
transform, prev_model_trace, forward_proposal_trace, first_pass_results, discard)
123+
backward_proposal_trace, _ = generate(
124+
q_backward, (new_model_trace, q_backward_args...), first_pass_results.u_back)
126125
return (new_model_trace, backward_proposal_trace, log_abs_determinant)
127126
end
128127

@@ -135,7 +134,7 @@ function (translator::UpdatingTraceTranslator)(
135134

136135
# apply trace transform
137136
(new_model_trace, backward_proposal_trace, log_abs_determinant) =
138-
run_transform(translator, prev_model_trace, forward_proposal_trace, check)
137+
run_transform(translator, prev_model_trace, forward_proposal_trace)
139138

140139
# compute log weight
141140
prev_model_score = get_score(prev_model_trace)
@@ -149,7 +148,7 @@ function (translator::UpdatingTraceTranslator)(
149148
inverter = inverse(translator, prev_model_trace, prev_observations)
150149
argdiffs = map((_) -> UnknownChange(), get_args(prev_model_trace))
151150
(prev_model_trace_rt, forward_proposal_trace_rt, _) =
152-
run_transform(inverter, new_model_trace, backward_proposal_trace, check)
151+
run_transform(inverter, new_model_trace, backward_proposal_trace)
153152
check_round_trip(prev_model_trace, prev_model_trace_rt,
154153
forward_proposal_trace, forward_proposal_trace_rt)
155154
end

src/update.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function pf_update!(state::ParticleFilterView, new_args::Tuple,
9797
n_particles = length(state.traces)
9898
translator = GenParticleFilters.ExtendingTraceTranslator(
9999
p_new_args=new_args, p_argdiffs=argdiffs, new_observations=observations,
100-
q_forward=proposal, q_forward_args=proposal_args, f=transform)
100+
q_forward=proposal, q_forward_args=proposal_args, transform=transform)
101101
return pf_update!(state, translator)
102102
end
103103

@@ -169,6 +169,6 @@ function pf_update!(state::ParticleFilterView, new_args::Tuple,
169169
translator = GenParticleFilters.UpdatingTraceTranslator(
170170
p_new_args=new_args, p_argdiffs=argdiffs, new_observations=observations,
171171
q_forward=fwd_proposal, q_forward_args=fwd_args,
172-
q_backward=bwd_proposal, q_backward_args=bwd_args, f=transform)
172+
q_backward=bwd_proposal, q_backward_args=bwd_args, transform=transform)
173173
return pf_update!(state, translator; check=check)
174174
end

0 commit comments

Comments
 (0)