Skip to content

Commit b6a3020

Browse files
authored
Create ssim.py
1 parent a969425 commit b6a3020

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

ssim.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class GaussianFilter2D(nn.Module):
7+
def __init__(self, window_size=11, in_channels=1, sigma=1.5) -> None:
8+
"""2D Gaussian Filer
9+
10+
Args:
11+
window_size (int, optional): The window size of the gaussian filter. Defaults to 11.
12+
in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False.
13+
sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5.
14+
"""
15+
super().__init__()
16+
self.window_size = window_size
17+
if not (window_size % 2 == 1):
18+
raise ValueError("Window size must be odd.")
19+
self.in_channels = in_channels
20+
self.padding = window_size // 2
21+
self.sigma = sigma
22+
self.register_buffer(name="gaussian_window2d", tensor=self._get_gaussian_window2d())
23+
24+
def _get_gaussian_window1d(self):
25+
sigma2 = self.sigma * self.sigma
26+
x = torch.arange(-(self.window_size // 2), self.window_size // 2 + 1)
27+
w = torch.exp(-0.5 * x ** 2 / sigma2)
28+
w = w / w.sum()
29+
return w.reshape(1, 1, self.window_size, 1)
30+
31+
def _get_gaussian_window2d(self):
32+
gaussian_window_1d = self._get_gaussian_window1d()
33+
w = torch.matmul(gaussian_window_1d, gaussian_window_1d.transpose(dim0=-1, dim1=-2))
34+
w.reshape(1, 1, self.window_size, self.window_size)
35+
return w.repeat(self.in_channels, 1, 1, 1)
36+
37+
def forward(self, x):
38+
x = F.conv2d(input=x, weight=self.gaussian_window2d, padding=self.padding, groups=x.shape[1])
39+
return x
40+
41+
42+
class SSIM(nn.Module):
43+
def __init__(
44+
self, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False
45+
):
46+
"""Calculate the mean SSIM (MSSIM) between two 4D tensors.
47+
48+
Args:
49+
window_size (int, optional): The window size of the gaussian filter. Defaults to 11.
50+
in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False.
51+
sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5.
52+
K1 (float, optional): K1 of MSSIM. Defaults to 0.01.
53+
K2 (float, optional): K2 of MSSIM. Defaults to 0.03.
54+
L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
55+
keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
56+
return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
57+
58+
```
59+
# setting 0: for 4d float tensors with the data range [0, 1] and 1 channel
60+
ssim_caller = SSIM().cuda()
61+
# setting 1: for 4d float tensors with the data range [0, 1] and 3 channel
62+
ssim_caller = SSIM(in_channels=3).cuda()
63+
# setting 2: for 4d float tensors with the data range [0, 255] and 3 channel
64+
ssim_caller = SSIM(L=255, in_channels=3).cuda()
65+
# setting 3: for 4d float tensors with the data range [0, 255] and 3 channel, and return the logarithmic form
66+
ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda()
67+
# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim
68+
ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()
69+
70+
# two 4d tensors
71+
x = torch.randn(3, 1, 100, 100).cuda()
72+
y = torch.randn(3, 1, 100, 100).cuda()
73+
ssim_score_0 = ssim_caller(x, y)
74+
# or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result)
75+
with torch.cuda.amp.autocast(enabled=True):
76+
ssim_score_1 = ssim_caller(x, y)
77+
assert torch.isclose(ssim_score_0, ssim_score_1)
78+
```
79+
"""
80+
super().__init__()
81+
self.window_size = window_size
82+
self.C1 = (K1 * L) ** 2 # equ 7 in ref1
83+
self.C2 = (K2 * L) ** 2 # equ 7 in ref1
84+
self.keep_batch_dim = keep_batch_dim
85+
self.return_log = return_log
86+
87+
self.gaussian_filer = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma)
88+
89+
def forward(self, x, y):
90+
return ssim(
91+
x,
92+
y,
93+
gaussian_filter=self.gaussian_filer,
94+
C1=self.C1,
95+
C2=self.C2,
96+
keep_batch_dim=self.keep_batch_dim,
97+
return_log=self.return_log,
98+
)
99+
100+
101+
@torch.cuda.amp.autocast(enabled=False)
102+
def ssim(x, y, gaussian_filter, C1, C2, keep_batch_dim=False, return_log=False):
103+
"""Calculate the mean SSIM (MSSIM) between two 4d tensors.
104+
105+
Args:
106+
x (Tensor): 4d tensor
107+
y (Tensor): 4d tensor
108+
gaussian_filter (GaussianFilter2D): the gaussian filter object
109+
C1 (float): the constant to avoid instability
110+
C2 (float): the constant to avoid instability
111+
keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
112+
return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
113+
114+
Returns:
115+
Tensor: MSSIM
116+
"""
117+
assert x.shape == y.shape, f"x: {x.shape} != y: {y.shape}"
118+
assert x.ndim == y.ndim == 4, f"x: {x.ndim} != y: {y.ndim} != 4"
119+
assert x.type() == y.type(), f"x: {x.type()} != y: {y.type()}"
120+
121+
mu_x = gaussian_filter(x) # equ 14
122+
mu_y = gaussian_filter(y) # equ 14
123+
sigma2_x = gaussian_filter(x * x) - mu_x * mu_x # equ 15
124+
sigma2_y = gaussian_filter(y * y) - mu_y * mu_y # equ 15
125+
sigma_xy = gaussian_filter(x * y) - mu_x * mu_y # equ 16
126+
127+
# equ 13 in ref1
128+
A1 = 2 * mu_x * mu_y + C1
129+
A2 = 2 * sigma_xy + C2
130+
B1 = mu_x * mu_x + mu_y * mu_y + C1
131+
B2 = sigma2_x + sigma2_y + C2
132+
S = (A1 * A2) / (B1 * B2)
133+
134+
if return_log:
135+
S = S - S.min()
136+
S = S / S.max()
137+
S = -torch.log(S + 1e-8)
138+
139+
if keep_batch_dim:
140+
return S.mean(dim=(1, 2, 3))
141+
else:
142+
return S.mean()

0 commit comments

Comments
 (0)