Skip to content

Commit c12d4a9

Browse files
committed
lstm gen half
1 parent 46ee287 commit c12d4a9

File tree

7 files changed

+1106
-4
lines changed

7 files changed

+1106
-4
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .modeling_lstm import LSTMMolecularGenerator
2+
3+
# __all__ = ['LSTMMolecularGenerator']
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Type
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch.distributions import Categorical, Distribution
6+
7+
from lstm import LSTM
8+
from utils import rnn_start_token_vector
9+
10+
class ActionSampler:
11+
"""
12+
Sampler for a SmilesRNN models.
13+
14+
Does not return SMILES strings directly, but instead the actions (i.e. which SMILES character to select).
15+
Those values are more general and are for instance necessary for other RL algorithms.
16+
17+
The class will sample the RNN model multiple times if the number of desired samples is larger than the
18+
maximal allowed batch size.
19+
"""
20+
21+
def __init__(self, max_batch_size, max_seq_length, device,
22+
distribution_cls: Type[Distribution] = Categorical) -> None:
23+
"""
24+
Args:
25+
max_batch_size: maximal batch size for the RNN model
26+
max_seq_length: max length for a sampled SMILES string
27+
device: cuda | cpu
28+
distribution_cls: distribution type to sample from. If None, will be a multinomial distribution. Useful for testing purposes.
29+
"""
30+
self.max_batch_size = max_batch_size
31+
self.max_seq_length = max_seq_length
32+
self.device = device
33+
self.distribution_cls = distribution_cls
34+
35+
def sample(self, model: LSTM, num_samples: int, target: torch.Tensor) -> torch.Tensor:
36+
"""
37+
Samples a specified number of actions from an RNN model based on a multinomial distribution.
38+
39+
Args:
40+
model: Smiles RNN model to sample from
41+
num_samples: Number of samples to generate
42+
43+
Returns:
44+
tensor of actions (num_samples x max_seq_length)
45+
"""
46+
47+
# Round up division to get the number of batches that are necessary:
48+
number_batches = (num_samples + self.max_batch_size - 1) // self.max_batch_size
49+
remaining_samples = num_samples
50+
51+
actions = torch.LongTensor(num_samples, self.max_seq_length).to(self.device)
52+
53+
batch_start = 0
54+
55+
for i in range(number_batches):
56+
batch_size = min(self.max_batch_size, remaining_samples)
57+
batch_end = batch_start + batch_size
58+
59+
actions[batch_start:batch_end, :] = self._sample_batch(model, batch_size, target)
60+
61+
batch_start += batch_size
62+
remaining_samples -= batch_size
63+
64+
return actions
65+
66+
def _sample_batch(self, model: LSTM, batch_size: int, target: torch.Tensor) -> torch.Tensor:
67+
"""
68+
Samples a batch of actions based on a multinomial distribution.
69+
70+
Args:
71+
model: Smiles RNN model to sample from
72+
num_samples: Number of samples to generate
73+
74+
Returns:
75+
tensor of actions (batch_size x max_seq_length)
76+
"""
77+
hidden, cell = model.init_hidden(batch_size, target)
78+
inp = rnn_start_token_vector(batch_size, self.device)
79+
actions = torch.zeros((batch_size, self.max_seq_length), dtype=torch.long).to(self.device)
80+
81+
for char in range(self.max_seq_length):
82+
output, hidden, cell = model(inp, hidden, cell)
83+
prob = F.softmax(output, dim=2)
84+
distribution = self.distribution_cls(probs=prob)
85+
action = distribution.sample()
86+
87+
actions[:, char] = action.squeeze()
88+
89+
inp = action
90+
91+
return actions

0 commit comments

Comments
 (0)