@@ -104,55 +104,142 @@ print(ssim_score_0.shape, ssim_score_1.shape)
104
104
105
105
## As A Loss
106
106
107
- ![ prediction] ( https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png )
107
+ As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.
108
+
109
+ ### Optimize towards MSSIM=1
110
+
111
+ ![ prediction] ( https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png )
108
112
109
113
``` python
110
114
import matplotlib.pyplot as plt
111
115
import torch
112
116
from pytorch_ssim import SSIM
113
117
from skimage import data
114
- from torch.optim import Adam
118
+ from torch import optim
115
119
116
-
117
- original_image = data.camera() / 255
120
+ original_image = data.moon() / 255
118
121
target_image = torch.from_numpy(original_image).unsqueeze(0 ).unsqueeze(0 ).float().cuda()
119
- predicted_image = torch.rand_like (
122
+ predicted_image = torch.zeros_like (
120
123
target_image, device = target_image.device, dtype = target_image.dtype, requires_grad = True
121
124
)
122
125
initial_image = predicted_image.clone()
123
126
124
127
ssim = SSIM().cuda()
125
128
initial_ssim_value = ssim(predicted_image, target_image)
126
- print (f " Initial ssim: { initial_ssim_value.item():.4f } " )
127
- ssim_value = initial_ssim_value
128
129
129
- optimizer = Adam([predicted_image], lr = 0.01 )
130
+ ssim_value = initial_ssim_value
131
+ optimizer = optim.Adam([predicted_image], lr = 0.01 )
130
132
loss_curves = []
131
- while ssim_value < 0.95 :
133
+ while ssim_value < 0.999 :
132
134
ssim_out = 1 - ssim(predicted_image, target_image)
133
135
loss_curves.append(ssim_out.item())
134
136
ssim_value = 1 - ssim_out.item()
137
+ print (ssim_value)
135
138
ssim_out.backward()
136
139
optimizer.step()
137
140
optimizer.zero_grad()
138
141
139
- fig, axes = plt.subplots(nrows = 1 , ncols = 4 , figsize = (8 , 2 ))
142
+ fig, axes = plt.subplots(nrows = 2 , ncols = 4 , figsize = (8 , 4 ))
140
143
ax = axes.ravel()
141
144
142
145
ax[0 ].imshow(original_image, cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
143
146
ax[0 ].set_title(" Original Image" )
144
147
145
148
ax[1 ].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
146
- ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.4f } " )
149
+ ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.5f } " )
147
150
ax[1 ].set_title(" Initial Image" )
148
151
149
152
ax[2 ].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
150
- ax[2 ].set_xlabel(f " SSIM: { ssim_value:.4f } " )
153
+ ax[2 ].set_xlabel(f " SSIM: { ssim_value:.5f } " )
151
154
ax[2 ].set_title(" Predicted Image" )
152
155
153
156
ax[3 ].plot(loss_curves)
154
157
ax[3 ].set_title(" SSIM Loss Curve" )
155
158
159
+ ax[4 ].set_title(" Original Image" )
160
+ ax[4 ].hist(original_image.ravel(), bins = 256 )
161
+ ax[4 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
162
+ ax[4 ].set_xlabel(" Pixel Intensity" )
163
+
164
+ ax[5 ].set_title(" Initial Image" )
165
+ ax[5 ].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
166
+ ax[5 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
167
+ ax[5 ].set_xlabel(" Pixel Intensity" )
168
+
169
+ ax[6 ].set_title(" Predicted Image" )
170
+ ax[6 ].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
171
+ ax[6 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
172
+ ax[6 ].set_xlabel(" Pixel Intensity" )
173
+
174
+ plt.tight_layout()
175
+ plt.savefig(" prediction.png" )
176
+ ```
177
+
178
+ ### Optimize towards MSSIM=-1
179
+
180
+ ![ prediction] ( https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png )
181
+
182
+ ``` python
183
+ import matplotlib.pyplot as plt
184
+ import torch
185
+ from pytorch_ssim import SSIM
186
+ from skimage import data
187
+ from torch import optim
188
+
189
+ original_image = data.moon() / 255
190
+ target_image = torch.from_numpy(original_image).unsqueeze(0 ).unsqueeze(0 ).float().cuda()
191
+ predicted_image = torch.zeros_like(
192
+ target_image, device = target_image.device, dtype = target_image.dtype, requires_grad = True
193
+ )
194
+ initial_image = predicted_image.clone()
195
+
196
+ ssim = SSIM(L = original_image.max() - original_image.min()).cuda()
197
+ initial_ssim_value = ssim(predicted_image, target_image)
198
+
199
+ ssim_value = initial_ssim_value
200
+ optimizer = optim.Adam([predicted_image], lr = 0.01 )
201
+ loss_curves = []
202
+ while ssim_value > - 0.94 :
203
+ ssim_out = ssim(predicted_image, target_image)
204
+ loss_curves.append(ssim_out.item())
205
+ ssim_value = ssim_out.item()
206
+ print (ssim_value)
207
+ ssim_out.backward()
208
+ optimizer.step()
209
+ optimizer.zero_grad()
210
+
211
+ fig, axes = plt.subplots(nrows = 2 , ncols = 4 , figsize = (8 , 4 ))
212
+ ax = axes.ravel()
213
+
214
+ ax[0 ].imshow(original_image, cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
215
+ ax[0 ].set_title(" Original Image" )
216
+
217
+ ax[1 ].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
218
+ ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.5f } " )
219
+ ax[1 ].set_title(" Initial Image" )
220
+
221
+ ax[2 ].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
222
+ ax[2 ].set_xlabel(f " SSIM: { ssim_value:.5f } " )
223
+ ax[2 ].set_title(" Predicted Image" )
224
+
225
+ ax[3 ].plot(loss_curves)
226
+ ax[3 ].set_title(" SSIM Loss Curve" )
227
+
228
+ ax[4 ].set_title(" Original Image" )
229
+ ax[4 ].hist(original_image.ravel(), bins = 256 )
230
+ ax[4 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
231
+ ax[4 ].set_xlabel(" Pixel Intensity" )
232
+
233
+ ax[5 ].set_title(" Initial Image" )
234
+ ax[5 ].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
235
+ ax[5 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
236
+ ax[5 ].set_xlabel(" Pixel Intensity" )
237
+
238
+ ax[6 ].set_title(" Predicted Image" )
239
+ ax[6 ].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
240
+ ax[6 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
241
+ ax[6 ].set_xlabel(" Pixel Intensity" )
242
+
156
243
plt.tight_layout()
157
244
plt.savefig(" prediction.png" )
158
245
```
0 commit comments