Skip to content

Commit 35767e1

Browse files
committed
add lstm generation with docs
1 parent c12d4a9 commit 35767e1

File tree

9 files changed

+211
-95
lines changed

9 files changed

+211
-95
lines changed

docs/source/api/generator.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,18 @@ Modeling Molecules as Graphs with Heuristic-based Generators
6767
:undoc-members:
6868
:show-inheritance:
6969

70-
Modeling Molecules as Sequences with Transformer-based Generators
71-
-----------------------------------------------------------------
70+
Modeling Molecules as Sequences
71+
--------------------------------
7272

7373
.. rubric:: MolGPT for Unconditional Molecular Generation
7474
.. autoclass:: torch_molecule.generator.molgpt.modeling_molgpt.MolGPTMolecularGenerator
75+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
76+
:members: fit, generate
77+
:undoc-members:
78+
:show-inheritance:
79+
80+
.. rubric:: LSTM for Unconditional/Conditional Molecular Generation
81+
.. autoclass:: torch_molecule.generator.lstm.modeling_lstm.LSTMMolecularGenerator
7582
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
7683
:members: fit, generate
7784
:undoc-members:

tests/generator/run_lstm.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import numpy as np
3+
import pandas as pd
4+
from tqdm import tqdm
5+
6+
import torch
7+
from torch_molecule.generator.lstm import LSTMMolecularGenerator
8+
9+
EPOCHS = 1000 # Reduced for faster testing
10+
BATCH_SIZE = 24
11+
12+
def test_lstm_generator():
13+
# Load data from polymer100.csv
14+
data_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
15+
"data", "polymer100.csv")
16+
print(f"Loading data from: {data_path}")
17+
18+
df = pd.read_csv(data_path)
19+
smiles_list = df['smiles'].tolist()
20+
21+
# Extract property columns (all columns except 'smiles')
22+
property_columns = [col for col in df.columns if col != 'smiles']
23+
properties = df[property_columns].values.tolist()
24+
25+
print(f"Loaded {len(smiles_list)} molecules with {len(property_columns)} properties")
26+
print(f"Property columns: {property_columns}")
27+
print(f"First 3 SMILES: {smiles_list[:3]}")
28+
print(f"First 3 properties: {properties[:3]}")
29+
30+
# 1. Basic initialization test - Unconditional Model
31+
print("\n=== Testing Unconditional LSTM model initialization ===")
32+
unconditional_model = LSTMMolecularGenerator(
33+
num_layer=3,
34+
hidden_size=128,
35+
max_len=64,
36+
batch_size=BATCH_SIZE,
37+
epochs=EPOCHS,
38+
verbose=True
39+
)
40+
print("Unconditional LSTM Model initialized successfully")
41+
42+
# 2. Basic fitting test - Unconditional Model
43+
print("\n=== Testing Unconditional LSTM model fitting ===")
44+
unconditional_model.fit(smiles_list)
45+
print("Unconditional LSTM Model fitting completed")
46+
47+
# 3. Unconditional generation test
48+
print("\n=== Testing Unconditional LSTM generation ===")
49+
generated_smiles_uncond = unconditional_model.generate(batch_size=BATCH_SIZE)
50+
print(f"Unconditionally generated {len(generated_smiles_uncond)} molecules")
51+
print("Example unconditionally generated SMILES:", generated_smiles_uncond[:10])
52+
53+
# 4. Model saving and loading test - Unconditional Model
54+
print("\n=== Testing Unconditional LSTM model saving and loading ===")
55+
save_path = "unconditional_lstm_test_model.pt"
56+
unconditional_model.save_to_local(save_path)
57+
print(f"Unconditional LSTM Model saved to {save_path}")
58+
59+
new_unconditional_model = LSTMMolecularGenerator()
60+
new_unconditional_model.load_from_local(save_path)
61+
print("Unconditional LSTM Model loaded successfully")
62+
63+
# Test generation with loaded unconditional model
64+
generated_smiles_uncond = new_unconditional_model.generate(batch_size=5)
65+
print("Generated molecules with loaded unconditional model:", len(generated_smiles_uncond))
66+
print("Example generated SMILES:", generated_smiles_uncond[:10])
67+
68+
# Clean up unconditional model
69+
if os.path.exists(save_path):
70+
os.remove(save_path)
71+
print(f"Cleaned up {save_path}")
72+
73+
# 5. Basic initialization test - Property Conditional Model
74+
print("\n=== Testing Property Conditional LSTM model initialization ===")
75+
prop_conditional_model = LSTMMolecularGenerator(
76+
num_layer=2,
77+
hidden_size=128,
78+
max_len=64,
79+
num_task=len(property_columns), # Set number of properties
80+
batch_size=BATCH_SIZE,
81+
epochs=EPOCHS,
82+
verbose=True
83+
)
84+
print("Property Conditional LSTM Model initialized successfully")
85+
86+
# 6. Basic fitting test - Property Conditional Model
87+
print("\n=== Testing Property Conditional LSTM model fitting ===")
88+
prop_conditional_model.fit(smiles_list, properties)
89+
print("Property Conditional LSTM Model fitting completed")
90+
91+
# 7. Property conditional generation test
92+
print("\n=== Testing Property Conditional LSTM generation ===")
93+
# Create some target properties (using mean values from the dataset as a starting point)
94+
mean_properties = np.mean(properties, axis=0).tolist()
95+
target_properties = []
96+
for i in range(5):
97+
# Create variations around the mean
98+
target_prop = [p * (0.8 + 0.4 * np.random.random()) for p in mean_properties]
99+
target_properties.append(target_prop)
100+
101+
print(f"Target properties for generation: {target_properties}")
102+
generated_smiles = prop_conditional_model.generate(labels=target_properties)
103+
print(f"Property conditionally generated {len(generated_smiles)} molecules")
104+
print("Example property conditionally generated SMILES:", generated_smiles[:2])
105+
106+
# 8. Model saving and loading test - Property Conditional Model
107+
print("\n=== Testing Property Conditional LSTM model saving and loading ===")
108+
save_path = "prop_conditional_lstm_test_model.pt"
109+
prop_conditional_model.save_to_local(save_path)
110+
print(f"Property Conditional LSTM Model saved to {save_path}")
111+
112+
new_prop_conditional_model = LSTMMolecularGenerator()
113+
new_prop_conditional_model.load_from_local(save_path)
114+
print("Property Conditional LSTM Model loaded successfully")
115+
116+
# Test generation with loaded property conditional model
117+
generated_smiles = new_prop_conditional_model.generate(labels=target_properties)
118+
print("Generated molecules with loaded property conditional model:", len(generated_smiles))
119+
print("Example generated SMILES:", generated_smiles[:2])
120+
121+
# Clean up property conditional model
122+
if os.path.exists(save_path):
123+
os.remove(save_path)
124+
print(f"Cleaned up {save_path}")
125+
126+
if __name__ == "__main__":
127+
test_lstm_generator()

torch_molecule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .generator.digress import DigressMolecularGenerator
3232
from .generator.molgpt import MolGPTMolecularGenerator
3333
from .generator.gdss import GDSSMolecularGenerator
34+
from .generator.lstm import LSTMMolecularGenerator
3435

3536
__all__ = [
3637
# 'BaseMolecularPredictor',
@@ -59,4 +60,5 @@
5960
'DigressMolecularGenerator',
6061
'MolGPTMolecularGenerator',
6162
'GDSSMolecularGenerator',
63+
'LSTMMolecularGenerator',
6264
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .modeling_lstm import LSTMMolecularGenerator
22

3-
# __all__ = ['LSTMMolecularGenerator']
3+
__all__ = ['LSTMMolecularGenerator']

torch_molecule/generator/lstm/action_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch.nn.functional as F
55
from torch.distributions import Categorical, Distribution
66

7-
from lstm import LSTM
8-
from utils import rnn_start_token_vector
7+
from .lstm import LSTM
8+
from .utils import rnn_start_token_vector
99

1010
class ActionSampler:
1111
"""

torch_molecule/generator/lstm/lstm.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,26 @@
1-
# import os
2-
# import time
3-
# from glob import glob
4-
# from functools import total_ordering
5-
# from typing import List, Set
6-
# from tqdm import tqdm
7-
81
import torch
92
import torch.nn as nn
10-
# from torch.utils.data import DataLoader
11-
# import numpy as np
12-
13-
# from .utils import canonicalize_list
14-
# from .utils import get_tensor_dataset, load_smiles_from_list
15-
# from .utils import save_model, time_since
16-
17-
# from .action_sampler import ActionSampler
18-
# from .smiles_char_dict import SmilesCharDictionary
19-
20-
# import logging
21-
# logger = logging.getLogger(__name__)
22-
# logger.addHandler(logging.NullHandler())
23-
24-
# this file contains:
25-
# SmilesRnn
26-
# SmilesRnnTrainer
27-
# SmilesRnnSampler
28-
29-
# @total_ordering
30-
# class OptResult:
31-
# def __init__(self, smiles: str, score: float) -> None:
32-
# self.smiles = smiles
33-
# self.score = score
34-
35-
# def __eq__(self, other):
36-
# return (self.score, self.smiles) == (other.score, other.smiles)
37-
38-
# def __lt__(self, other):
39-
# return (self.score, self.smiles) < (other.score, other.smiles)
403

414
class LSTM(nn.Module):
42-
"""
43-
character-based RNN language model optimized by with hill-climbing
44-
"""
45-
def __init__(self, num_task, input_size, hidden_size, output_size, num_layer, dropout) -> None:
46-
self.num_task = num_task
5+
def __init__(self, num_task, input_size, hidden_size, output_size, num_layer, dropout):
6+
super().__init__()
477
self.input_size = input_size
488
self.hidden_size = hidden_size
499
self.output_size = output_size
5010
self.num_layer = num_layer
5111
self.dropout = dropout
52-
self.hidden_transform = nn.Linear(num_task, num_layer * hidden_size)
53-
self.cell_transform = nn.Linear(num_task, num_layer * hidden_size)
12+
if num_task == 0:
13+
self.input_dim = 1
14+
else:
15+
self.input_dim = num_task
16+
self.hidden_transform = nn.Linear(self.input_dim, num_layer * hidden_size)
17+
self.cell_transform = nn.Linear(self.input_dim, num_layer * hidden_size)
5418
self.encoder = nn.Embedding(input_size, hidden_size)
5519
self.decoder = nn.Linear(hidden_size, output_size)
5620

5721
self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=num_layer, dropout=dropout)
5822
self.initialize_parameters()
59-
# self.criterion = nn.CrossEntropyLoss()
60-
# self.sampler = SmilesRnnSampler(device=self.device, batch_size=512)
61-
# self.max_len = max_len
62-
23+
6324
def initialize_parameters(self):
6425
# encoder / decoder
6526
nn.init.xavier_uniform_(self.encoder.weight)
@@ -78,7 +39,7 @@ def initialize_parameters(self):
7839

7940
def forward(self, input, hidden, cell):
8041
embeds = self.encoder(input)
81-
output, hidden, cell = self.rnn(embeds, (hidden, cell))
42+
output, (hidden, cell) = self.rnn(embeds, (hidden, cell))
8243
output = self.decoder(output)
8344
return output, hidden, cell
8445

0 commit comments

Comments
 (0)