Skip to content

Commit aab23bc

Browse files
committed
revise bug in categoricalemb for graphdit
1 parent 9a0ef80 commit aab23bc

File tree

3 files changed

+21
-30
lines changed

3 files changed

+21
-30
lines changed

tests/generator/graphdit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_molecule import GraphDITMolecularGenerator
77
from torch_molecule.utils.search import ParameterType, ParameterSpec
88

9-
EPOCHS = 10
9+
EPOCHS = 2
1010
BATCH_SIZE = 32
1111

1212
def test_graph_dit_generator():
@@ -18,13 +18,15 @@ def test_graph_dit_generator():
1818
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
1919
]
2020
smiles_list = smiles_list * 25 # Create 100 molecules for training
21-
properties = [1.0, 2.0, 3.0, 4.0] * 25 # Create 100 properties for training
21+
# properties = [1.0, 2.0, 3.0, 4.0] * 25 # Create 100 properties for training
22+
properties = [0, 0, 1, 1] * 25 # Create 100 properties for training
2223

2324
# 1. Basic initialization test - Conditional Model
2425
print('smiles_list', len(smiles_list), smiles_list[:5], 'properties', len(properties), properties[:5])
2526
print("\n=== Testing Conditional GraphDIT model initialization ===")
2627
conditional_model = GraphDITMolecularGenerator(
27-
task_type=['regression'],
28+
task_type=['classification'],
29+
drop_condition=0.1,
2830
timesteps=500,
2931
batch_size=BATCH_SIZE,
3032
epochs=EPOCHS,
@@ -42,7 +44,7 @@ def test_graph_dit_generator():
4244

4345
# 3. Conditional generation test
4446
print("\n=== Testing Conditional GraphDIT generation ===")
45-
target_properties = [1.0, 2.0, 3.0, 4.0]
47+
target_properties = [0, 0, 1, 1]
4648
generated_smiles = conditional_model.generate(target_properties, batch_size=BATCH_SIZE)
4749
print(f"Conditionally generated {len(generated_smiles)} molecules")
4850
print("Example conditionally generated SMILES:", generated_smiles[:2])

torch_molecule/generator/graph_dit/modeling_graph_dit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,8 @@ def generate(self, labels: Optional[Union[List[List], np.ndarray, torch.Tensor]]
468468
y = labels.to(self.device).float()
469469
else:
470470
y = None
471+
472+
self.model.eval()
471473
for s_int in reversed(range(0, self.timesteps)):
472474
s_array = s_int * torch.ones((batch_size, 1)).float().to(self.device)
473475
t_array = s_array + 1

torch_molecule/nn/embedder.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,6 @@ def __init__(self, num_classes, hidden_size, dropout_prob):
8686
self.num_classes = num_classes
8787
self.dropout_prob = dropout_prob
8888

89-
def token_drop(self, labels, force_drop_ids=None):
90-
"""
91-
Drops labels to enable classifier-free guidance.
92-
93-
Parameters
94-
----------
95-
labels : torch.Tensor
96-
Tensor of integer labels.
97-
force_drop_ids : torch.Tensor or None, optional
98-
Boolean mask to force specific labels to be dropped.
99-
100-
Returns
101-
-------
102-
torch.Tensor
103-
Labels with some entries replaced by a dropout token.
104-
"""
105-
if force_drop_ids is None:
106-
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
107-
else:
108-
drop_ids = force_drop_ids == 1
109-
labels = torch.where(drop_ids, self.num_classes, labels)
110-
return labels
111-
11289
def forward(self, labels, train, force_drop_ids=None):
11390
"""
11491
Forward pass for categorical embedding with optional label dropout.
@@ -128,11 +105,21 @@ def forward(self, labels, train, force_drop_ids=None):
128105
Embedded label representations, with optional noise added during training.
129106
"""
130107
labels = labels.long().view(-1)
108+
131109
use_dropout = self.dropout_prob > 0
132-
if (train and use_dropout) or (force_drop_ids is not None):
133-
labels = self.token_drop(labels, force_drop_ids)
110+
drop_ids = force_drop_ids == 1
111+
112+
if (train and use_dropout):
113+
drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
114+
if force_drop_ids is not None:
115+
drop_ids = torch.logical_or(drop_ids, drop_ids_rand)
116+
else:
117+
drop_ids = drop_ids_rand
118+
119+
if use_dropout:
120+
labels = torch.where(drop_ids, self.num_classes, labels)
134121
embeddings = self.embedding_table(labels)
135-
if True and train:
122+
if train:
136123
noise = torch.randn_like(embeddings)
137124
embeddings = embeddings + noise
138125
return embeddings

0 commit comments

Comments
 (0)