Skip to content

Commit 409f601

Browse files
authored
Update README.md
1 parent b6a3020 commit 409f601

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

README.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,109 @@
11
# MSSIM.pytorch
2+
23
A better pytorch-based implementation for the mean structural similarity (MSSIM).
4+
5+
Compared to this widely used implementation: <https://github.com/Po-Hsun-Su/pytorch-ssim>, I further optimized and refactored the code.
6+
7+
At the same time, in this implementation, I have dealt with the problem that the calculation with the fp16 mode cannot be consistent with the calculation with the fp32 mode. Typecasting is used here to ensure that the computation is done in fp32 mode. This might also avoid unexpected results when using it as a loss.
8+
9+
## Structural similarity index
10+
11+
> When comparing images, the mean squared error (MSE)–while simple to implement–is not highly indicative of perceived similarity. Structural similarity aims to address this shortcoming by taking texture into account. More details can be seen at https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity
12+
13+
![results](https://user-images.githubusercontent.com/26847524/174805728-81e8502b-2ecb-4b40-a2c4-b4f1e2361ea9.png)
14+
15+
```python
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
import torch
19+
import torch.nn.functional as F
20+
from pytorch_ssim import SSIM, ssim
21+
from skimage import data, img_as_float
22+
23+
img = img_as_float(data.camera())
24+
rows, cols = img.shape
25+
26+
noise = np.ones_like(img) * 0.2 * (img.max() - img.min())
27+
rng = np.random.default_rng()
28+
noise[rng.random(size=noise.shape) > 0.5] *= -1
29+
30+
img_noise = img + noise
31+
img_const = img + abs(noise)
32+
33+
img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
34+
img_noise_tensor = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).float()
35+
img_const_tensor = torch.from_numpy(img_const).unsqueeze(0).unsqueeze(0).float()
36+
37+
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 8), sharex=True, sharey=True)
38+
ax = axes.ravel()
39+
40+
mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean")
41+
ssim_none = ssim(img_tensor, img_tensor, L=img_tensor.max() - img_tensor.min())
42+
43+
mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean")
44+
ssim_noise = ssim(img_tensor, img_noise_tensor, L=img_noise_tensor.max() - img_noise_tensor.min())
45+
46+
mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean")
47+
ssim_const = ssim(img_tensor, img_const_tensor, L=img_const_tensor.max() - img_const_tensor.min())
48+
49+
ax[0].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1)
50+
ax[0].set_xlabel(f"MSE: {mse_none:.2f}, SSIM: {ssim_none:.2f}")
51+
ax[0].set_title("Original image")
52+
53+
ax[1].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1)
54+
ax[1].set_xlabel(f"MSE: {mse_noise:.2f}, SSIM: {ssim_noise:.2f}")
55+
ax[1].set_title("Image with noise")
56+
57+
ax[2].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1)
58+
ax[2].set_xlabel(f"MSE: {mse_const:.2f}, SSIM: {ssim_const:.2f}")
59+
ax[2].set_title("Image plus constant")
60+
61+
mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean")
62+
ssim_none = SSIM(L=img_tensor.max() - img_tensor.min())(img_tensor, img_tensor)
63+
64+
mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean")
65+
ssim_noise = SSIM(L=img_noise_tensor.max() - img_noise_tensor.min())(img_tensor, img_noise_tensor)
66+
67+
mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean")
68+
ssim_const = SSIM(L=img_const_tensor.max() - img_const_tensor.min())(img_tensor, img_const_tensor)
69+
70+
ax[3].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1)
71+
ax[3].set_xlabel(f"MSE: {mse_none:.2f}, SSIM: {ssim_none:.2f}")
72+
ax[3].set_title("Original image")
73+
74+
ax[4].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1)
75+
ax[4].set_xlabel(f"MSE: {mse_noise:.2f}, SSIM: {ssim_noise:.2f}")
76+
ax[4].set_title("Image with noise")
77+
78+
ax[5].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1)
79+
ax[5].set_xlabel(f"MSE: {mse_const:.2f}, SSIM: {ssim_const:.2f}")
80+
ax[5].set_title("Image plus constant")
81+
82+
[ax[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for i in range(len(axes))]
83+
84+
plt.tight_layout()
85+
plt.savefig("results.png")
86+
```
87+
88+
## More Examples
89+
90+
```python
91+
# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim
92+
ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()
93+
94+
# two 4d tensors
95+
x = torch.randn(3, 1, 100, 100).cuda()
96+
y = torch.randn(3, 1, 100, 100).cuda()
97+
ssim_score_0 = ssim_caller(x, y)
98+
# or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result)
99+
with torch.cuda.amp.autocast(enabled=True):
100+
ssim_score_1 = ssim_caller(x, y)
101+
assert torch.allclose(ssim_score_0, ssim_score_1)
102+
print(ssim_score_0.shape, ssim_score_1.shape)
103+
```
104+
105+
## Reference
106+
107+
- https://github.com/Po-Hsun-Su/pytorch-ssim
108+
- https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity
109+
- Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: From error visibility to structural similarity,” IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, Apr. 2004.

0 commit comments

Comments
 (0)