Skip to content

Commit 8957131

Browse files
authored
Fixed the type of x and y.
Force to keep the type of x and y consistent with the Gaussian filter weights. Since the model output in fp16 mode is half type, this helps to avoid problems caused by inconsistent types of x and y. After testing, it is now possible to train the model with fp16 mode, although the operation process of the mssim is still in fp32 mode.
1 parent fd3d1b5 commit 8957131

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ssim.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ def forward(self, x, y):
9999
"""
100100
assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same"
101101
assert x.ndim == y.ndim == 4, f"x: {x.ndim} and y: {y.ndim} must be 4"
102-
assert (
103-
x.type() == y.type() == self.gaussian_filter.gaussian_window2d.type()
104-
), f"x: {x.type()} and y: {y.type()} must be {self.gaussian_filter.gaussian_window2d.type()}"
102+
if x.type() != self.gaussian_filter.gaussian_window2d.type():
103+
x = x.type_as(self.gaussian_filter.gaussian_window2d)
104+
if y.type() != self.gaussian_filter.gaussian_window2d.type():
105+
y = y.type_as(self.gaussian_filter.gaussian_window2d)
105106

106107
mu_x = self.gaussian_filter(x) # equ 14
107108
mu_y = self.gaussian_filter(y) # equ 14
@@ -157,5 +158,5 @@ def ssim(
157158
L=L,
158159
keep_batch_dim=keep_batch_dim,
159160
return_log=return_log,
160-
)
161+
).to(device=x.device)
161162
return ssim_obj(x, y)

0 commit comments

Comments
 (0)