Skip to content

Commit e7606b9

Browse files
committed
clean lstm and update readme
1 parent 35767e1 commit e7606b9

File tree

3 files changed

+3
-533
lines changed

3 files changed

+3
-533
lines changed

README.md

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ predictions = model.predict(smiles_list)
176176
| GDSS | [Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations. ICML 2022](https://proceedings.mlr.press/v162/jo22a/jo22a.pdf) |
177177
| MolGPT | [MolGPT: Molecular Generation Using a Transformer-Decoder Model. Journal of Chemical Information and Modeling 2021](https://pubs.acs.org/doi/10.1021/acs.jcim.1c00600) |
178178
| GraphGA | [A Graph-Based Genetic Algorithm and Its Application to the Multiobjective Evolution of Median Molecules. Journal of Chemical Information and Computer Sciences 2004](https://pubs.acs.org/doi/10.1021/ci034290p) |
179+
| LSTM (SMILES) | [Long short-term memory (Neural Computation 1997)](https://ieeexplore.ieee.org/abstract/document/6795963) based on SMILES strings |
179180

180181
### Representation Models
181182

@@ -192,55 +193,7 @@ predictions = model.predict(smiles_list)
192193

193194
## Project Structure
194195

195-
The structure of `torch_molecule` is as follows:
196-
197-
`tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'`
198-
199-
```
200-
torch_molecule
201-
├── base
202-
│ ├── base.py
203-
│ ├── encoder.py
204-
│ ├── generator.py
205-
│ ├── __init__.py
206-
│ └── predictor.py
207-
├── encoder
208-
│ ├── attrmask
209-
│ ├── constant.py
210-
│ ├── contextpred
211-
│ ├── edgepred
212-
│ ├── moama
213-
│ └── supervised
214-
├── generator
215-
│ ├── digress
216-
│ ├── graph_dit
217-
│ └── graphga
218-
├── __init__.py
219-
├── nn
220-
│ ├── attention.py
221-
│ ├── embedder.py
222-
│ ├── gnn.py
223-
│ ├── __init__.py
224-
│ └── mlp.py
225-
├── predictor
226-
│ ├── dir
227-
│ ├── gnn
228-
│ ├── grea
229-
│ ├── irm
230-
│ ├── lstm
231-
│ ├── rpgnn
232-
│ ├── sgir
233-
│ └── ssr
234-
└── utils
235-
├── checker.py
236-
├── checkpoint.py
237-
├── format.py
238-
├── generic
239-
├── graph
240-
├── hf.py
241-
├── __init__.py
242-
└── search.py
243-
```
196+
See the structure of `torch_molecule` with the command `tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'`
244197

245198
## Acknowledgements
246199

torch_molecule/generator/lstm/lstm.py

Lines changed: 1 addition & 317 deletions
Original file line numberDiff line numberDiff line change
@@ -56,320 +56,4 @@ def compute_loss(self, batch_data, criterion):
5656
output, hidden, cell = self.forward(ipt, hidden, cell)
5757
output = output.view(output.size(0) * output.size(1), -1)
5858
loss = criterion(output, tgt.view(-1))
59-
return loss
60-
61-
## Define SmilesRnnSampler
62-
# class SMILESSampler:
63-
# """
64-
# Samples molecules from an RNN smiles language model
65-
# """
66-
# def __init__(self, device: str, batch_size=64) -> None:
67-
# """
68-
# Args:
69-
# device: cpu | cuda
70-
# batch_size: number of concurrent samples to generate
71-
# """
72-
# self.device = device
73-
# self.batch_size = batch_size
74-
# self.sd = SmilesCharDictionary()
75-
76-
# def sample(self, model: LSTM, num_to_sample: int, max_seq_len=100):
77-
# """
78-
79-
# Args:
80-
# model: RNN to sample from
81-
# num_to_sample: number of samples to produce
82-
# max_seq_len: maximum length of the samples
83-
# batch_size: number of concurrent samples to generate
84-
85-
# Returns: a list of SMILES string, with no beginning nor end symbols
86-
87-
# """
88-
# sampler = ActionSampler(max_batch_size=self.batch_size, max_seq_length=max_seq_len, device=self.device)
89-
90-
# model.eval()
91-
# with torch.no_grad():
92-
# indices = sampler.sample(model, num_samples=num_to_sample)
93-
# return self.sd.matrix_to_smiles(indices)
94-
95-
# define SmilesRnnTrainer
96-
97-
# class SmilesRnnTrainer:
98-
# def __init__(self, model, criteria, optimizer, device, log_dir=None, clip_gradients=True) -> None:
99-
# self.model = model.to(device)
100-
# self.criteria = [c.to(device) for c in criteria]
101-
# self.optimizer = optimizer
102-
# self.device = device
103-
# self.log_dir = log_dir
104-
# self.clip_gradients = clip_gradients
105-
106-
# def process_batch(self, batch):
107-
108-
# # ship data to device
109-
# inp, tgt = batch
110-
# inp = inp.to(self.device)
111-
# tgt = tgt.to(self.device)
112-
113-
# # process data
114-
# batch_size = inp.size(0)
115-
# hidden = self.model.init_hidden(inp.size(0), self.device)
116-
# output, hidden = self.model(inp, hidden)
117-
# output = output.view(output.size(0) * output.size(1), -1)
118-
# loss = self.criteria[0](output, tgt.view(-1))
119-
# return loss, batch_size
120-
121-
# def train_on_batch(self, batch):
122-
123-
# # setup model for training
124-
# self.model.train()
125-
# self.model.zero_grad()
126-
127-
# # forward / backward
128-
# loss, size = self.process_batch(batch)
129-
# loss.backward()
130-
131-
# # optimize
132-
# if self.clip_gradients:
133-
# nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
134-
# self.optimizer.step()
135-
136-
# return loss.item(), size
137-
138-
# def test_on_batch(self, batch):
139-
140-
# # setup model for evaluation
141-
# self.model.eval()
142-
143-
# # forward
144-
# loss, size = self.process_batch(batch)
145-
146-
# return loss.item(), size
147-
148-
# def validate(self, data_loader, n_molecule):
149-
# """Runs validation and reports the average loss"""
150-
# valid_losses = []
151-
# with torch.no_grad():
152-
# for batch in data_loader:
153-
# loss, size = self.test_on_batch(batch)
154-
# valid_losses += [loss]
155-
# return np.array(valid_losses).mean()
156-
157-
# def train_extra_log(self, n_molecules):
158-
# pass
159-
160-
# def valid_extra_log(self, n_molecules):
161-
# pass
162-
163-
# def fit(self, training_data, test_data, n_epochs, batch_size, print_every,
164-
# valid_every, num_workers=0):
165-
# training_round = _ModelTrainingRound(self, training_data, test_data, n_epochs, batch_size, print_every,
166-
# valid_every, num_workers)
167-
# return training_round.run()
168-
169-
170-
# class _ModelTrainingRound:
171-
# """
172-
# Performs one round of model training.
173-
174-
# Is a separate class from ModelTrainer to allow for more modular functions without too many parameters.
175-
# This class is not to be used outside of ModelTrainer.
176-
# """
177-
# class EarlyStopNecessary(Exception):
178-
# pass
179-
180-
# def __init__(self, model_trainer: SmilesRnnTrainer, training_data, test_data, n_epochs, batch_size, print_every,
181-
# valid_every, num_workers=0) -> None:
182-
# self.model_trainer = model_trainer
183-
# self.training_data = training_data
184-
# self.test_data = test_data
185-
# self.n_epochs = n_epochs
186-
# self.batch_size = batch_size
187-
# self.print_every = print_every
188-
# self.valid_every = valid_every
189-
# self.num_workers = num_workers
190-
191-
# self.start_time = time.time()
192-
# self.unprocessed_train_losses: List[float] = []
193-
# self.all_train_losses: List[float] = []
194-
# self.all_valid_losses: List[float] = []
195-
# self.n_molecules_so_far = 0
196-
# self.has_run = False
197-
# self.min_valid_loss = np.inf
198-
# self.min_avg_train_loss = np.inf
199-
200-
# def run(self):
201-
# if self.has_run:
202-
# raise Exception('_ModelTrainingRound.train() can be called only once.')
203-
204-
# try:
205-
# for epoch_index in range(1, self.n_epochs + 1):
206-
# self._train_one_epoch(epoch_index)
207-
208-
# self._validation_on_final_model()
209-
# except _ModelTrainingRound.EarlyStopNecessary:
210-
# logger.error('Probable explosion during training. Stopping now.')
211-
212-
# self.has_run = True
213-
# return self.all_train_losses, self.all_valid_losses
214-
215-
# def _train_one_epoch(self, epoch_index: int):
216-
# logger.info(f'EPOCH {epoch_index}')
217-
218-
# # shuffle at every epoch
219-
# data_loader = DataLoader(self.training_data,
220-
# batch_size=self.batch_size,
221-
# shuffle=True,
222-
# num_workers=self.num_workers,
223-
# pin_memory=True)
224-
225-
# epoch_t0 = time.time()
226-
# self.unprocessed_train_losses.clear()
227-
228-
# for batch_index, batch in enumerate(data_loader):
229-
# self._train_one_batch(batch_index, batch, epoch_index, epoch_t0)
230-
231-
# def _train_one_batch(self, batch_index, batch, epoch_index, train_t0):
232-
# loss, size = self.model_trainer.train_on_batch(batch)
233-
234-
# self.unprocessed_train_losses += [loss]
235-
# self.n_molecules_so_far += size
236-
237-
# # report training progress?
238-
# if batch_index > 0 and batch_index % self.print_every == 0:
239-
# self._report_training_progress(batch_index, epoch_index, epoch_start=train_t0)
240-
241-
# # report validation progress?
242-
# if batch_index >= 0 and batch_index % self.valid_every == 0:
243-
# self._report_validation_progress(epoch_index)
244-
245-
# def _report_training_progress(self, batch_index, epoch_index, epoch_start):
246-
# mols_sec = self._calculate_mols_per_second(batch_index, epoch_start)
247-
248-
# # Update train losses by processing all losses since last time this function was executed
249-
# avg_train_loss = np.array(self.unprocessed_train_losses).mean()
250-
# self.all_train_losses += avg_train_loss
251-
# self.unprocessed_train_losses.clear()
252-
253-
# logger.info(
254-
# 'TRAIN | '
255-
# f'elapsed: {time_since(self.start_time)} | '
256-
# f'epoch|batch : {epoch_index}|{batch_index} ({self._get_overall_progress():.1f}%) | '
257-
# f'molecules: {self.n_molecules_so_far} | '
258-
# f'mols/sec: {mols_sec:.2f} | '
259-
# f'train_loss: {avg_train_loss:.4f}')
260-
# self.model_trainer.train_extra_log(self.n_molecules_so_far)
261-
262-
# self._check_early_stopping_train_loss(avg_train_loss)
263-
264-
# def _calculate_mols_per_second(self, batch_index, epoch_start):
265-
# """
266-
# Calculates the speed so far in the current epoch.
267-
# """
268-
# train_time_in_current_epoch = time.time() - epoch_start
269-
# processed_batches = batch_index + 1
270-
# molecules_in_current_epoch = self.batch_size * processed_batches
271-
# return molecules_in_current_epoch / train_time_in_current_epoch
272-
273-
# def _report_validation_progress(self, epoch_index):
274-
# avg_valid_loss = self._validate_current_model()
275-
276-
# self._log_validation_step(epoch_index, avg_valid_loss)
277-
# self._check_early_stopping_validation(avg_valid_loss)
278-
279-
# # save model?
280-
# if self.model_trainer.log_dir:
281-
# if avg_valid_loss <= min(self.all_valid_losses):
282-
# self._save_current_model(self.model_trainer.log_dir, epoch_index, avg_valid_loss)
283-
284-
# def _validate_current_model(self):
285-
# """
286-
# Validate the current model.
287-
288-
# Returns: Validation loss.
289-
# """
290-
# test_loader = DataLoader(self.test_data,
291-
# batch_size=self.batch_size,
292-
# shuffle=False,
293-
# num_workers=self.num_workers,
294-
# pin_memory=True)
295-
# avg_valid_loss = self.model_trainer.validate(test_loader, self.n_molecules_so_far)
296-
# self.all_valid_losses += [avg_valid_loss]
297-
# return avg_valid_loss
298-
299-
# def _log_validation_step(self, epoch_index, avg_valid_loss):
300-
# """
301-
# Log the information about the validation step.
302-
# """
303-
# logger.info(
304-
# 'VALID | '
305-
# f'elapsed: {time_since(self.start_time)} | '
306-
# f'epoch: {epoch_index}/{self.n_epochs} ({self._get_overall_progress():.1f}%) | '
307-
# f'molecules: {self.n_molecules_so_far} | '
308-
# f'valid_loss: {avg_valid_loss:.4f}')
309-
# self.model_trainer.valid_extra_log(self.n_molecules_so_far)
310-
# logger.info('')
311-
312-
# def _get_overall_progress(self):
313-
# total_mols = self.n_epochs * len(self.training_data)
314-
# return 100. * self.n_molecules_so_far / total_mols
315-
316-
# def _validation_on_final_model(self):
317-
# """
318-
# Run validation for the final model and save it.
319-
# """
320-
# valid_loss = self._validate_current_model()
321-
# logger.info(
322-
# 'VALID | FINAL_MODEL | '
323-
# f'elapsed: {time_since(self.start_time)} | '
324-
# f'molecules: {self.n_molecules_so_far} | '
325-
# f'valid_loss: {valid_loss:.4f}')
326-
327-
# if self.model_trainer.log_dir:
328-
# self._save_model(self.model_trainer.log_dir, 'final', valid_loss)
329-
330-
# def _save_current_model(self, base_dir, epoch, valid_loss):
331-
# """
332-
# Delete previous versions of the model and save the current one.
333-
# """
334-
# for f in glob(os.path.join(base_dir, 'model_*')):
335-
# os.remove(f)
336-
337-
# self._save_model(base_dir, epoch, valid_loss)
338-
339-
# def _save_model(self, base_dir, info, valid_loss):
340-
# """
341-
# Save a copy of the model with format:
342-
# model_{info}_{valid_loss}
343-
# """
344-
# base_name = f'model_{info}_{valid_loss:.3f}'
345-
# logger.info(base_name)
346-
# save_model(self.model_trainer.model, base_dir, base_name)
347-
348-
# def _check_early_stopping_train_loss(self, avg_train_loss):
349-
# """
350-
# This function checks whether the training has exploded by verifying if the avg training loss
351-
# is more than 10 times the minimal loss so far.
352-
353-
# If this is the case, a EarlyStopNecessary exception is raised.
354-
# """
355-
# threshold = 10 * self.min_avg_train_loss
356-
# if avg_train_loss > threshold:
357-
# raise _ModelTrainingRound.EarlyStopNecessary()
358-
359-
# # update the min train loss if necessary
360-
# if avg_train_loss < self.min_avg_train_loss:
361-
# self.min_avg_train_loss = avg_train_loss
362-
363-
# def _check_early_stopping_validation(self, avg_valid_loss):
364-
# """
365-
# This function checks whether the training has exploded by verifying if the validation loss
366-
# has more than doubled compared to the minimum validation loss so far.
367-
368-
# If this is the case, a EarlyStopNecessary exception is raised.
369-
# """
370-
# threshold = 2 * self.min_valid_loss
371-
# if avg_valid_loss > threshold:
372-
# raise _ModelTrainingRound.EarlyStopNecessary()
373-
374-
# if avg_valid_loss < self.min_valid_loss:
375-
# self.min_valid_loss = avg_valid_loss
59+
return loss

0 commit comments

Comments
 (0)