Skip to content

Commit 71a46dc

Browse files
committed
add doc for jtvae
1 parent d00b71a commit 71a46dc

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

docs/source/api/generator.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ inherited from :class:`torch_molecule.base.base.BaseModel`
2727
- ``save(path, repo_id)``: Save the model to either local storage or Hugging Face
2828
- ``load(path, repo_id)``: Load a model from either local storage or Hugging Face
2929

30-
Modeling Molecules as Graphs with GNN / Transformer-based Generators
30+
Modeling Molecules as Graphs
3131
---------------------------------------------------------------------
3232

3333
.. rubric:: GraphDiT for Un/Multi-conditional Molecular Generation
@@ -46,13 +46,20 @@ Modeling Molecules as Graphs with GNN / Transformer-based Generators
4646

4747
.. rubric:: GDSS for score-based molecular generation
4848
.. autoclass:: torch_molecule.generator.gdss.modeling_gdss.GDSSMolecularGenerator
49-
:exclude-members: fitting_epoch, fitting_loss, save_to_hf, load_from_hf
49+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
50+
:members: fit, generate
51+
:undoc-members:
52+
:show-inheritance:
53+
54+
.. rubric:: JT-VAE for Unconditional Molecular Generation
55+
.. autoclass:: torch_molecule.generator.jtvae.modeling_jtvae.JTVAEMolecularGenerator
56+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5057
:members: fit, generate
5158
:undoc-members:
5259
:show-inheritance:
5360

54-
Modeling Molecules as Graphs with Heuristic-based Generators
55-
------------------------------------------------------------
61+
Modeling Molecules as Graphs with Heuristic Methods
62+
---------------------------------------------------
5663

5764
.. rubric:: Graph Genetic Algorithm for Un/Multi-conditional Molecular Generation
5865
.. autoclass:: torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator

torch_molecule/generator/jtvae/modeling_jtvae.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from tqdm import tqdm
22
from dataclasses import dataclass, field
3-
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Type
3+
from typing import Optional, Dict, Any, Tuple, List, Type
44
import numpy as np
55

66
import torch
7-
from torch.utils.data import DataLoader, TensorDataset
87

98
from .jtnn_vae import JTNNVAE
109
from .jtnn.mol_tree import MolTree
@@ -15,21 +14,51 @@
1514
@dataclass
1615
class JTVAEMolecularGenerator(BaseMolecularGenerator):
1716
"""
18-
JT-VAE-based molecular generator.
19-
20-
This generator implements a JT-VAE architecture for molecular generation.
21-
17+
JT-VAE-based molecular generator. Implemented for unconditional moleculargeneration.
2218
2319
References
2420
----------
2521
- Junction Tree Variational Autoencoder for Molecular Graph Generation. ICML 2018. https://arxiv.org/pdf/1802.04364
2622
- Code: https://github.com/kamikaze0923/jtvae
2723
28-
2924
Parameters
3025
----------
31-
TODO.
32-
26+
hidden_size : int, default=450
27+
Dimension of hidden layers in the model.
28+
latent_size : int, default=56
29+
Dimension of the latent space.
30+
depthT : int, default=20
31+
Depth of the tree encoder.
32+
depthG : int, default=3
33+
Depth of the graph decoder.
34+
batch_size : int, default=32
35+
Number of samples per batch during training.
36+
epochs : int, default=20
37+
Number of epochs to train the model.
38+
learning_rate : float, default=0.003
39+
Initial learning rate for the optimizer.
40+
weight_decay : float, default=0.0
41+
L2 regularization factor.
42+
grad_norm_clip : Optional[float], default=None
43+
Maximum norm for gradient clipping. None means no clipping.
44+
beta : float, default=0.0
45+
Initial KL divergence weight for VAE training.
46+
step_beta : float, default=0.002
47+
Step size for KL annealing.
48+
max_beta : float, default=1.0
49+
Maximum value for KL weight.
50+
warmup : int, default=40000
51+
Number of steps for KL annealing warmup.
52+
use_lr_scheduler : bool, default=True
53+
Whether to use learning rate scheduling.
54+
anneal_rate : float, default=0.9
55+
Learning rate annealing factor.
56+
anneal_iter : int, default=40000
57+
Number of iterations between learning rate updates.
58+
kl_anneal_iter : int, default=2000
59+
Number of iterations between KL weight updates.
60+
verbose : bool, default=False
61+
Whether to print detailed training information.
3362
"""
3463
# Model parameters
3564
hidden_size: int = 450

0 commit comments

Comments
 (0)