@@ -11,7 +11,7 @@ using Gen: run_first_pass, jacobian_correction, check_round_trip, run_transform
11
11
new_observations::ChoiceMap = EmptyChoiceMap(),
12
12
q_forward::GenerativeFunction,
13
13
q_forward_args::Tuple = (),
14
- f ::Union{TraceTransformDSLProgram,Nothing} = nothing)
14
+ transform ::Union{TraceTransformDSLProgram,Nothing} = nothing)
15
15
Constructor for a extending trace translator.
16
16
Run the translator with:
17
17
(output_trace, log_weight) = translator(input_trace)
@@ -22,7 +22,7 @@ Run the translator with:
22
22
new_observations:: ChoiceMap = EmptyChoiceMap ()
23
23
q_forward:: GenerativeFunction
24
24
q_forward_args:: Tuple = ()
25
- f :: Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
25
+ transform :: Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
26
26
end
27
27
28
28
function (translator:: ExtendingTraceTranslator )(prev_model_trace:: Trace )
@@ -33,14 +33,14 @@ function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
33
33
forward_proposal_score = get_score (forward_proposal_trace)
34
34
35
35
# transform forward proposal
36
- if translator. f === nothing
36
+ if translator. transform === nothing
37
37
constraints = get_choices (forward_proposal_trace)
38
38
log_abs_determinant = 0.0
39
39
else
40
40
first_pass_results =
41
- run_first_pass (translator. f , forward_proposal_trace, nothing )
41
+ run_first_pass (translator. transform , forward_proposal_trace, nothing )
42
42
log_abs_determinant =
43
- jacobian_correction (translator. f , forward_proposal_trace,
43
+ jacobian_correction (translator. transform , forward_proposal_trace,
44
44
nothing , first_pass_results, nothing )
45
45
constraints = first_pass_results. constraints
46
46
end
@@ -97,7 +97,7 @@ the observed random choices in the previous trace.
97
97
q_forward_args:: Tuple = ()
98
98
q_backward:: GenerativeFunction
99
99
q_backward_args:: Tuple = ()
100
- f :: TraceTransformDSLProgram
100
+ transform :: TraceTransformDSLProgram
101
101
end
102
102
103
103
function Gen. inverse (translator:: UpdatingTraceTranslator , prev_model_trace:: Trace ,
@@ -106,23 +106,22 @@ function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trac
106
106
get_args (prev_model_trace), map ((_)-> UnknownChange (), get_args (prev_model_trace)),
107
107
prev_observations, translator. q_backward, translator. q_backward_args,
108
108
translator. q_forward, translator. q_forward_args,
109
- inverse (translator. f ))
109
+ inverse (translator. transform ))
110
110
end
111
111
112
112
function Gen. run_transform (translator:: UpdatingTraceTranslator ,
113
- prev_model_trace:: Trace , forward_proposal_trace:: Trace ,
114
- check:: Bool = false )
115
- @unpack f, new_observations = translator
113
+ prev_model_trace:: Trace , forward_proposal_trace:: Trace )
114
+ @unpack transform, new_observations = translator
116
115
@unpack p_new_args, p_argdiffs, q_backward, q_backward_args = translator
117
- first_pass_results =
118
- Gen . run_first_pass (f , prev_model_trace, forward_proposal_trace)
116
+ first_pass_results = run_first_pass (
117
+ transform , prev_model_trace, forward_proposal_trace)
119
118
constraints = merge (first_pass_results. constraints, new_observations)
120
- ( new_model_trace, _, _, discard) = update (
119
+ new_model_trace, _, _, discard = update (
121
120
prev_model_trace, p_new_args, p_argdiffs, constraints)
122
- log_abs_determinant = jacobian_correction (f, prev_model_trace,
123
- forward_proposal_trace, first_pass_results, discard)
124
- backward_proposal_trace, = generate (q_backward,
125
- (new_model_trace, q_backward_args... ), first_pass_results. u_back)
121
+ log_abs_determinant = jacobian_correction (
122
+ transform, prev_model_trace, forward_proposal_trace, first_pass_results, discard)
123
+ backward_proposal_trace, _ = generate (
124
+ q_backward, (new_model_trace, q_backward_args... ), first_pass_results. u_back)
126
125
return (new_model_trace, backward_proposal_trace, log_abs_determinant)
127
126
end
128
127
@@ -135,7 +134,7 @@ function (translator::UpdatingTraceTranslator)(
135
134
136
135
# apply trace transform
137
136
(new_model_trace, backward_proposal_trace, log_abs_determinant) =
138
- run_transform (translator, prev_model_trace, forward_proposal_trace, check )
137
+ run_transform (translator, prev_model_trace, forward_proposal_trace)
139
138
140
139
# compute log weight
141
140
prev_model_score = get_score (prev_model_trace)
@@ -149,7 +148,7 @@ function (translator::UpdatingTraceTranslator)(
149
148
inverter = inverse (translator, prev_model_trace, prev_observations)
150
149
argdiffs = map ((_) -> UnknownChange (), get_args (prev_model_trace))
151
150
(prev_model_trace_rt, forward_proposal_trace_rt, _) =
152
- run_transform (inverter, new_model_trace, backward_proposal_trace, check )
151
+ run_transform (inverter, new_model_trace, backward_proposal_trace)
153
152
check_round_trip (prev_model_trace, prev_model_trace_rt,
154
153
forward_proposal_trace, forward_proposal_trace_rt)
155
154
end
0 commit comments