|
1 | 1 | # MSSIM.pytorch |
| 2 | + |
2 | 3 | A better pytorch-based implementation for the mean structural similarity (MSSIM). |
| 4 | + |
| 5 | +Compared to this widely used implementation: <https://github.com/Po-Hsun-Su/pytorch-ssim>, I further optimized and refactored the code. |
| 6 | + |
| 7 | +At the same time, in this implementation, I have dealt with the problem that the calculation with the fp16 mode cannot be consistent with the calculation with the fp32 mode. Typecasting is used here to ensure that the computation is done in fp32 mode. This might also avoid unexpected results when using it as a loss. |
| 8 | + |
| 9 | +## Structural similarity index |
| 10 | + |
| 11 | +> When comparing images, the mean squared error (MSE)–while simple to implement–is not highly indicative of perceived similarity. Structural similarity aims to address this shortcoming by taking texture into account. More details can be seen at https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity |
| 12 | +
|
| 13 | + |
| 14 | + |
| 15 | +```python |
| 16 | +import matplotlib.pyplot as plt |
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | +import torch.nn.functional as F |
| 20 | +from pytorch_ssim import SSIM, ssim |
| 21 | +from skimage import data, img_as_float |
| 22 | + |
| 23 | +img = img_as_float(data.camera()) |
| 24 | +rows, cols = img.shape |
| 25 | + |
| 26 | +noise = np.ones_like(img) * 0.2 * (img.max() - img.min()) |
| 27 | +rng = np.random.default_rng() |
| 28 | +noise[rng.random(size=noise.shape) > 0.5] *= -1 |
| 29 | + |
| 30 | +img_noise = img + noise |
| 31 | +img_const = img + abs(noise) |
| 32 | + |
| 33 | +img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() |
| 34 | +img_noise_tensor = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).float() |
| 35 | +img_const_tensor = torch.from_numpy(img_const).unsqueeze(0).unsqueeze(0).float() |
| 36 | + |
| 37 | +fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 8), sharex=True, sharey=True) |
| 38 | +ax = axes.ravel() |
| 39 | + |
| 40 | +mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean") |
| 41 | +ssim_none = ssim(img_tensor, img_tensor, L=img_tensor.max() - img_tensor.min()) |
| 42 | + |
| 43 | +mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean") |
| 44 | +ssim_noise = ssim(img_tensor, img_noise_tensor, L=img_noise_tensor.max() - img_noise_tensor.min()) |
| 45 | + |
| 46 | +mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean") |
| 47 | +ssim_const = ssim(img_tensor, img_const_tensor, L=img_const_tensor.max() - img_const_tensor.min()) |
| 48 | + |
| 49 | +ax[0].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 50 | +ax[0].set_xlabel(f"MSE: {mse_none:.2f}, SSIM: {ssim_none:.2f}") |
| 51 | +ax[0].set_title("Original image") |
| 52 | + |
| 53 | +ax[1].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 54 | +ax[1].set_xlabel(f"MSE: {mse_noise:.2f}, SSIM: {ssim_noise:.2f}") |
| 55 | +ax[1].set_title("Image with noise") |
| 56 | + |
| 57 | +ax[2].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 58 | +ax[2].set_xlabel(f"MSE: {mse_const:.2f}, SSIM: {ssim_const:.2f}") |
| 59 | +ax[2].set_title("Image plus constant") |
| 60 | + |
| 61 | +mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean") |
| 62 | +ssim_none = SSIM(L=img_tensor.max() - img_tensor.min())(img_tensor, img_tensor) |
| 63 | + |
| 64 | +mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean") |
| 65 | +ssim_noise = SSIM(L=img_noise_tensor.max() - img_noise_tensor.min())(img_tensor, img_noise_tensor) |
| 66 | + |
| 67 | +mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean") |
| 68 | +ssim_const = SSIM(L=img_const_tensor.max() - img_const_tensor.min())(img_tensor, img_const_tensor) |
| 69 | + |
| 70 | +ax[3].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 71 | +ax[3].set_xlabel(f"MSE: {mse_none:.2f}, SSIM: {ssim_none:.2f}") |
| 72 | +ax[3].set_title("Original image") |
| 73 | + |
| 74 | +ax[4].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 75 | +ax[4].set_xlabel(f"MSE: {mse_noise:.2f}, SSIM: {ssim_noise:.2f}") |
| 76 | +ax[4].set_title("Image with noise") |
| 77 | + |
| 78 | +ax[5].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1) |
| 79 | +ax[5].set_xlabel(f"MSE: {mse_const:.2f}, SSIM: {ssim_const:.2f}") |
| 80 | +ax[5].set_title("Image plus constant") |
| 81 | + |
| 82 | +[ax[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for i in range(len(axes))] |
| 83 | + |
| 84 | +plt.tight_layout() |
| 85 | +plt.savefig("results.png") |
| 86 | +``` |
| 87 | + |
| 88 | +## More Examples |
| 89 | + |
| 90 | +```python |
| 91 | +# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim |
| 92 | +ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda() |
| 93 | + |
| 94 | +# two 4d tensors |
| 95 | +x = torch.randn(3, 1, 100, 100).cuda() |
| 96 | +y = torch.randn(3, 1, 100, 100).cuda() |
| 97 | +ssim_score_0 = ssim_caller(x, y) |
| 98 | +# or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result) |
| 99 | +with torch.cuda.amp.autocast(enabled=True): |
| 100 | + ssim_score_1 = ssim_caller(x, y) |
| 101 | +assert torch.allclose(ssim_score_0, ssim_score_1) |
| 102 | +print(ssim_score_0.shape, ssim_score_1.shape) |
| 103 | +``` |
| 104 | + |
| 105 | +## Reference |
| 106 | + |
| 107 | +- https://github.com/Po-Hsun-Su/pytorch-ssim |
| 108 | +- https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity |
| 109 | +- Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: From error visibility to structural similarity,” IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, Apr. 2004. |
0 commit comments