|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | + |
| 5 | + |
| 6 | +class GaussianFilter2D(nn.Module): |
| 7 | + def __init__(self, window_size=11, in_channels=1, sigma=1.5) -> None: |
| 8 | + """2D Gaussian Filer |
| 9 | +
|
| 10 | + Args: |
| 11 | + window_size (int, optional): The window size of the gaussian filter. Defaults to 11. |
| 12 | + in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. |
| 13 | + sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. |
| 14 | + """ |
| 15 | + super().__init__() |
| 16 | + self.window_size = window_size |
| 17 | + if not (window_size % 2 == 1): |
| 18 | + raise ValueError("Window size must be odd.") |
| 19 | + self.in_channels = in_channels |
| 20 | + self.padding = window_size // 2 |
| 21 | + self.sigma = sigma |
| 22 | + self.register_buffer(name="gaussian_window2d", tensor=self._get_gaussian_window2d()) |
| 23 | + |
| 24 | + def _get_gaussian_window1d(self): |
| 25 | + sigma2 = self.sigma * self.sigma |
| 26 | + x = torch.arange(-(self.window_size // 2), self.window_size // 2 + 1) |
| 27 | + w = torch.exp(-0.5 * x ** 2 / sigma2) |
| 28 | + w = w / w.sum() |
| 29 | + return w.reshape(1, 1, self.window_size, 1) |
| 30 | + |
| 31 | + def _get_gaussian_window2d(self): |
| 32 | + gaussian_window_1d = self._get_gaussian_window1d() |
| 33 | + w = torch.matmul(gaussian_window_1d, gaussian_window_1d.transpose(dim0=-1, dim1=-2)) |
| 34 | + w.reshape(1, 1, self.window_size, self.window_size) |
| 35 | + return w.repeat(self.in_channels, 1, 1, 1) |
| 36 | + |
| 37 | + def forward(self, x): |
| 38 | + x = F.conv2d(input=x, weight=self.gaussian_window2d, padding=self.padding, groups=x.shape[1]) |
| 39 | + return x |
| 40 | + |
| 41 | + |
| 42 | +class SSIM(nn.Module): |
| 43 | + def __init__( |
| 44 | + self, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False |
| 45 | + ): |
| 46 | + """Calculate the mean SSIM (MSSIM) between two 4D tensors. |
| 47 | +
|
| 48 | + Args: |
| 49 | + window_size (int, optional): The window size of the gaussian filter. Defaults to 11. |
| 50 | + in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. |
| 51 | + sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. |
| 52 | + K1 (float, optional): K1 of MSSIM. Defaults to 0.01. |
| 53 | + K2 (float, optional): K2 of MSSIM. Defaults to 0.03. |
| 54 | + L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. |
| 55 | + keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. |
| 56 | + return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. |
| 57 | +
|
| 58 | + ``` |
| 59 | + # setting 0: for 4d float tensors with the data range [0, 1] and 1 channel |
| 60 | + ssim_caller = SSIM().cuda() |
| 61 | + # setting 1: for 4d float tensors with the data range [0, 1] and 3 channel |
| 62 | + ssim_caller = SSIM(in_channels=3).cuda() |
| 63 | + # setting 2: for 4d float tensors with the data range [0, 255] and 3 channel |
| 64 | + ssim_caller = SSIM(L=255, in_channels=3).cuda() |
| 65 | + # setting 3: for 4d float tensors with the data range [0, 255] and 3 channel, and return the logarithmic form |
| 66 | + ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda() |
| 67 | + # setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim |
| 68 | + ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda() |
| 69 | +
|
| 70 | + # two 4d tensors |
| 71 | + x = torch.randn(3, 1, 100, 100).cuda() |
| 72 | + y = torch.randn(3, 1, 100, 100).cuda() |
| 73 | + ssim_score_0 = ssim_caller(x, y) |
| 74 | + # or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result) |
| 75 | + with torch.cuda.amp.autocast(enabled=True): |
| 76 | + ssim_score_1 = ssim_caller(x, y) |
| 77 | + assert torch.isclose(ssim_score_0, ssim_score_1) |
| 78 | + ``` |
| 79 | + """ |
| 80 | + super().__init__() |
| 81 | + self.window_size = window_size |
| 82 | + self.C1 = (K1 * L) ** 2 # equ 7 in ref1 |
| 83 | + self.C2 = (K2 * L) ** 2 # equ 7 in ref1 |
| 84 | + self.keep_batch_dim = keep_batch_dim |
| 85 | + self.return_log = return_log |
| 86 | + |
| 87 | + self.gaussian_filer = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma) |
| 88 | + |
| 89 | + def forward(self, x, y): |
| 90 | + return ssim( |
| 91 | + x, |
| 92 | + y, |
| 93 | + gaussian_filter=self.gaussian_filer, |
| 94 | + C1=self.C1, |
| 95 | + C2=self.C2, |
| 96 | + keep_batch_dim=self.keep_batch_dim, |
| 97 | + return_log=self.return_log, |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +@torch.cuda.amp.autocast(enabled=False) |
| 102 | +def ssim(x, y, gaussian_filter, C1, C2, keep_batch_dim=False, return_log=False): |
| 103 | + """Calculate the mean SSIM (MSSIM) between two 4d tensors. |
| 104 | +
|
| 105 | + Args: |
| 106 | + x (Tensor): 4d tensor |
| 107 | + y (Tensor): 4d tensor |
| 108 | + gaussian_filter (GaussianFilter2D): the gaussian filter object |
| 109 | + C1 (float): the constant to avoid instability |
| 110 | + C2 (float): the constant to avoid instability |
| 111 | + keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. |
| 112 | + return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Tensor: MSSIM |
| 116 | + """ |
| 117 | + assert x.shape == y.shape, f"x: {x.shape} != y: {y.shape}" |
| 118 | + assert x.ndim == y.ndim == 4, f"x: {x.ndim} != y: {y.ndim} != 4" |
| 119 | + assert x.type() == y.type(), f"x: {x.type()} != y: {y.type()}" |
| 120 | + |
| 121 | + mu_x = gaussian_filter(x) # equ 14 |
| 122 | + mu_y = gaussian_filter(y) # equ 14 |
| 123 | + sigma2_x = gaussian_filter(x * x) - mu_x * mu_x # equ 15 |
| 124 | + sigma2_y = gaussian_filter(y * y) - mu_y * mu_y # equ 15 |
| 125 | + sigma_xy = gaussian_filter(x * y) - mu_x * mu_y # equ 16 |
| 126 | + |
| 127 | + # equ 13 in ref1 |
| 128 | + A1 = 2 * mu_x * mu_y + C1 |
| 129 | + A2 = 2 * sigma_xy + C2 |
| 130 | + B1 = mu_x * mu_x + mu_y * mu_y + C1 |
| 131 | + B2 = sigma2_x + sigma2_y + C2 |
| 132 | + S = (A1 * A2) / (B1 * B2) |
| 133 | + |
| 134 | + if return_log: |
| 135 | + S = S - S.min() |
| 136 | + S = S / S.max() |
| 137 | + S = -torch.log(S + 1e-8) |
| 138 | + |
| 139 | + if keep_batch_dim: |
| 140 | + return S.mean(dim=(1, 2, 3)) |
| 141 | + else: |
| 142 | + return S.mean() |
0 commit comments