Skip to content

Commit 485903f

Browse files
committed
ssr truncate ssr loss to 10 and fix verbose
1 parent e2bc8fa commit 485903f

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

torch_molecule/predictor/ssr/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def compute_loss(self, batched_data, criterion, coarse_ratios=[0.8, 0.9], cmd_co
273273
import warnings
274274
warnings.warn(f"SSR loss is too large: {ssr_loss}, truncating to 10")
275275
ssr_loss = 10
276-
total_loss = pred_loss + cmd_coeff * ssr_loss
276+
total_loss = pred_loss + ssr_loss
277277

278278
return total_loss, pred_loss, ssr_loss
279279

torch_molecule/predictor/ssr/modeling_ssr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def _convert_to_pytorch_data(self, X, y=None):
128128
setattr(g, f"coarsened_edge_attr_{coarse_ratio_postfix}", coarse_edge_attr)
129129
setattr(g, f"num_coarse_nodes_{coarse_ratio_postfix}", torch.tensor(num_clusters))
130130
setattr(g, f"clusters_{coarse_ratio_postfix}", clusters)
131-
132-
if self.verbose:
133-
print(f"Processed molecule {idx}: {g.num_nodes} nodes, coarsened versions added")
134131

135132
pyg_graph_list.append(g)
136133

0 commit comments

Comments
 (0)