Skip to content

Commit e2433e3

Browse files
committed
Update PLDI artifact.
2 parents a90088f + a76a097 commit e2433e3

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

experiments/table_1_minibatch_gradient_benchmark/genjax_vae_overhead.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,16 @@ def batch_elbo_grad_estimate(key, encoder, decoder, data_batch):
9595
def _inner(key, encoder, decoder, data):
9696
chm = choice_map({"image": data.flatten()})
9797
objective = vi.elbo(model, guide, chm)
98-
return objective.grad_estimate(key, ((decoder,), (chm, encoder,)))
98+
return objective.grad_estimate(
99+
key,
100+
(
101+
(decoder,),
102+
(
103+
chm,
104+
encoder,
105+
),
106+
),
107+
)
99108

100109
sub_keys = jax.random.split(key, len(data_batch))
101110
return jax.vmap(_inner, in_axes=(0, None, None, 0))(

0 commit comments

Comments
 (0)