Skip to content

Commit 28ad8fd

Browse files
committed
Fix extending notebook.
1 parent 07897ce commit 28ad8fd

File tree

3 files changed

+67
-133
lines changed

3 files changed

+67
-133
lines changed

extending_our_work.ipynb

Lines changed: 54 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
"id": "8eaebd2e-4091-494f-83fd-3b221dacd087",
1414
"metadata": {},
1515
"source": [
16-
"This notebook is intended as a tutorial: a guide to the usage of our system on new problems, which illustrates how several parts of the system work together."
16+
"This notebook shows how to use our library to solve a new inference task, beyond those considered in our experiments. It is intended to illustrate the usage of the library, but assumes some knowledge of variational inference. The inference problem comes from [Pyro's SVI Part I tutorial](https://pyro.ai/examples/svi_part_i.html#A-simple-example)."
1717
]
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 1,
21+
"execution_count": null,
2222
"id": "dd955977-8d0a-4a68-81c7-f031f17c348e",
2323
"metadata": {},
2424
"outputs": [],
@@ -46,35 +46,28 @@
4646
"id": "fcdc73aa-dbec-4c10-adc7-5f1e60125c5e",
4747
"metadata": {},
4848
"source": [
49-
"Models and variational families (guides) in our system are probabilistic programs. We write these using a modeling language which can be accessed via the `genjax.gen` decorator. Below, _addresses (like `\"latent_fairness\"`) denote random variables. Although we don't show it here, deterministic (JAX traceable) code can be freely interwoven between random variable statements."
49+
"Models and variational families (guides) in our system are probabilistic programs. We write these using a modeling language which can be accessed via the `genjax.gen` decorator. \n",
50+
"\n",
51+
"In the code, random choices can be made using the syntax `dist(args) @ \"choice_name\"`, where `\"choice_name\"` is a unique name for the random variable being sampled. In the code below, our model defines a distribution over two random variables, and the variational family, or guide, defines a distribution over only one random variable.\n",
52+
"\n",
53+
"Although we don't show it here, deterministic (JAX traceable) code can be freely interwoven between random variable statements."
5054
]
5155
},
5256
{
5357
"cell_type": "code",
54-
"execution_count": 2,
58+
"execution_count": null,
5559
"id": "7330bcaf",
5660
"metadata": {},
57-
"outputs": [
58-
{
59-
"data": {
60-
"text/plain": [
61-
"BuiltinGenerativeFunction(source=<function model at 0x142c97370>)"
62-
]
63-
},
64-
"execution_count": 2,
65-
"metadata": {},
66-
"output_type": "execute_result"
67-
}
68-
],
61+
"outputs": [],
6962
"source": [
7063
"#####################\n",
7164
"# Model & Guide\n",
7265
"#####################\n",
7366
"\n",
7467
"@genjax.gen\n",
7568
"def model():\n",
76-
" f = genjax.beta(2.0, 2.0) @ \"latent_fairness\"\n",
77-
" _ = genjax.tfp_bernoulli(f) @ \"obs\"\n",
69+
" f = genjax.tfp_beta(10.0, 10.0) @ \"latent_fairness\"\n",
70+
" _ = genjax.tfp_flip(f) @ \"obs\"\n",
7871
"\n",
7972
"\n",
8073
"@genjax.gen\n",
@@ -112,7 +105,7 @@
112105
},
113106
{
114107
"cell_type": "code",
115-
"execution_count": 3,
108+
"execution_count": null,
116109
"id": "702e667a-ae2c-4bb5-8a93-7211862a567f",
117110
"metadata": {},
118111
"outputs": [],
@@ -156,27 +149,19 @@
156149
},
157150
{
158151
"cell_type": "code",
159-
"execution_count": 4,
152+
"execution_count": null,
160153
"id": "6858aa01-1fa5-42a7-8302-6085fec5c08c",
161154
"metadata": {},
162-
"outputs": [
163-
{
164-
"name": "stdout",
165-
"output_type": "stream",
166-
"text": [
167-
"[ True False False False False False False False False False]\n"
168-
]
169-
}
170-
],
155+
"outputs": [],
171156
"source": [
172157
"#####################\n",
173158
"# Data Generation\n",
174159
"#####################\n",
175160
"\n",
176161
"data = []\n",
177-
"for _ in range(1):\n",
162+
"for _ in range(6):\n",
178163
" data.append(True)\n",
179-
"for _ in range(9):\n",
164+
"for _ in range(4):\n",
180165
" data.append(False)\n",
181166
"\n",
182167
"data = jnp.array(data)\n",
@@ -186,44 +171,21 @@
186171
},
187172
{
188173
"cell_type": "code",
189-
"execution_count": 5,
174+
"execution_count": null,
190175
"id": "23962668-0a2f-423d-8a80-e2930d270f77",
191176
"metadata": {},
192-
"outputs": [
193-
{
194-
"data": {
195-
"text/plain": [
196-
"Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x142ce8c10>))"
197-
]
198-
},
199-
"execution_count": 5,
200-
"metadata": {},
201-
"output_type": "execute_result"
202-
}
203-
],
177+
"outputs": [],
204178
"source": [
205179
"objective = elbo(model, guide, genjax.choice_map({\"obs\": data}))\n",
206180
"objective"
207181
]
208182
},
209183
{
210184
"cell_type": "code",
211-
"execution_count": 6,
185+
"execution_count": null,
212186
"id": "aef685d1-e026-4801-961a-033b4c70b53a",
213187
"metadata": {},
214-
"outputs": [
215-
{
216-
"data": {
217-
"text/plain": [
218-
"(Array(-2.9085593, dtype=float32, weak_type=True),\n",
219-
" Array(3.0023632, dtype=float32, weak_type=True))"
220-
]
221-
},
222-
"execution_count": 6,
223-
"metadata": {},
224-
"output_type": "execute_result"
225-
}
226-
],
188+
"outputs": [],
227189
"source": [
228190
"key, sub_key = jax.random.split(key)\n",
229191
"_, q_grads = objective.grad_estimate(sub_key, ((), (1.0, 1.0)))\n",
@@ -235,7 +197,7 @@
235197
"id": "2dd4bf25-31a8-4f17-bcb5-fc01014bd6d3",
236198
"metadata": {},
237199
"source": [
238-
"That all works, like you'd expect it to."
200+
"The `objective.grad_estimate method` takes arguments `(key: PRNGKey, loss_args: Tuple)` and returns an unbiased estimate of the gradient of our objective. We can use these gradient estimates for stochastic optimization of the guide's parameters (see below)."
239201
]
240202
},
241203
{
@@ -256,44 +218,21 @@
256218
},
257219
{
258220
"cell_type": "code",
259-
"execution_count": 7,
221+
"execution_count": null,
260222
"id": "94feba7c-e704-4683-8da8-c7a8e1c37b75",
261223
"metadata": {},
262-
"outputs": [
263-
{
264-
"data": {
265-
"text/plain": [
266-
"Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x158be7d00>))"
267-
]
268-
},
269-
"execution_count": 7,
270-
"metadata": {},
271-
"output_type": "execute_result"
272-
}
273-
],
224+
"outputs": [],
274225
"source": [
275226
"objective = genjax.vi.elbo(model, guide, genjax.choice_map({\"obs\": data}))\n",
276227
"objective"
277228
]
278229
},
279230
{
280231
"cell_type": "code",
281-
"execution_count": 8,
232+
"execution_count": null,
282233
"id": "a330520e-93a5-49ef-b24f-c41e7b2f2dd9",
283234
"metadata": {},
284-
"outputs": [
285-
{
286-
"data": {
287-
"text/plain": [
288-
"(Array(-2.1008544, dtype=float32, weak_type=True),\n",
289-
" Array(-0.2123673, dtype=float32, weak_type=True))"
290-
]
291-
},
292-
"execution_count": 8,
293-
"metadata": {},
294-
"output_type": "execute_result"
295-
}
296-
],
235+
"outputs": [],
297236
"source": [
298237
"key, sub_key = jax.random.split(key)\n",
299238
"_, q_grads = objective.grad_estimate(sub_key, ((), (1.0, 1.0)))\n",
@@ -310,7 +249,7 @@
310249
},
311250
{
312251
"cell_type": "code",
313-
"execution_count": 9,
252+
"execution_count": null,
314253
"id": "6079ea7c",
315254
"metadata": {},
316255
"outputs": [],
@@ -347,19 +286,28 @@
347286
},
348287
{
349288
"cell_type": "code",
350-
"execution_count": 10,
289+
"execution_count": null,
351290
"id": "56a6b73f",
352291
"metadata": {},
353292
"outputs": [],
354293
"source": [
355294
"# setup the optimizer\n",
356-
"adam = optax.adam(5e-3)\n",
295+
"adam = optax.adam(5e-4)\n",
357296
"svi_updater = svi_update(model, guide, adam)\n",
358297
"\n",
359298
"# initialize parameters\n",
360299
"alpha = jnp.array(2.0)\n",
361300
"beta = jnp.array(2.0)\n",
362301
"\n",
302+
"# here we use some facts about the Beta distribution\n",
303+
"start_mean = alpha / (alpha + beta)\n",
304+
"factor = beta / (alpha * (1.0 + alpha + beta))\n",
305+
"start_std = start_mean * jnp.sqrt(factor)\n",
306+
"print(\n",
307+
" \"\\nStarting mean and std \"\n",
308+
" + \"is %.3f +- %.3f\" % (start_mean, start_std)\n",
309+
")\n",
310+
"\n",
363311
"params = (alpha, beta)\n",
364312
"opt_state = adam.init(params)\n",
365313
"\n",
@@ -378,7 +326,7 @@
378326
},
379327
{
380328
"cell_type": "code",
381-
"execution_count": 11,
329+
"execution_count": null,
382330
"id": "e0e2594a",
383331
"metadata": {},
384332
"outputs": [],
@@ -387,7 +335,7 @@
387335
"# Gradient Steps\n",
388336
"#####################\n",
389337
"\n",
390-
"for step in range(2000):\n",
338+
"for step in range(5000):\n",
391339
" key, sub_key = jax.random.split(key)\n",
392340
" params, loss, opt_state = svi_updater(key, data, params, opt_state)"
393341
]
@@ -402,19 +350,10 @@
402350
},
403351
{
404352
"cell_type": "code",
405-
"execution_count": 12,
353+
"execution_count": null,
406354
"id": "9d153225",
407355
"metadata": {},
408-
"outputs": [
409-
{
410-
"name": "stdout",
411-
"output_type": "stream",
412-
"text": [
413-
"\n",
414-
"Based on the data and our prior belief, the fairness of the coin is 0.293 +- 0.191\n"
415-
]
416-
}
417-
],
356+
"outputs": [],
418357
"source": [
419358
"#####################\n",
420359
"# Inferred parameters\n",
@@ -444,7 +383,7 @@
444383
},
445384
{
446385
"cell_type": "code",
447-
"execution_count": 13,
386+
"execution_count": null,
448387
"id": "0fa5e755-98ef-4a17-ae9c-923415a84765",
449388
"metadata": {},
450389
"outputs": [],
@@ -468,12 +407,12 @@
468407
" return updater\n",
469408
"\n",
470409
" # setup the optimizer\n",
471-
" adam = optax.adam(5e-3)\n",
410+
" adam = optax.adam(5e-4)\n",
472411
" svi_updater = svi_update(model, guide, adam)\n",
473412
" \n",
474413
" # initialize parameters\n",
475-
" alpha = jnp.array(2.0)\n",
476-
" beta = jnp.array(2.0)\n",
414+
" alpha = jnp.array(15.0)\n",
415+
" beta = jnp.array(15.0)\n",
477416
" \n",
478417
" params = (alpha, beta)\n",
479418
" opt_state = adam.init(params)\n",
@@ -483,7 +422,7 @@
483422
" _ = svi_updater(key, data, params, opt_state)\n",
484423
"\n",
485424
" losses = []\n",
486-
" for step in range(2000):\n",
425+
" for step in range(5000):\n",
487426
" key, sub_key = jax.random.split(key)\n",
488427
" params, loss, opt_state = svi_updater(key, data, params, opt_state)\n",
489428
" losses.append(loss)\n",
@@ -512,24 +451,15 @@
512451
},
513452
{
514453
"cell_type": "code",
515-
"execution_count": 14,
454+
"execution_count": null,
516455
"id": "3a3fe4d2-86b2-422b-9de4-fb2da6b1b95b",
517456
"metadata": {},
518-
"outputs": [
519-
{
520-
"name": "stdout",
521-
"output_type": "stream",
522-
"text": [
523-
"\n",
524-
"Based on the data and our prior belief, the fairness of the coin is 0.627 +- 0.208\n"
525-
]
526-
}
527-
],
457+
"outputs": [],
528458
"source": [
529459
"data = []\n",
530-
"for _ in range(9):\n",
460+
"for _ in range(8):\n",
531461
" data.append(True)\n",
532-
"for _ in range(1):\n",
462+
"for _ in range(2):\n",
533463
" data.append(False)\n",
534464
"\n",
535465
"data = jnp.array(data)\n",
@@ -555,7 +485,7 @@
555485
},
556486
{
557487
"cell_type": "code",
558-
"execution_count": 15,
488+
"execution_count": null,
559489
"id": "d56ad497-64b4-454f-b49f-86a7c5afe3a7",
560490
"metadata": {},
561491
"outputs": [],
@@ -595,19 +525,10 @@
595525
},
596526
{
597527
"cell_type": "code",
598-
"execution_count": 16,
528+
"execution_count": null,
599529
"id": "69ca6d58-defb-4fc3-85ae-ca3302015a8b",
600530
"metadata": {},
601-
"outputs": [
602-
{
603-
"name": "stdout",
604-
"output_type": "stream",
605-
"text": [
606-
"\n",
607-
"Based on the data and our prior belief, the fairness of the coin is 0.628 +- 0.209\n"
608-
]
609-
}
610-
],
531+
"outputs": [],
611532
"source": [
612533
"run_experiment(key, data, iwelbo)"
613534
]
@@ -638,7 +559,7 @@
638559
},
639560
{
640561
"cell_type": "code",
641-
"execution_count": 17,
562+
"execution_count": null,
642563
"id": "2ffec31d-6e9f-4cf1-ba2c-30c3c7275265",
643564
"metadata": {},
644565
"outputs": [],

0 commit comments

Comments
 (0)