Skip to content

Commit 733ce34

Browse files
Examples tested and updated, NSF example added
1 parent 2c481c4 commit 733ce34

File tree

3 files changed

+207
-13
lines changed

3 files changed

+207
-13
lines changed

example/neural-spline-flow.ipynb

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Neural Spline Flow"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"# Import required packages\n",
17+
"import torch\n",
18+
"import numpy as np\n",
19+
"import normflow as nf\n",
20+
"\n",
21+
"from sklearn.datasets import make_moons\n",
22+
"\n",
23+
"from matplotlib import pyplot as plt\n",
24+
"\n",
25+
"from tqdm import tqdm"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"scrolled": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"# Set up model\n",
37+
"\n",
38+
"# Define flows\n",
39+
"K = 16\n",
40+
"torch.manual_seed(0)\n",
41+
"\n",
42+
"latent_size = 2\n",
43+
"hidden_units = 128\n",
44+
"hidden_layers = 2\n",
45+
"\n",
46+
"flows = []\n",
47+
"for i in range(K):\n",
48+
" flows += [nf.flows.AutoregressiveRationalQuadraticSpline(latent_size, hidden_layers, hidden_units)]\n",
49+
" flows += [nf.flows.InvertibleAffine(latent_size)]\n",
50+
"\n",
51+
"# Set prior and q0\n",
52+
"q0 = nf.distributions.DiagGaussian(2, trainable=False)\n",
53+
" \n",
54+
"# Construct flow model\n",
55+
"nfm = nf.NormalizingFlow(q0=q0, flows=flows)\n",
56+
"\n",
57+
"# Move model on GPU if available\n",
58+
"enable_cuda = True\n",
59+
"device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
60+
"nfm = nfm.to(device)\n",
61+
"\n",
62+
"# Initialize ActNorm\n",
63+
"x_np, _ = make_moons(2 ** 9, noise=0.1)\n",
64+
"x = torch.tensor(x_np).float().to(device)\n",
65+
"_ = nfm.log_prob(x)"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {
72+
"scrolled": false
73+
},
74+
"outputs": [],
75+
"source": [
76+
"# Plot prior distribution\n",
77+
"x_np, _ = make_moons(2 ** 20, noise=0.1)\n",
78+
"plt.figure(figsize=(15, 15))\n",
79+
"plt.hist2d(x_np[:, 0], x_np[:, 1], bins=200)\n",
80+
"plt.show()\n",
81+
"\n",
82+
"# Plot initial posterior distribution\n",
83+
"grid_size = 100\n",
84+
"xx, yy = torch.meshgrid(torch.linspace(-1.5, 2.5, grid_size), torch.linspace(-2, 2, grid_size))\n",
85+
"zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
86+
"zz = zz.to(device)\n",
87+
"\n",
88+
"nfm.eval()\n",
89+
"log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n",
90+
"nfm.train()\n",
91+
"prob = torch.exp(log_prob)\n",
92+
"prob[torch.isnan(prob)] = 0\n",
93+
"\n",
94+
"plt.figure(figsize=(15, 15))\n",
95+
"plt.pcolormesh(xx, yy, prob.data.numpy())\n",
96+
"plt.gca().set_aspect('equal', 'box')\n",
97+
"plt.show()"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"scrolled": false
105+
},
106+
"outputs": [],
107+
"source": [
108+
"# Train model\n",
109+
"max_iter = 10000\n",
110+
"num_samples = 2 ** 9\n",
111+
"show_iter = 500\n",
112+
"\n",
113+
"\n",
114+
"loss_hist = np.array([])\n",
115+
"\n",
116+
"optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-5)\n",
117+
"for it in tqdm(range(max_iter)):\n",
118+
" optimizer.zero_grad()\n",
119+
" \n",
120+
" # Get training samples\n",
121+
" x_np, _ = make_moons(num_samples, noise=0.1)\n",
122+
" x = torch.tensor(x_np).float().to(device)\n",
123+
" \n",
124+
" # Compute loss\n",
125+
" loss = nfm.forward_kld(x)\n",
126+
" \n",
127+
" # Do backprop and optimizer step\n",
128+
" if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
129+
" loss.backward()\n",
130+
" optimizer.step()\n",
131+
" \n",
132+
" # Make layers Lipschitz continuous\n",
133+
" nf.utils.update_lipschitz(nfm, 5)\n",
134+
" \n",
135+
" # Log loss\n",
136+
" loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
137+
" \n",
138+
" # Plot learned posterior\n",
139+
" if (it + 1) % show_iter == 0:\n",
140+
" nfm.eval()\n",
141+
" log_prob = nfm.log_prob(zz)\n",
142+
" nfm.train()\n",
143+
" prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n",
144+
" prob[torch.isnan(prob)] = 0\n",
145+
"\n",
146+
" plt.figure(figsize=(15, 15))\n",
147+
" plt.pcolormesh(xx, yy, prob.data.numpy())\n",
148+
" plt.gca().set_aspect('equal', 'box')\n",
149+
" plt.show()\n",
150+
"\n",
151+
"# Plot loss\n",
152+
"plt.figure(figsize=(10, 10))\n",
153+
"plt.plot(loss_hist, label='loss')\n",
154+
"plt.legend()\n",
155+
"plt.show()"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"metadata": {},
162+
"outputs": [],
163+
"source": [
164+
"# Plot learned posterior distribution\n",
165+
"nfm.eval()\n",
166+
"log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n",
167+
"nfm.train()\n",
168+
"prob = torch.exp(log_prob)\n",
169+
"prob[torch.isnan(prob)] = 0\n",
170+
"\n",
171+
"plt.figure(figsize=(15, 15))\n",
172+
"plt.pcolormesh(xx, yy, prob.data.numpy())\n",
173+
"plt.gca().set_aspect('equal', 'box')\n",
174+
"plt.show()"
175+
]
176+
}
177+
],
178+
"metadata": {
179+
"kernelspec": {
180+
"display_name": "Python 3 (ipykernel)",
181+
"language": "python",
182+
"name": "python3"
183+
},
184+
"language_info": {
185+
"codemirror_mode": {
186+
"name": "ipython",
187+
"version": 3
188+
},
189+
"file_extension": ".py",
190+
"mimetype": "text/x-python",
191+
"name": "python",
192+
"nbconvert_exporter": "python",
193+
"pygments_lexer": "ipython3",
194+
"version": "3.8.11"
195+
}
196+
},
197+
"nbformat": 4,
198+
"nbformat_minor": 4
199+
}

example/planar.ipynb

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,10 @@
3838
"flows = []\n",
3939
"for i in range(K):\n",
4040
" flows += [nf.flows.Planar((2,))]\n",
41-
"#prior = nf.distributions.Sinusoidal(0.2, 4)\n",
42-
"prior = nf.distributions.TwoModes(2, 0.1)\n",
43-
"#prior = torch.distributions.MultivariateNormal(torch.tensor(np.zeros(2), device=device), \n",
44-
"# torch.tensor(np.eye(2), device=device))\n",
45-
"#hidden_units_encoder = np.array([2, 4])\n",
46-
"#encoder_nn = nf.nets.MLP(hidden_units_encoder)\n",
47-
"#q0 = nf.distributions.NNDiagGaussian(encoder_nn)\n",
41+
"target = nf.distributions.TwoModes(2, 0.1)\n",
42+
"\n",
4843
"q0 = nf.distributions.DiagGaussian(2)\n",
49-
"nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=prior)\n",
44+
"nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)\n",
5045
"nfm.to(device)"
5146
]
5247
},
@@ -58,11 +53,11 @@
5853
},
5954
"outputs": [],
6055
"source": [
61-
"# Plot prior distribution\n",
56+
"# Plot target distribution\n",
6257
"grid_size = 200\n",
6358
"xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))\n",
6459
"z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
65-
"log_prob = prior.log_prob(z.to(device)).to('cpu').view(*xx.shape)\n",
60+
"log_prob = target.log_prob(z.to(device)).to('cpu').view(*xx.shape)\n",
6661
"prob = torch.exp(log_prob)\n",
6762
"\n",
6863
"plt.figure(figsize=(10, 10))\n",
@@ -142,7 +137,7 @@
142137
],
143138
"metadata": {
144139
"kernelspec": {
145-
"display_name": "Python 3",
140+
"display_name": "Python 3 (ipykernel)",
146141
"language": "python",
147142
"name": "python3"
148143
},
@@ -156,7 +151,7 @@
156151
"name": "python",
157152
"nbconvert_exporter": "python",
158153
"pygments_lexer": "ipython3",
159-
"version": "3.7.6"
154+
"version": "3.8.11"
160155
},
161156
"pycharm": {
162157
"stem_cell": {

example/residual.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@
193193
"name": "python",
194194
"nbconvert_exporter": "python",
195195
"pygments_lexer": "ipython3",
196-
"version": "3.8.12"
196+
"version": "3.8.11"
197197
}
198198
},
199199
"nbformat": 4,

0 commit comments

Comments
 (0)