Skip to content

Commit 18101f2

Browse files
committed
Update to PLDI artifact.
1 parent 81e7590 commit 18101f2

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

experiments.ipynb

Lines changed: 20 additions & 2 deletions
Large diffs are not rendered by default.

experiments/fig_7_air_estimator_evaluation/air_analysis.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
hybrid_iwae_mvd_air = pd.read_csv(
2929
"./training_runs/genjax_air_iwae_2_hybrid_mvd_enum_epochs_41.csv",
3030
)
31-
rws_air_mvd = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv")
31+
rws_air_mvd = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6.csv")
32+
rws_air_mvd_bs1 = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6_bs1.csv")
3233
pyro_reinforce_air = pd.read_csv(
3334
"./training_runs/pyro_air_reinforce_epochs_41.csv",
3435
)
@@ -336,6 +337,17 @@ def go_plot_rws(ax, df, x, mean, label, cmap, color_idx, marker):
336337
num_lines = 2
337338
cmap = plt.cm.get_cmap("cividis", num_lines)
338339

340+
rws_air_l = go_plot_rws(
341+
ax3,
342+
rws_air_mvd_bs1,
343+
"Epoch wall clock times",
344+
"Accuracy",
345+
"Ours (batch size = 1, RWS(K = 10))",
346+
cmap,
347+
0,
348+
"x",
349+
)
350+
339351
rws_air_l = go_plot_rws(
340352
ax3,
341353
rws_air_mvd,

experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,22 +789,22 @@ def body_fn(carry, xs):
789789

790790
key, sub_key = jax.random.split(key)
791791
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
792-
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=40
792+
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=5
793793
)
794794

795795
arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
796796
df = pd.DataFrame(
797797
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
798798
)
799-
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv", index=False)
799+
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6.csv", index=False)
800800

801801
key, sub_key = jax.random.split(key)
802802
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
803-
sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=40
803+
sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=5
804804
)
805805

806806
arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
807807
df = pd.DataFrame(
808808
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
809809
)
810-
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41_bs1.csv", index=False)
810+
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6_bs1.csv", index=False)

0 commit comments

Comments
 (0)