Skip to content

Commit d9b9016

Browse files
committed
Fix masked atom issue. Fixes #11
1 parent e294403 commit d9b9016

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

tests/encoder/attrmask.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,26 @@ def test_attrmask_encoder():
4141
os.remove(save_path)
4242
print(f"Cleaned up {save_path}")
4343

44+
def test_attrmask_encoder_polymers():
45+
# Test molecules (simple examples)
46+
polymers = [
47+
"*Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5ccc(N*)cc5)cc4)CCC(CCCCC)CC3)cc2)cc1",
48+
"*Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)c(-c3ccc(C)cc3)c2-c2ccc(C)cc2)cc1",
49+
"*CC(*)(C)C(=O)OCCCCCCCCCOc1ccc2cc(C(=O)Oc3ccccc3)ccc2c1"
50+
]
51+
model = AttrMaskMolecularEncoder(
52+
num_layer=3,
53+
hidden_size=300,
54+
batch_size=5,
55+
epochs=5, # Small number for testing
56+
verbose=True
57+
)
58+
model.fit(polymers)
59+
vectors = model.encode(polymers)
60+
print(f"Representation shape: {vectors.shape}")
61+
print(f"Representation for new molecule: {vectors[0]}")
62+
63+
4464
if __name__ == "__main__":
45-
test_attrmask_encoder()
65+
test_attrmask_encoder_polymers()
66+
test_attrmask_encoder()

tests/encoder/moama.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,25 @@ def test_moama_encoder():
4747
os.remove(save_path)
4848
print(f"Cleaned up {save_path}")
4949

50+
def test_moama_encoder_polymers():
51+
# Test molecules (simple examples)
52+
polymers = [
53+
"*Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5ccc(N*)cc5)cc4)CCC(CCCCC)CC3)cc2)cc1",
54+
"*Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)c(-c3ccc(C)cc3)c2-c2ccc(C)cc2)cc1",
55+
"*CC(*)(C)C(=O)OCCCCCCCCCOc1ccc2cc(C(=O)Oc3ccccc3)ccc2c1"
56+
]
57+
model = MoamaMolecularEncoder(
58+
num_layer=3,
59+
hidden_size=300,
60+
batch_size=5,
61+
epochs=5, # Small number for testing
62+
verbose=True
63+
)
64+
model.fit(polymers)
65+
vectors = model.encode(polymers)
66+
print(f"Representation shape: {vectors.shape}")
67+
print(f"Representation for new molecule: {vectors[0]}")
68+
5069
if __name__ == "__main__":
51-
test_moama_encoder()
70+
test_moama_encoder_polymers()
71+
test_moama_encoder()

torch_molecule/encoder/attrmask/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from ...nn import GNN_node, GNN_node_Virtualnode, MLP
66
from ...utils import init_weights
7+
from ...utils.graph.features import allowable_features
78

89
import random
910

@@ -23,11 +24,12 @@ def __init__(
2324
):
2425
super(GNN, self).__init__()
2526
gnn_name = encoder_type.split("-")[0]
26-
self.num_atom_type = 119
27+
decoding_size = len(allowable_features['possible_atomic_num_list'])
2728
self.hidden_size = hidden_size
2829
self.mask_num = mask_num
2930
self.mask_rate = mask_rate
3031

32+
self.mask_atom_id = 119
3133
encoder_params = {
3234
"num_layer": num_layer,
3335
"hidden_size": hidden_size,
@@ -50,7 +52,7 @@ def __init__(
5052
if self.pool is None:
5153
raise ValueError(f"Invalid graph pooling type {readout}.")
5254

53-
self.predictor = MLP(hidden_size, hidden_features=2 * hidden_size, out_features=self.num_atom_type)
55+
self.predictor = MLP(hidden_size, hidden_features=2 * hidden_size, out_features=decoding_size)
5456

5557
def initialize_parameters(self, seed=None):
5658
"""
@@ -96,7 +98,7 @@ def compute_loss(self, batched_data):
9698

9799
# mask nodes' features
98100
for node_idx in masked_node_indices:
99-
batched_data.x[node_idx] = torch.tensor([self.num_atom_type - 1] + [0] * (batched_data.x.shape[1] - 1))
101+
batched_data.x[node_idx] = torch.tensor([self.mask_atom_id - 1] + [0] * (batched_data.x.shape[1] - 1))
100102

101103
# generate predictions
102104
h_node, _ = self.graph_encoder(batched_data)

torch_molecule/encoder/moama/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ...utils import init_weights
77

88
from .utils import get_mask_indices, get_fingerprint_loss
9+
from ...utils.graph.features import allowable_features
910

1011
class_criterion = torch.nn.CrossEntropyLoss()
1112

@@ -23,7 +24,9 @@ def __init__(
2324
):
2425
super(GNN, self).__init__()
2526
gnn_name = encoder_type.split("-")[0]
26-
self.num_atom_type = 119
27+
decoding_size = len(allowable_features['possible_atomic_num_list'])
28+
29+
self.mask_atom_id = 119
2730
self.hidden_size = hidden_size
2831
self.mask_rate = mask_rate
2932
self.lw_rec = lw_rec
@@ -50,7 +53,7 @@ def __init__(
5053
if self.pool is None:
5154
raise ValueError(f"Invalid graph pooling type {readout}.")
5255

53-
self.predictor = GNN_Decoder(hidden_size, self.num_atom_type)
56+
self.predictor = GNN_Decoder(hidden_size, decoding_size)
5457

5558
def initialize_parameters(self, seed=None):
5659
"""
@@ -82,13 +85,16 @@ def compute_loss(self, batched_data):
8285

8386
# mask nodes' features
8487
for node_idx in masked_node_indices:
85-
batched_data.x[node_idx] = torch.tensor([self.num_atom_type - 1] + [0] * (batched_data.x.shape[1] - 1))
88+
batched_data.x[node_idx] = torch.tensor([self.mask_atom_id - 1] + [0] * (batched_data.x.shape[1] - 1))
8689

8790
# generate predictions
8891
h_node, _ = self.graph_encoder(batched_data)
8992
h_rep = self.pool(h_node, batched_data.batch)
9093
batched_data.x = h_node
9194
prediction_class = self.predictor(batched_data)[masked_node_indices]
95+
print('prediction_class', prediction_class.max(), prediction_class.min())
96+
print('batched_data.y', batched_data.y.max(), batched_data.y.min())
97+
9298

9399
# target_class = batched_data.y.to(torch.float32)
94100
loss_class = class_criterion(prediction_class.to(torch.float32), batched_data.y.long())

0 commit comments

Comments
 (0)