Skip to content

Commit 0e8c062

Browse files
committed
add graphmae
1 parent 79af7b5 commit 0e8c062

File tree

13 files changed

+849
-27
lines changed

13 files changed

+849
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ predictions = model.predict(smiles_list)
182182
| Model | Reference |
183183
|--------------|---------------------|
184184
| MoAMa | [Motif-aware Attribute Masking for Molecular Graph Pre-training. LoG 2024](https://arxiv.org/abs/2309.04589) |
185+
| GraphMAE | [GraphMAE: Self-Supervised Masked Graph Autoencoders. KDD 2022](https://arxiv.org/abs/2205.10803) |
185186
| AttrMasking | [Strategies for Pre-training Graph Neural Networks. ICLR 2020](https://arxiv.org/abs/1905.12265) |
186187
| ContextPred | [Strategies for Pre-training Graph Neural Networks. ICLR 2020](https://arxiv.org/abs/1905.12265) |
187188
| EdgePred | [Strategies for Pre-training Graph Neural Networks. ICLR 2020](https://arxiv.org/abs/1905.12265) |

docs/source/api/encoder.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,31 @@ Self-supervised Molecular Representation Learning
4242
:undoc-members:
4343
:show-inheritance:
4444

45-
.. rubric:: Context Prediction for Molecular Representation Learning
45+
.. rubric:: Graph masked autoencoder
46+
47+
.. autoclass:: torch_molecule.encoder.graphmae.modeling_graphmae.GraphMAEMolecularEncoder
48+
:members: fit, encode
49+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
50+
:undoc-members:
51+
:show-inheritance:
52+
53+
.. rubric:: Context Prediction
4654

4755
.. autoclass:: torch_molecule.encoder.contextpred.modeling_contextpred.ContextPredMolecularEncoder
4856
:members: fit, encode
4957
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5058
:undoc-members:
5159
:show-inheritance:
5260

53-
.. rubric:: Edge Prediction for Molecular Representation Learning
61+
.. rubric:: Edge Prediction
5462

5563
.. autoclass:: torch_molecule.encoder.edgepred.modeling_edgepred.EdgePredMolecularEncoder
5664
:members: fit, encode
5765
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5866
:undoc-members:
5967
:show-inheritance:
6068

61-
.. rubric:: InfoGraph for Molecular Representation Learning
69+
.. rubric:: InfoGraph
6270

6371
.. autoclass:: torch_molecule.encoder.infograph.modeling_infograph.InfoGraphMolecularEncoder
6472
:members: fit, encode
@@ -69,7 +77,7 @@ Self-supervised Molecular Representation Learning
6977
Supervised Pretraining for Molecules
7078
------------------------------------
7179

72-
.. rubric:: Supervised/Pseudolabeled Pretraining for Molecules
80+
.. rubric:: Pretraining with Supervised/Pseudolabeled Data
7381
.. autoclass:: torch_molecule.encoder.supervised.modeling_supervised.SupervisedMolecularEncoder
7482
:members: fit, encode
7583
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class

tests/encoder/run_graphmae.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import numpy as np
2+
import pandas as pd
3+
import os
4+
from torch_molecule import GraphMAEMolecularEncoder
5+
6+
def test_graphmae_encoder():
7+
# Load molecules from CSV file
8+
data_path = "data/molecule100.csv"
9+
if not os.path.exists(data_path):
10+
print(f"Data file not found: {data_path}")
11+
# Use simple molecules as fallback
12+
molecules = [
13+
"CC(=O)O", # Acetic acid
14+
"CCO", # Ethanol
15+
"CCCC", # Butane
16+
"c1ccccc1", # Benzene
17+
"CCN", # Ethylamine
18+
]
19+
else:
20+
df = pd.read_csv(data_path)
21+
molecules = df['smiles'].tolist()[:50] # Use first 50 molecules
22+
print(f"Loaded {len(molecules)} molecules from {data_path}")
23+
24+
# Initialize GraphMAE model
25+
model = GraphMAEMolecularEncoder(
26+
num_layer=3,
27+
hidden_size=128,
28+
batch_size=16,
29+
epochs=30, # Small number for testing
30+
mask_rate=0.15,
31+
verbose=True,
32+
# device="cpu"
33+
)
34+
print("GraphMAE model initialized successfully")
35+
36+
# Test fitting
37+
print("\n=== Testing GraphMAE model self-supervised fitting ===")
38+
model.fit(molecules)
39+
40+
# Test encoding
41+
print("\n=== Testing molecule encoding ===")
42+
encodings = model.encode(molecules[:5])
43+
print(f"Encoding shape: {encodings.shape}")
44+
45+
# Test saving and loading
46+
print("\n=== Testing model saving and loading ===")
47+
save_path = "graphmae_model.pt"
48+
model.save_to_local(save_path)
49+
print(f"Model saved to {save_path}")
50+
51+
new_model = GraphMAEMolecularEncoder()
52+
new_model.load_from_local(save_path)
53+
print("Model loaded successfully")
54+
55+
# Test encoding with loaded model
56+
new_encodings = new_model.encode(molecules[:5])
57+
print(f"New encoding shape: {new_encodings.shape}")
58+
59+
# Verify encodings are the same (or very close)
60+
encoding_diff = (encodings - new_encodings).abs().max().item()
61+
print(f"Max difference between encodings: {encoding_diff}")
62+
63+
# Clean up
64+
if os.path.exists(save_path):
65+
os.remove(save_path)
66+
print(f"Cleaned up {save_path}")
67+
68+
def test_graphmae_with_edge_masking():
69+
# Load molecules from CSV file
70+
data_path = "data/molecule100.csv"
71+
if not os.path.exists(data_path):
72+
print(f"Data file not found: {data_path}")
73+
# Use simple molecules as fallback
74+
molecules = [
75+
"CC(=O)O", # Acetic acid
76+
"CCO", # Ethanol
77+
"CCCC", # Butane
78+
"c1ccccc1", # Benzene
79+
"CCN", # Ethylamine
80+
]
81+
else:
82+
df = pd.read_csv(data_path)
83+
molecules = df['smiles'].tolist()[:50] # Use first 50 molecules
84+
print(f"Loaded {len(molecules)} molecules from {data_path}")
85+
86+
# Initialize GraphMAE model with edge masking enabled
87+
model = GraphMAEMolecularEncoder(
88+
num_layer=3,
89+
hidden_size=128,
90+
batch_size=16,
91+
epochs=30, # Small number for testing
92+
mask_rate=0.15,
93+
mask_edge=True, # Enable edge masking
94+
verbose=True,
95+
# device="cpu"
96+
)
97+
print("GraphMAE model with edge masking initialized successfully")
98+
99+
# Test fitting
100+
print("\n=== Testing GraphMAE model with edge masking ===")
101+
model.fit(molecules)
102+
103+
# Test encoding
104+
print("\n=== Testing molecule encoding with edge masking model ===")
105+
encodings = model.encode(molecules[:5])
106+
print(f"Encoding shape: {encodings.shape}")
107+
108+
# Test saving and loading
109+
print("\n=== Testing edge masking model saving and loading ===")
110+
save_path = "graphmae_edge_model.pt"
111+
model.save_to_local(save_path)
112+
print(f"Model saved to {save_path}")
113+
114+
new_model = GraphMAEMolecularEncoder()
115+
new_model.load_from_local(save_path)
116+
print("Model loaded successfully")
117+
118+
# Verify edge masking parameter was preserved
119+
print(f"Loaded model mask_edge parameter: {new_model.mask_edge}")
120+
121+
# Clean up
122+
if os.path.exists(save_path):
123+
os.remove(save_path)
124+
print(f"Cleaned up {save_path}")
125+
126+
if __name__ == "__main__":
127+
print("=== Testing GraphMAE Encoder (Default Configuration) ===")
128+
test_graphmae_encoder()
129+
130+
print("\n=== Testing GraphMAE Encoder with Edge Masking ===")
131+
test_graphmae_with_edge_masking()

torch_molecule/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
"""
1616
encoder module
1717
"""
18-
from .encoder.supervised import SupervisedMolecularEncoder
1918
from .encoder.attrmask import AttrMaskMolecularEncoder
19+
from .encoder.moama import MoamaMolecularEncoder
20+
from .encoder.graphmae import GraphMAEMolecularEncoder
21+
from .encoder.supervised import SupervisedMolecularEncoder
2022
from .encoder.contextpred import ContextPredMolecularEncoder
2123
from .encoder.edgepred import EdgePredMolecularEncoder
22-
from .encoder.moama import MoamaMolecularEncoder
2324
from .encoder.infograph import InfoGraphMolecularEncoder
2425
from .encoder.pretrained import HFPretrainedMolecularEncoder
2526
"""
@@ -46,9 +47,10 @@
4647
# encoders
4748
'SupervisedMolecularEncoder',
4849
'AttrMaskMolecularEncoder',
50+
'MoamaMolecularEncoder',
51+
'GraphMAEMolecularEncoder',
4952
'ContextPredMolecularEncoder',
5053
'EdgePredMolecularEncoder',
51-
'MoamaMolecularEncoder',
5254
'InfoGraphMolecularEncoder',
5355
'HFPretrainedMolecularEncoder',
5456
# generators

torch_molecule/encoder/attrmask/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
):
2424
super(GNN, self).__init__()
2525
gnn_name = encoder_type.split("-")[0]
26-
self.num_atom_type = 118
26+
self.num_atom_type = 119
2727
self.hidden_size = hidden_size
2828
self.mask_num = mask_num
2929
self.mask_rate = mask_rate
@@ -96,11 +96,10 @@ def compute_loss(self, batched_data):
9696

9797
# mask nodes' features
9898
for node_idx in masked_node_indices:
99-
batched_data.x[node_idx] = torch.tensor([self.num_atom_type] + [0] * (batched_data.x.shape[1] - 1))
99+
batched_data.x[node_idx] = torch.tensor([self.num_atom_type - 1] + [0] * (batched_data.x.shape[1] - 1))
100100

101101
# generate predictions
102102
h_node, _ = self.graph_encoder(batched_data)
103-
#h_rep = self.pool(h_node, batched_data.batch)
104103
prediction_class = self.predictor(h_node[masked_node_indices])
105104

106105
target_class = batched_data.y.to(torch.float32)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .modeling_graphmae import GraphMAEMolecularEncoder
2+
3+
__all__ = ['GraphMAEMolecularEncoder']

0 commit comments

Comments
 (0)