Skip to content

Commit af10302

Browse files
committed
Update to PLDI artifact.
1 parent 9cc371c commit af10302

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

experiments.ipynb

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
},
5555
{
5656
"cell_type": "code",
57-
"execution_count": null,
57+
"execution_count": 2,
5858
"id": "0a82fac8-f6c1-454b-bbf5-2a9274612c1b",
5959
"metadata": {},
6060
"outputs": [
@@ -65,6 +65,22 @@
6565
"/home/femtomc/.cache/pypoetry/virtualenvs/programmable-vi-pldi-2024-wktT2A4B-py3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
6666
" from .autonotebook import tqdm as notebook_tqdm\n"
6767
]
68+
},
69+
{
70+
"name": "stdout",
71+
"output_type": "stream",
72+
"text": [
73+
"GenJAX VI timings:\n",
74+
"Batch sizes: [64, 128, 256, 512, 1024]\n",
75+
"(array([ 3.2510662, 5.572334 , 10.231137 , 17.31219 , 31.963547 ],\n",
76+
" dtype=float32), array([0.21272574, 2.0414262 , 2.1288395 , 1.691629 , 5.045877 ],\n",
77+
" dtype=float32))\n",
78+
"Handcoded timings:\n",
79+
"Batch sizes: [64, 128, 256, 512, 1024]\n",
80+
"(array([ 3.345306 , 5.3966055, 9.804233 , 15.3170595, 31.775038 ],\n",
81+
" dtype=float32), array([0.06193171, 0.04953469, 2.064177 , 4.5305977 , 6.054197 ],\n",
82+
" dtype=float32))\n"
83+
]
6884
}
6985
],
7086
"source": [
@@ -100,7 +116,23 @@
100116
"execution_count": null,
101117
"id": "27525362-6d68-4d8a-841a-ec0c3011e756",
102118
"metadata": {},
103-
"outputs": [],
119+
"outputs": [
120+
{
121+
"name": "stdout",
122+
"output_type": "stream",
123+
"text": [
124+
"\u001b[1mpoetry run pytest experiments/table_2_benchmark_timings --benchmark-disable-gc\u001b[0m\n",
125+
"\u001b[1m============================= test session starts ==============================\u001b[0m\n",
126+
"platform linux -- Python 3.10.13, pytest-8.0.2, pluggy-1.4.0\n",
127+
"benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=True min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)\n",
128+
"rootdir: /home/femtomc/programmable-vi-pldi-2024\n",
129+
"plugins: jaxtyping-0.2.28, benchmark-4.0.0, anyio-4.3.0, typeguard-2.13.3\n",
130+
"collected 6 items \u001b[0m\u001b[1m\u001b[1m\u001b[1m\n",
131+
"\n",
132+
"experiments/table_2_benchmark_timings/test_genjax_enum_air_benchmark.py "
133+
]
134+
}
135+
],
104136
"source": [
105137
"! just table_2"
106138
]

experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def draw_many(imgs, zs, title):
670670
evaluate_accuracy = count_accuracy(mnist, true_counts, guide, batch_size=1000)
671671

672672

673-
def train(key, n=1, num_epochs=40, batch_size=64, learning_rate=1.0e-3):
673+
def train(key, n=2, num_epochs=40, batch_size=64, learning_rate=1.0e-3):
674674
def svi_update(model, guide, optimiser):
675675
def batch_updater(key, params, opt_state, data_batch):
676676
def p_grads(key, params, data):
@@ -788,7 +788,7 @@ def body_fn(carry, xs):
788788

789789
key, sub_key = jax.random.split(key)
790790
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
791-
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=40
791+
sub_key, learning_rate=3.0e-3, n=10, batch_size=64, num_epochs=40
792792
)
793793

794794
arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])

0 commit comments

Comments
 (0)