Skip to content

Commit 25fc111

Browse files
committed
add gdss generator
1 parent 951935d commit 25fc111

File tree

15 files changed

+2083
-9
lines changed

15 files changed

+2083
-9
lines changed

docs/source/api/generator.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,19 @@ Modeling Molecules as Graphs with GNN / Transformer-based Generators
4444
:undoc-members:
4545
:show-inheritance:
4646

47+
.. rubric:: GDSS for score-based molecular generation
48+
.. autoclass:: torch_molecule.generator.gdss.modeling_gdss.GDSSMolecularGenerator
49+
:exclude-members: fitting_epoch, fitting_loss, save_to_hf, load_from_hf
50+
:members: fit, generate
51+
:undoc-members:
52+
:show-inheritance:
53+
4754
Modeling Molecules as Graphs with Heuristic-based Generators
4855
------------------------------------------------------------
4956

5057
.. rubric:: Graph Genetic Algorithm for Un/Multi-conditional Molecular Generation
5158
.. autoclass:: torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator
52-
:exclude-members: fitting_epoch, fitting_loss, push_to_huggingface, load_from_huggingface
59+
:exclude-members: fitting_epoch, fitting_loss, save_to_hf, load_from_hf
5360
:members: fit, generate
5461
:undoc-members:
5562
:show-inheritance:

docs/source/overview.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ Model Persistence
77
^^^^^^^^^^^^^^^^^
88
- ``load_from_local``: Load a saved model from a local file
99
- ``save_to_local``: Save the current model to a local file
10-
- ``load_from_huggingface``: Load a model from a Hugging Face repository
11-
- ``push_to_huggingface``: Push the current model to a Hugging Face repository
10+
- ``load_from_hf``: Load a model from a Hugging Face repository
11+
- ``save_to_hf``: Push the current model to a Hugging Face repository
1212
- ``load``: Load a model from either local storage or Hugging Face
1313
- ``save``: Save the model to either local storage or Hugging Face
1414

tests/generator/run_gdss.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
import numpy as np
3+
from tqdm import tqdm
4+
5+
import torch
6+
from torch_molecule import GDSSMolecularGenerator
7+
8+
EPOCHS = 500
9+
BATCH_SIZE = 16
10+
11+
def test_gdss_generator():
12+
# Test data
13+
smiles_list = [
14+
'CNC[C@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@@H]1C',
15+
'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C',
16+
'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F',
17+
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
18+
]
19+
smiles_list = smiles_list * 25 # Create 100 molecules for training
20+
21+
# 1. Basic initialization test
22+
print("\n=== Testing GDSS model initialization ===")
23+
model = GDSSMolecularGenerator(
24+
num_layer=3,
25+
hidden_size_adj=8,
26+
hidden_size=16,
27+
attention_dim=16,
28+
num_head=4,
29+
batch_size=BATCH_SIZE,
30+
epochs=EPOCHS,
31+
learning_rate=0.005,
32+
verbose=True
33+
)
34+
print("GDSS Model initialized successfully")
35+
36+
# 2. Basic fitting test
37+
print("\n=== Testing GDSS model fitting ===")
38+
model.fit(smiles_list)
39+
print("GDSS Model fitting completed")
40+
41+
# 3. Generation test
42+
print("\n=== Testing GDSS model generation ===")
43+
generated_smiles = model.generate(batch_size=BATCH_SIZE)
44+
print(f"Generated {len(generated_smiles)} molecules")
45+
print("Example generated SMILES:", generated_smiles[:5])
46+
47+
# 4. Model saving and loading test
48+
print("\n=== Testing GDSS model saving and loading ===")
49+
save_path = "gdss_test_model.pt"
50+
model.save_to_local(save_path)
51+
print(f"GDSS Model saved to {save_path}")
52+
53+
new_model = GDSSMolecularGenerator()
54+
new_model.load_from_local(save_path)
55+
print("GDSS Model loaded successfully")
56+
57+
# Test generation with loaded model
58+
generated_smiles = new_model.generate(batch_size=BATCH_SIZE)
59+
print(f"Generated molecules with loaded model: {len(generated_smiles)}")
60+
print("Example generated SMILES:", generated_smiles[:5])
61+
62+
# 5. Test generation with specific number of nodes
63+
print("\n=== Testing generation with specific node counts ===")
64+
num_nodes = np.array([[20], [25], [30], [35]]) # Specify different node counts
65+
generated_smiles = model.generate(num_nodes=num_nodes)
66+
print(f"Generated molecules with specific node counts: {len(generated_smiles)}")
67+
print("Example generated SMILES:", generated_smiles)
68+
69+
# Clean up
70+
if os.path.exists(save_path):
71+
os.remove(save_path)
72+
print(f"Cleaned up {save_path}")
73+
74+
if __name__ == "__main__":
75+
test_gdss_generator()

torch_molecule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .generator.graph_ga import GraphGAMolecularGenerator
3030
from .generator.digress import DigressMolecularGenerator
3131
from .generator.molgpt import MolGPTMolecularGenerator
32+
from .generator.gdss import GDSSMolecularGenerator
3233

3334
__all__ = [
3435
# 'BaseMolecularPredictor',
@@ -55,4 +56,5 @@
5556
'GraphGAMolecularGenerator',
5657
'DigressMolecularGenerator',
5758
'MolGPTMolecularGenerator',
59+
'GDSSMolecularGenerator',
5860
]

torch_molecule/generator/digress/modeling_digress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ def fit(self, X_train: List[str]) -> "DigressMolecularGenerator":
274274
self.fitting_epoch = 0
275275
for epoch in range(self.epochs):
276276
train_losses = self._train_epoch(train_loader, optimizer, epoch)
277-
self.fitting_loss.append(np.mean(train_losses))
277+
self.fitting_loss.append(np.mean(train_losses).item())
278278
if scheduler:
279-
scheduler.step(np.mean(train_losses))
279+
scheduler.step(np.mean(train_losses).item())
280280

281281
self.fitting_epoch = epoch
282282
self.is_fitted_ = True
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .modeling_gdss import GDSSMolecularGenerator
2+
3+
__all__ = ['GDSSMolecularGenerator']
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import math
2+
import torch
3+
from torch.nn import Parameter
4+
import torch.nn.functional as F
5+
from typing import Any
6+
from .utils import mask_adjs, mask_x
7+
8+
def glorot(tensor):
9+
if tensor is not None:
10+
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
11+
tensor.data.uniform_(-stdv, stdv)
12+
13+
def zeros(tensor):
14+
if tensor is not None:
15+
tensor.data.fill_(0)
16+
17+
def reset(value: Any):
18+
if hasattr(value, 'reset_parameters'):
19+
value.reset_parameters()
20+
else:
21+
for child in value.children() if hasattr(value, 'children') else []:
22+
reset(child)
23+
24+
# -------- GCN layer --------
25+
class DenseGCNConv(torch.nn.Module):
26+
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
27+
"""
28+
def __init__(self, in_channels, out_channels, improved=False, bias=True):
29+
super(DenseGCNConv, self).__init__()
30+
31+
self.in_channels = in_channels
32+
self.out_channels = out_channels
33+
self.improved = improved
34+
35+
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
36+
37+
if bias:
38+
self.bias = Parameter(torch.Tensor(out_channels))
39+
else:
40+
self.register_parameter('bias', None)
41+
42+
self.reset_parameters()
43+
44+
def reset_parameters(self):
45+
glorot(self.weight)
46+
zeros(self.bias)
47+
48+
49+
def forward(self, x, adj, mask=None, add_loop=True):
50+
r"""
51+
Args:
52+
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
53+
\times N \times F}`, with batch-size :math:`B`, (maximum)
54+
number of nodes :math:`N` for each graph, and feature
55+
dimension :math:`F`.
56+
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
57+
\times N \times N}`. The adjacency tensor is broadcastable in
58+
the batch dimension, resulting in a shared adjacency matrix for
59+
the complete batch.
60+
mask (BoolTensor, optional): Mask matrix
61+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
62+
the valid nodes for each graph. (default: :obj:`None`)
63+
add_loop (bool, optional): If set to :obj:`False`, the layer will
64+
not automatically add self-loops to the adjacency matrices.
65+
(default: :obj:`True`)
66+
"""
67+
x = x.unsqueeze(0) if x.dim() == 2 else x
68+
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
69+
B, N, _ = adj.size()
70+
71+
if add_loop:
72+
adj = adj.clone()
73+
idx = torch.arange(N, dtype=torch.long, device=adj.device)
74+
adj[:, idx, idx] = 1 if not self.improved else 2
75+
76+
out = torch.matmul(x, self.weight)
77+
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
78+
79+
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
80+
out = torch.matmul(adj, out)
81+
82+
if self.bias is not None:
83+
out = out + self.bias
84+
85+
if mask is not None:
86+
out = out * mask.view(B, N, 1).to(x.dtype)
87+
88+
return out
89+
90+
91+
def __repr__(self):
92+
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
93+
self.out_channels)
94+
95+
# -------- MLP layer --------
96+
class MLP(torch.nn.Module):
97+
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
98+
"""
99+
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
100+
input_dim: dimensionality of input features
101+
hidden_dim: dimensionality of hidden units at ALL layers
102+
output_dim: number of classes for prediction
103+
num_classes: the number of classes of input, to be treated with different gains and biases,
104+
(see the definition of class `ConditionalLayer1d`)
105+
"""
106+
107+
super(MLP, self).__init__()
108+
109+
self.linear_or_not = True # default is linear model
110+
self.num_layers = num_layers
111+
self.use_bn = use_bn
112+
self.activate_func = activate_func
113+
114+
if num_layers < 1:
115+
raise ValueError("number of layers should be positive!")
116+
elif num_layers == 1:
117+
# Linear model
118+
self.linear = torch.nn.Linear(input_dim, output_dim)
119+
else:
120+
# Multi-layer model
121+
self.linear_or_not = False
122+
self.linears = torch.nn.ModuleList()
123+
124+
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
125+
for layer in range(num_layers - 2):
126+
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
127+
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
128+
129+
if self.use_bn:
130+
self.batch_norms = torch.nn.ModuleList()
131+
for layer in range(num_layers - 1):
132+
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
133+
134+
135+
def forward(self, x):
136+
"""
137+
:param x: [num_classes * batch_size, N, F_i], batch of node features
138+
note that in self.cond_layers[layer],
139+
`x` is splited into `num_classes` groups in dim=0,
140+
and then treated with different gains and biases
141+
"""
142+
if self.linear_or_not:
143+
# If linear model
144+
return self.linear(x)
145+
else:
146+
# If MLP
147+
h = x
148+
for layer in range(self.num_layers - 1):
149+
h = self.linears[layer](h)
150+
if self.use_bn:
151+
h = self.batch_norms[layer](h)
152+
h = self.activate_func(h)
153+
return self.linears[self.num_layers - 1](h)
154+
155+
156+
# -------- Graph Multi-Head Attention (GMH) --------
157+
# -------- From Baek et al. (2021) --------
158+
class Attention(torch.nn.Module):
159+
def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'):
160+
super(Attention, self).__init__()
161+
self.num_heads = num_heads
162+
self.attn_dim = attn_dim
163+
self.out_dim = out_dim
164+
self.conv = conv
165+
166+
self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv)
167+
self.activation = torch.tanh
168+
self.softmax_dim = 2
169+
170+
def forward(self, x, adj, flags, attention_mask=None):
171+
if self.conv == 'GCN':
172+
Q = self.gnn_q(x, adj)
173+
K = self.gnn_k(x, adj)
174+
else:
175+
Q = self.gnn_q(x)
176+
K = self.gnn_k(x)
177+
178+
V = self.gnn_v(x, adj)
179+
dim_split = self.attn_dim // self.num_heads
180+
Q_ = torch.cat(Q.split(dim_split, 2), 0)
181+
K_ = torch.cat(K.split(dim_split, 2), 0)
182+
183+
if attention_mask is not None:
184+
attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0)
185+
attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim)
186+
A = self.activation( attention_mask + attention_score )
187+
else:
188+
A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N
189+
190+
# -------- (B x num_heads) x N x N --------
191+
A = A.view(-1, *adj.shape)
192+
A = A.mean(dim=0)
193+
A = (A + A.transpose(-1,-2))/2
194+
195+
return V, A
196+
197+
def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'):
198+
199+
if conv == 'GCN':
200+
gnn_q = DenseGCNConv(in_dim, attn_dim)
201+
gnn_k = DenseGCNConv(in_dim, attn_dim)
202+
gnn_v = DenseGCNConv(in_dim, out_dim)
203+
204+
return gnn_q, gnn_k, gnn_v
205+
206+
elif conv == 'MLP':
207+
num_layers=2
208+
gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
209+
gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
210+
gnn_v = DenseGCNConv(in_dim, out_dim)
211+
212+
return gnn_q, gnn_k, gnn_v
213+
214+
else:
215+
raise NotImplementedError(f'{conv} not implemented.')
216+
217+
218+
# -------- Layer of ScoreNetworkA --------
219+
class AttentionLayer(torch.nn.Module):
220+
def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
221+
num_heads=4, conv='GCN'):
222+
super(AttentionLayer, self).__init__()
223+
self.attn = torch.nn.ModuleList()
224+
for _ in range(input_dim):
225+
self.attn_dim = attn_dim
226+
self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim,
227+
num_heads=num_heads, conv=conv))
228+
229+
self.hidden_dim = 2*max(input_dim, output_dim)
230+
self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
231+
self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
232+
use_bn=False, activate_func=F.elu)
233+
234+
def forward(self, x, adj, flags):
235+
"""
236+
237+
:param x: B x N x F_i
238+
:param adj: B x C_i x N x N
239+
:return: x_out: B x N x F_o, adj_out: B x C_o x N x N
240+
"""
241+
mask_list = []
242+
x_list = []
243+
for _ in range(len(self.attn)):
244+
_x, mask = self.attn[_](x, adj[:,_,:,:], flags)
245+
mask_list.append(mask.unsqueeze(-1))
246+
x_list.append(_x)
247+
x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
248+
x_out = torch.tanh(x_out)
249+
250+
mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
251+
shape = mlp_in.shape
252+
mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
253+
_adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
254+
_adj = _adj + _adj.transpose(-1,-2)
255+
adj_out = mask_adjs(_adj, flags)
256+
257+
return x_out, adj_out

0 commit comments

Comments
 (0)