1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ from torch_molecule .generator .lstm import LSTMMolecularGenerator
8
+
9
+ EPOCHS = 1000 # Reduced for faster testing
10
+ BATCH_SIZE = 24
11
+
12
+ def test_lstm_generator ():
13
+ # Load data from polymer100.csv
14
+ data_path = os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (__file__ ))),
15
+ "data" , "polymer100.csv" )
16
+ print (f"Loading data from: { data_path } " )
17
+
18
+ df = pd .read_csv (data_path )
19
+ smiles_list = df ['smiles' ].tolist ()
20
+
21
+ # Extract property columns (all columns except 'smiles')
22
+ property_columns = [col for col in df .columns if col != 'smiles' ]
23
+ properties = df [property_columns ].values .tolist ()
24
+
25
+ print (f"Loaded { len (smiles_list )} molecules with { len (property_columns )} properties" )
26
+ print (f"Property columns: { property_columns } " )
27
+ print (f"First 3 SMILES: { smiles_list [:3 ]} " )
28
+ print (f"First 3 properties: { properties [:3 ]} " )
29
+
30
+ # 1. Basic initialization test - Unconditional Model
31
+ print ("\n === Testing Unconditional LSTM model initialization ===" )
32
+ unconditional_model = LSTMMolecularGenerator (
33
+ num_layer = 3 ,
34
+ hidden_size = 128 ,
35
+ max_len = 64 ,
36
+ batch_size = BATCH_SIZE ,
37
+ epochs = EPOCHS ,
38
+ verbose = True
39
+ )
40
+ print ("Unconditional LSTM Model initialized successfully" )
41
+
42
+ # 2. Basic fitting test - Unconditional Model
43
+ print ("\n === Testing Unconditional LSTM model fitting ===" )
44
+ unconditional_model .fit (smiles_list )
45
+ print ("Unconditional LSTM Model fitting completed" )
46
+
47
+ # 3. Unconditional generation test
48
+ print ("\n === Testing Unconditional LSTM generation ===" )
49
+ generated_smiles_uncond = unconditional_model .generate (batch_size = BATCH_SIZE )
50
+ print (f"Unconditionally generated { len (generated_smiles_uncond )} molecules" )
51
+ print ("Example unconditionally generated SMILES:" , generated_smiles_uncond [:10 ])
52
+
53
+ # 4. Model saving and loading test - Unconditional Model
54
+ print ("\n === Testing Unconditional LSTM model saving and loading ===" )
55
+ save_path = "unconditional_lstm_test_model.pt"
56
+ unconditional_model .save_to_local (save_path )
57
+ print (f"Unconditional LSTM Model saved to { save_path } " )
58
+
59
+ new_unconditional_model = LSTMMolecularGenerator ()
60
+ new_unconditional_model .load_from_local (save_path )
61
+ print ("Unconditional LSTM Model loaded successfully" )
62
+
63
+ # Test generation with loaded unconditional model
64
+ generated_smiles_uncond = new_unconditional_model .generate (batch_size = 5 )
65
+ print ("Generated molecules with loaded unconditional model:" , len (generated_smiles_uncond ))
66
+ print ("Example generated SMILES:" , generated_smiles_uncond [:10 ])
67
+
68
+ # Clean up unconditional model
69
+ if os .path .exists (save_path ):
70
+ os .remove (save_path )
71
+ print (f"Cleaned up { save_path } " )
72
+
73
+ # 5. Basic initialization test - Property Conditional Model
74
+ print ("\n === Testing Property Conditional LSTM model initialization ===" )
75
+ prop_conditional_model = LSTMMolecularGenerator (
76
+ num_layer = 2 ,
77
+ hidden_size = 128 ,
78
+ max_len = 64 ,
79
+ num_task = len (property_columns ), # Set number of properties
80
+ batch_size = BATCH_SIZE ,
81
+ epochs = EPOCHS ,
82
+ verbose = True
83
+ )
84
+ print ("Property Conditional LSTM Model initialized successfully" )
85
+
86
+ # 6. Basic fitting test - Property Conditional Model
87
+ print ("\n === Testing Property Conditional LSTM model fitting ===" )
88
+ prop_conditional_model .fit (smiles_list , properties )
89
+ print ("Property Conditional LSTM Model fitting completed" )
90
+
91
+ # 7. Property conditional generation test
92
+ print ("\n === Testing Property Conditional LSTM generation ===" )
93
+ # Create some target properties (using mean values from the dataset as a starting point)
94
+ mean_properties = np .mean (properties , axis = 0 ).tolist ()
95
+ target_properties = []
96
+ for i in range (5 ):
97
+ # Create variations around the mean
98
+ target_prop = [p * (0.8 + 0.4 * np .random .random ()) for p in mean_properties ]
99
+ target_properties .append (target_prop )
100
+
101
+ print (f"Target properties for generation: { target_properties } " )
102
+ generated_smiles = prop_conditional_model .generate (labels = target_properties )
103
+ print (f"Property conditionally generated { len (generated_smiles )} molecules" )
104
+ print ("Example property conditionally generated SMILES:" , generated_smiles [:2 ])
105
+
106
+ # 8. Model saving and loading test - Property Conditional Model
107
+ print ("\n === Testing Property Conditional LSTM model saving and loading ===" )
108
+ save_path = "prop_conditional_lstm_test_model.pt"
109
+ prop_conditional_model .save_to_local (save_path )
110
+ print (f"Property Conditional LSTM Model saved to { save_path } " )
111
+
112
+ new_prop_conditional_model = LSTMMolecularGenerator ()
113
+ new_prop_conditional_model .load_from_local (save_path )
114
+ print ("Property Conditional LSTM Model loaded successfully" )
115
+
116
+ # Test generation with loaded property conditional model
117
+ generated_smiles = new_prop_conditional_model .generate (labels = target_properties )
118
+ print ("Generated molecules with loaded property conditional model:" , len (generated_smiles ))
119
+ print ("Example generated SMILES:" , generated_smiles [:2 ])
120
+
121
+ # Clean up property conditional model
122
+ if os .path .exists (save_path ):
123
+ os .remove (save_path )
124
+ print (f"Cleaned up { save_path } " )
125
+
126
+ if __name__ == "__main__" :
127
+ test_lstm_generator ()
0 commit comments