generated from aesara-devs/aesara-repo
-
-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed
Description
Using the sampling steps built by AeMCMC in a scan
loop is not straightforward:
import aesara
import aemcmc
sample_steps, sample_updates, initial_values = aemcmc.construct_sampler(
{Y_rv: y_tt}, srng
)
to_sample_rvs: List[TensorVariable]
inputs = [initial_values[rv] for rv in to_sample_rvs]
outputs = [sample_steps[rv] for rv in to_sample_rvs]
def step_fn(*values):
from aesara.compile.function.pfunc import rebuild_collect_shared
vv_to_values = {inputs[i]: val for i, val in enumerate(values)}
_, new_values, [_, new_updates, _, _] = rebuild_collect_shared(
outputs, inputs=inputs, replace=vv_to_values, updates=sample_updates
)
return new_values, new_updates
n_samples = at.iscalar("n_samples")
outputs, updates = aesara.scan(step_fn, outputs_info=inputs, n_steps=n_samples)
sample_fn = aesara.function(inputs + [n_samples], outputs, updates=updates)
but easily generalizable. We should implement a utility function, e.g. aemcmc.sampling_loop
which, given the outputs of construct_sampler
and a number of iterations n_samples
returns a graph that generate n_samples
.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed