Skip to content

Commit 92bf0ec

Browse files
authored
Add some better examples.
1 parent 8957131 commit 92bf0ec

File tree

1 file changed

+99
-12
lines changed

1 file changed

+99
-12
lines changed

README.md

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,55 +104,142 @@ print(ssim_score_0.shape, ssim_score_1.shape)
104104

105105
## As A Loss
106106

107-
![prediction](https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png)
107+
As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.
108+
109+
### Optimize towards MSSIM=1
110+
111+
![prediction](https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png)
108112

109113
```python
110114
import matplotlib.pyplot as plt
111115
import torch
112116
from pytorch_ssim import SSIM
113117
from skimage import data
114-
from torch.optim import Adam
118+
from torch import optim
115119

116-
117-
original_image = data.camera() / 255
120+
original_image = data.moon() / 255
118121
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
119-
predicted_image = torch.rand_like(
122+
predicted_image = torch.zeros_like(
120123
target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
121124
)
122125
initial_image = predicted_image.clone()
123126

124127
ssim = SSIM().cuda()
125128
initial_ssim_value = ssim(predicted_image, target_image)
126-
print(f"Initial ssim: {initial_ssim_value.item():.4f}")
127-
ssim_value = initial_ssim_value
128129

129-
optimizer = Adam([predicted_image], lr=0.01)
130+
ssim_value = initial_ssim_value
131+
optimizer = optim.Adam([predicted_image], lr=0.01)
130132
loss_curves = []
131-
while ssim_value < 0.95:
133+
while ssim_value < 0.999:
132134
ssim_out = 1 - ssim(predicted_image, target_image)
133135
loss_curves.append(ssim_out.item())
134136
ssim_value = 1 - ssim_out.item()
137+
print(ssim_value)
135138
ssim_out.backward()
136139
optimizer.step()
137140
optimizer.zero_grad()
138141

139-
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(8, 2))
142+
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
140143
ax = axes.ravel()
141144

142145
ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
143146
ax[0].set_title("Original Image")
144147

145148
ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
146-
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.4f}")
149+
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
147150
ax[1].set_title("Initial Image")
148151

149152
ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
150-
ax[2].set_xlabel(f"SSIM: {ssim_value:.4f}")
153+
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
151154
ax[2].set_title("Predicted Image")
152155

153156
ax[3].plot(loss_curves)
154157
ax[3].set_title("SSIM Loss Curve")
155158

159+
ax[4].set_title("Original Image")
160+
ax[4].hist(original_image.ravel(), bins=256)
161+
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
162+
ax[4].set_xlabel("Pixel Intensity")
163+
164+
ax[5].set_title("Initial Image")
165+
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
166+
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
167+
ax[5].set_xlabel("Pixel Intensity")
168+
169+
ax[6].set_title("Predicted Image")
170+
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
171+
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
172+
ax[6].set_xlabel("Pixel Intensity")
173+
174+
plt.tight_layout()
175+
plt.savefig("prediction.png")
176+
```
177+
178+
### Optimize towards MSSIM=-1
179+
180+
![prediction](https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png)
181+
182+
```python
183+
import matplotlib.pyplot as plt
184+
import torch
185+
from pytorch_ssim import SSIM
186+
from skimage import data
187+
from torch import optim
188+
189+
original_image = data.moon() / 255
190+
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
191+
predicted_image = torch.zeros_like(
192+
target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
193+
)
194+
initial_image = predicted_image.clone()
195+
196+
ssim = SSIM(L=original_image.max() - original_image.min()).cuda()
197+
initial_ssim_value = ssim(predicted_image, target_image)
198+
199+
ssim_value = initial_ssim_value
200+
optimizer = optim.Adam([predicted_image], lr=0.01)
201+
loss_curves = []
202+
while ssim_value > -0.94:
203+
ssim_out = ssim(predicted_image, target_image)
204+
loss_curves.append(ssim_out.item())
205+
ssim_value = ssim_out.item()
206+
print(ssim_value)
207+
ssim_out.backward()
208+
optimizer.step()
209+
optimizer.zero_grad()
210+
211+
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
212+
ax = axes.ravel()
213+
214+
ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
215+
ax[0].set_title("Original Image")
216+
217+
ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
218+
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
219+
ax[1].set_title("Initial Image")
220+
221+
ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
222+
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
223+
ax[2].set_title("Predicted Image")
224+
225+
ax[3].plot(loss_curves)
226+
ax[3].set_title("SSIM Loss Curve")
227+
228+
ax[4].set_title("Original Image")
229+
ax[4].hist(original_image.ravel(), bins=256)
230+
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
231+
ax[4].set_xlabel("Pixel Intensity")
232+
233+
ax[5].set_title("Initial Image")
234+
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
235+
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
236+
ax[5].set_xlabel("Pixel Intensity")
237+
238+
ax[6].set_title("Predicted Image")
239+
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
240+
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
241+
ax[6].set_xlabel("Pixel Intensity")
242+
156243
plt.tight_layout()
157244
plt.savefig("prediction.png")
158245
```

0 commit comments

Comments
 (0)