|
1 | 1 | from tqdm import tqdm
|
2 | 2 | from dataclasses import dataclass, field
|
3 |
| -from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Type |
| 3 | +from typing import Optional, Dict, Any, Tuple, List, Type |
4 | 4 | import numpy as np
|
5 | 5 |
|
6 | 6 | import torch
|
7 |
| -from torch.utils.data import DataLoader, TensorDataset |
8 | 7 |
|
9 | 8 | from .jtnn_vae import JTNNVAE
|
10 | 9 | from .jtnn.mol_tree import MolTree
|
|
15 | 14 | @dataclass
|
16 | 15 | class JTVAEMolecularGenerator(BaseMolecularGenerator):
|
17 | 16 | """
|
18 |
| - JT-VAE-based molecular generator. |
19 |
| - |
20 |
| - This generator implements a JT-VAE architecture for molecular generation. |
21 |
| - |
| 17 | + JT-VAE-based molecular generator. Implemented for unconditional moleculargeneration. |
22 | 18 |
|
23 | 19 | References
|
24 | 20 | ----------
|
25 | 21 | - Junction Tree Variational Autoencoder for Molecular Graph Generation. ICML 2018. https://arxiv.org/pdf/1802.04364
|
26 | 22 | - Code: https://github.com/kamikaze0923/jtvae
|
27 | 23 |
|
28 |
| -
|
29 | 24 | Parameters
|
30 | 25 | ----------
|
31 |
| - TODO. |
32 |
| -
|
| 26 | + hidden_size : int, default=450 |
| 27 | + Dimension of hidden layers in the model. |
| 28 | + latent_size : int, default=56 |
| 29 | + Dimension of the latent space. |
| 30 | + depthT : int, default=20 |
| 31 | + Depth of the tree encoder. |
| 32 | + depthG : int, default=3 |
| 33 | + Depth of the graph decoder. |
| 34 | + batch_size : int, default=32 |
| 35 | + Number of samples per batch during training. |
| 36 | + epochs : int, default=20 |
| 37 | + Number of epochs to train the model. |
| 38 | + learning_rate : float, default=0.003 |
| 39 | + Initial learning rate for the optimizer. |
| 40 | + weight_decay : float, default=0.0 |
| 41 | + L2 regularization factor. |
| 42 | + grad_norm_clip : Optional[float], default=None |
| 43 | + Maximum norm for gradient clipping. None means no clipping. |
| 44 | + beta : float, default=0.0 |
| 45 | + Initial KL divergence weight for VAE training. |
| 46 | + step_beta : float, default=0.002 |
| 47 | + Step size for KL annealing. |
| 48 | + max_beta : float, default=1.0 |
| 49 | + Maximum value for KL weight. |
| 50 | + warmup : int, default=40000 |
| 51 | + Number of steps for KL annealing warmup. |
| 52 | + use_lr_scheduler : bool, default=True |
| 53 | + Whether to use learning rate scheduling. |
| 54 | + anneal_rate : float, default=0.9 |
| 55 | + Learning rate annealing factor. |
| 56 | + anneal_iter : int, default=40000 |
| 57 | + Number of iterations between learning rate updates. |
| 58 | + kl_anneal_iter : int, default=2000 |
| 59 | + Number of iterations between KL weight updates. |
| 60 | + verbose : bool, default=False |
| 61 | + Whether to print detailed training information. |
33 | 62 | """
|
34 | 63 | # Model parameters
|
35 | 64 | hidden_size: int = 450
|
|
0 commit comments