@@ -56,320 +56,4 @@ def compute_loss(self, batch_data, criterion):
56
56
output , hidden , cell = self .forward (ipt , hidden , cell )
57
57
output = output .view (output .size (0 ) * output .size (1 ), - 1 )
58
58
loss = criterion (output , tgt .view (- 1 ))
59
- return loss
60
-
61
- ## Define SmilesRnnSampler
62
- # class SMILESSampler:
63
- # """
64
- # Samples molecules from an RNN smiles language model
65
- # """
66
- # def __init__(self, device: str, batch_size=64) -> None:
67
- # """
68
- # Args:
69
- # device: cpu | cuda
70
- # batch_size: number of concurrent samples to generate
71
- # """
72
- # self.device = device
73
- # self.batch_size = batch_size
74
- # self.sd = SmilesCharDictionary()
75
-
76
- # def sample(self, model: LSTM, num_to_sample: int, max_seq_len=100):
77
- # """
78
-
79
- # Args:
80
- # model: RNN to sample from
81
- # num_to_sample: number of samples to produce
82
- # max_seq_len: maximum length of the samples
83
- # batch_size: number of concurrent samples to generate
84
-
85
- # Returns: a list of SMILES string, with no beginning nor end symbols
86
-
87
- # """
88
- # sampler = ActionSampler(max_batch_size=self.batch_size, max_seq_length=max_seq_len, device=self.device)
89
-
90
- # model.eval()
91
- # with torch.no_grad():
92
- # indices = sampler.sample(model, num_samples=num_to_sample)
93
- # return self.sd.matrix_to_smiles(indices)
94
-
95
- # define SmilesRnnTrainer
96
-
97
- # class SmilesRnnTrainer:
98
- # def __init__(self, model, criteria, optimizer, device, log_dir=None, clip_gradients=True) -> None:
99
- # self.model = model.to(device)
100
- # self.criteria = [c.to(device) for c in criteria]
101
- # self.optimizer = optimizer
102
- # self.device = device
103
- # self.log_dir = log_dir
104
- # self.clip_gradients = clip_gradients
105
-
106
- # def process_batch(self, batch):
107
-
108
- # # ship data to device
109
- # inp, tgt = batch
110
- # inp = inp.to(self.device)
111
- # tgt = tgt.to(self.device)
112
-
113
- # # process data
114
- # batch_size = inp.size(0)
115
- # hidden = self.model.init_hidden(inp.size(0), self.device)
116
- # output, hidden = self.model(inp, hidden)
117
- # output = output.view(output.size(0) * output.size(1), -1)
118
- # loss = self.criteria[0](output, tgt.view(-1))
119
- # return loss, batch_size
120
-
121
- # def train_on_batch(self, batch):
122
-
123
- # # setup model for training
124
- # self.model.train()
125
- # self.model.zero_grad()
126
-
127
- # # forward / backward
128
- # loss, size = self.process_batch(batch)
129
- # loss.backward()
130
-
131
- # # optimize
132
- # if self.clip_gradients:
133
- # nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
134
- # self.optimizer.step()
135
-
136
- # return loss.item(), size
137
-
138
- # def test_on_batch(self, batch):
139
-
140
- # # setup model for evaluation
141
- # self.model.eval()
142
-
143
- # # forward
144
- # loss, size = self.process_batch(batch)
145
-
146
- # return loss.item(), size
147
-
148
- # def validate(self, data_loader, n_molecule):
149
- # """Runs validation and reports the average loss"""
150
- # valid_losses = []
151
- # with torch.no_grad():
152
- # for batch in data_loader:
153
- # loss, size = self.test_on_batch(batch)
154
- # valid_losses += [loss]
155
- # return np.array(valid_losses).mean()
156
-
157
- # def train_extra_log(self, n_molecules):
158
- # pass
159
-
160
- # def valid_extra_log(self, n_molecules):
161
- # pass
162
-
163
- # def fit(self, training_data, test_data, n_epochs, batch_size, print_every,
164
- # valid_every, num_workers=0):
165
- # training_round = _ModelTrainingRound(self, training_data, test_data, n_epochs, batch_size, print_every,
166
- # valid_every, num_workers)
167
- # return training_round.run()
168
-
169
-
170
- # class _ModelTrainingRound:
171
- # """
172
- # Performs one round of model training.
173
-
174
- # Is a separate class from ModelTrainer to allow for more modular functions without too many parameters.
175
- # This class is not to be used outside of ModelTrainer.
176
- # """
177
- # class EarlyStopNecessary(Exception):
178
- # pass
179
-
180
- # def __init__(self, model_trainer: SmilesRnnTrainer, training_data, test_data, n_epochs, batch_size, print_every,
181
- # valid_every, num_workers=0) -> None:
182
- # self.model_trainer = model_trainer
183
- # self.training_data = training_data
184
- # self.test_data = test_data
185
- # self.n_epochs = n_epochs
186
- # self.batch_size = batch_size
187
- # self.print_every = print_every
188
- # self.valid_every = valid_every
189
- # self.num_workers = num_workers
190
-
191
- # self.start_time = time.time()
192
- # self.unprocessed_train_losses: List[float] = []
193
- # self.all_train_losses: List[float] = []
194
- # self.all_valid_losses: List[float] = []
195
- # self.n_molecules_so_far = 0
196
- # self.has_run = False
197
- # self.min_valid_loss = np.inf
198
- # self.min_avg_train_loss = np.inf
199
-
200
- # def run(self):
201
- # if self.has_run:
202
- # raise Exception('_ModelTrainingRound.train() can be called only once.')
203
-
204
- # try:
205
- # for epoch_index in range(1, self.n_epochs + 1):
206
- # self._train_one_epoch(epoch_index)
207
-
208
- # self._validation_on_final_model()
209
- # except _ModelTrainingRound.EarlyStopNecessary:
210
- # logger.error('Probable explosion during training. Stopping now.')
211
-
212
- # self.has_run = True
213
- # return self.all_train_losses, self.all_valid_losses
214
-
215
- # def _train_one_epoch(self, epoch_index: int):
216
- # logger.info(f'EPOCH {epoch_index}')
217
-
218
- # # shuffle at every epoch
219
- # data_loader = DataLoader(self.training_data,
220
- # batch_size=self.batch_size,
221
- # shuffle=True,
222
- # num_workers=self.num_workers,
223
- # pin_memory=True)
224
-
225
- # epoch_t0 = time.time()
226
- # self.unprocessed_train_losses.clear()
227
-
228
- # for batch_index, batch in enumerate(data_loader):
229
- # self._train_one_batch(batch_index, batch, epoch_index, epoch_t0)
230
-
231
- # def _train_one_batch(self, batch_index, batch, epoch_index, train_t0):
232
- # loss, size = self.model_trainer.train_on_batch(batch)
233
-
234
- # self.unprocessed_train_losses += [loss]
235
- # self.n_molecules_so_far += size
236
-
237
- # # report training progress?
238
- # if batch_index > 0 and batch_index % self.print_every == 0:
239
- # self._report_training_progress(batch_index, epoch_index, epoch_start=train_t0)
240
-
241
- # # report validation progress?
242
- # if batch_index >= 0 and batch_index % self.valid_every == 0:
243
- # self._report_validation_progress(epoch_index)
244
-
245
- # def _report_training_progress(self, batch_index, epoch_index, epoch_start):
246
- # mols_sec = self._calculate_mols_per_second(batch_index, epoch_start)
247
-
248
- # # Update train losses by processing all losses since last time this function was executed
249
- # avg_train_loss = np.array(self.unprocessed_train_losses).mean()
250
- # self.all_train_losses += avg_train_loss
251
- # self.unprocessed_train_losses.clear()
252
-
253
- # logger.info(
254
- # 'TRAIN | '
255
- # f'elapsed: {time_since(self.start_time)} | '
256
- # f'epoch|batch : {epoch_index}|{batch_index} ({self._get_overall_progress():.1f}%) | '
257
- # f'molecules: {self.n_molecules_so_far} | '
258
- # f'mols/sec: {mols_sec:.2f} | '
259
- # f'train_loss: {avg_train_loss:.4f}')
260
- # self.model_trainer.train_extra_log(self.n_molecules_so_far)
261
-
262
- # self._check_early_stopping_train_loss(avg_train_loss)
263
-
264
- # def _calculate_mols_per_second(self, batch_index, epoch_start):
265
- # """
266
- # Calculates the speed so far in the current epoch.
267
- # """
268
- # train_time_in_current_epoch = time.time() - epoch_start
269
- # processed_batches = batch_index + 1
270
- # molecules_in_current_epoch = self.batch_size * processed_batches
271
- # return molecules_in_current_epoch / train_time_in_current_epoch
272
-
273
- # def _report_validation_progress(self, epoch_index):
274
- # avg_valid_loss = self._validate_current_model()
275
-
276
- # self._log_validation_step(epoch_index, avg_valid_loss)
277
- # self._check_early_stopping_validation(avg_valid_loss)
278
-
279
- # # save model?
280
- # if self.model_trainer.log_dir:
281
- # if avg_valid_loss <= min(self.all_valid_losses):
282
- # self._save_current_model(self.model_trainer.log_dir, epoch_index, avg_valid_loss)
283
-
284
- # def _validate_current_model(self):
285
- # """
286
- # Validate the current model.
287
-
288
- # Returns: Validation loss.
289
- # """
290
- # test_loader = DataLoader(self.test_data,
291
- # batch_size=self.batch_size,
292
- # shuffle=False,
293
- # num_workers=self.num_workers,
294
- # pin_memory=True)
295
- # avg_valid_loss = self.model_trainer.validate(test_loader, self.n_molecules_so_far)
296
- # self.all_valid_losses += [avg_valid_loss]
297
- # return avg_valid_loss
298
-
299
- # def _log_validation_step(self, epoch_index, avg_valid_loss):
300
- # """
301
- # Log the information about the validation step.
302
- # """
303
- # logger.info(
304
- # 'VALID | '
305
- # f'elapsed: {time_since(self.start_time)} | '
306
- # f'epoch: {epoch_index}/{self.n_epochs} ({self._get_overall_progress():.1f}%) | '
307
- # f'molecules: {self.n_molecules_so_far} | '
308
- # f'valid_loss: {avg_valid_loss:.4f}')
309
- # self.model_trainer.valid_extra_log(self.n_molecules_so_far)
310
- # logger.info('')
311
-
312
- # def _get_overall_progress(self):
313
- # total_mols = self.n_epochs * len(self.training_data)
314
- # return 100. * self.n_molecules_so_far / total_mols
315
-
316
- # def _validation_on_final_model(self):
317
- # """
318
- # Run validation for the final model and save it.
319
- # """
320
- # valid_loss = self._validate_current_model()
321
- # logger.info(
322
- # 'VALID | FINAL_MODEL | '
323
- # f'elapsed: {time_since(self.start_time)} | '
324
- # f'molecules: {self.n_molecules_so_far} | '
325
- # f'valid_loss: {valid_loss:.4f}')
326
-
327
- # if self.model_trainer.log_dir:
328
- # self._save_model(self.model_trainer.log_dir, 'final', valid_loss)
329
-
330
- # def _save_current_model(self, base_dir, epoch, valid_loss):
331
- # """
332
- # Delete previous versions of the model and save the current one.
333
- # """
334
- # for f in glob(os.path.join(base_dir, 'model_*')):
335
- # os.remove(f)
336
-
337
- # self._save_model(base_dir, epoch, valid_loss)
338
-
339
- # def _save_model(self, base_dir, info, valid_loss):
340
- # """
341
- # Save a copy of the model with format:
342
- # model_{info}_{valid_loss}
343
- # """
344
- # base_name = f'model_{info}_{valid_loss:.3f}'
345
- # logger.info(base_name)
346
- # save_model(self.model_trainer.model, base_dir, base_name)
347
-
348
- # def _check_early_stopping_train_loss(self, avg_train_loss):
349
- # """
350
- # This function checks whether the training has exploded by verifying if the avg training loss
351
- # is more than 10 times the minimal loss so far.
352
-
353
- # If this is the case, a EarlyStopNecessary exception is raised.
354
- # """
355
- # threshold = 10 * self.min_avg_train_loss
356
- # if avg_train_loss > threshold:
357
- # raise _ModelTrainingRound.EarlyStopNecessary()
358
-
359
- # # update the min train loss if necessary
360
- # if avg_train_loss < self.min_avg_train_loss:
361
- # self.min_avg_train_loss = avg_train_loss
362
-
363
- # def _check_early_stopping_validation(self, avg_valid_loss):
364
- # """
365
- # This function checks whether the training has exploded by verifying if the validation loss
366
- # has more than doubled compared to the minimum validation loss so far.
367
-
368
- # If this is the case, a EarlyStopNecessary exception is raised.
369
- # """
370
- # threshold = 2 * self.min_valid_loss
371
- # if avg_valid_loss > threshold:
372
- # raise _ModelTrainingRound.EarlyStopNecessary()
373
-
374
- # if avg_valid_loss < self.min_valid_loss:
375
- # self.min_valid_loss = avg_valid_loss
59
+ return loss
0 commit comments