Skip to content

Commit 49341a6

Browse files
authored
[#45] remove type checking for mixed precision training
1 parent 5ea6dc1 commit 49341a6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pytorch_msssim/ssim.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def ssim(
138138
if len(X.shape) not in (4, 5):
139139
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
140140

141-
if not X.type() == Y.type():
142-
raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
141+
#if not X.type() == Y.type():
142+
# raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
143143

144144
if win is not None: # set win_size
145145
win_size = win.shape[-1]
@@ -193,8 +193,8 @@ def ms_ssim(
193193
X = X.squeeze(dim=d)
194194
Y = Y.squeeze(dim=d)
195195

196-
if not X.type() == Y.type():
197-
raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
196+
#if not X.type() == Y.type():
197+
# raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
198198

199199
if len(X.shape) == 4:
200200
avg_pool = F.avg_pool2d

0 commit comments

Comments
 (0)