Skip to content

Commit 612d53e

Browse files
committed
deploy: 0e8c062
1 parent d46f365 commit 612d53e

File tree

12 files changed

+820
-22
lines changed

12 files changed

+820
-22
lines changed

_modules/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ <h1>All modules for which code is available</h1>
267267
<li><a href="torch_molecule/encoder/attrmask/modeling_attrmask.html">torch_molecule.encoder.attrmask.modeling_attrmask</a></li>
268268
<li><a href="torch_molecule/encoder/contextpred/modeling_contextpred.html">torch_molecule.encoder.contextpred.modeling_contextpred</a></li>
269269
<li><a href="torch_molecule/encoder/edgepred/modeling_edgepred.html">torch_molecule.encoder.edgepred.modeling_edgepred</a></li>
270+
<li><a href="torch_molecule/encoder/graphmae/modeling_graphmae.html">torch_molecule.encoder.graphmae.modeling_graphmae</a></li>
270271
<li><a href="torch_molecule/encoder/infograph/modeling_infograph.html">torch_molecule.encoder.infograph.modeling_infograph</a></li>
271272
<li><a href="torch_molecule/encoder/moama/modeling_moama.html">torch_molecule.encoder.moama.modeling_moama</a></li>
272273
<li><a href="torch_molecule/encoder/pretrained/modeling_pretrained.html">torch_molecule.encoder.pretrained.modeling_pretrained</a></li>

_modules/torch_molecule/encoder/graphmae/modeling_graphmae.html

Lines changed: 642 additions & 0 deletions
Large diffs are not rendered by default.

_modules/torch_molecule/nn/gnn.html

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,14 +383,15 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
383383
<div class="viewcode-block" id="GINConv">
384384
<a class="viewcode-back" href="../../../api/nn.html#torch_molecule.nn.gnn.GINConv">[docs]</a>
385385
<span class="k">class</span><span class="w"> </span><span class="nc">GINConv</span><span class="p">(</span><span class="n">MessagePassing</span><span class="p">):</span>
386-
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">):</span>
386+
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
387387
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
388388
<span class="sd"> hidden_size (int): node embedding dimensionality</span>
389389
<span class="sd"> &#39;&#39;&#39;</span>
390-
391390
<span class="nb">super</span><span class="p">(</span><span class="n">GINConv</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">aggr</span> <span class="o">=</span> <span class="s2">&quot;add&quot;</span><span class="p">)</span>
391+
<span class="k">if</span> <span class="n">output_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
392+
<span class="n">output_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
392393

393-
<span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
394+
<span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">))</span>
394395
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span>
395396

396397
<span class="bp">self</span><span class="o">.</span><span class="n">bond_encoder</span> <span class="o">=</span> <span class="n">BondEncoder</span><span class="p">(</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span><span class="p">)</span>
@@ -420,12 +421,14 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
420421
<div class="viewcode-block" id="GCNConv">
421422
<a class="viewcode-back" href="../../../api/nn.html#torch_molecule.nn.gnn.GCNConv">[docs]</a>
422423
<span class="k">class</span><span class="w"> </span><span class="nc">GCNConv</span><span class="p">(</span><span class="n">MessagePassing</span><span class="p">):</span>
423-
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">):</span>
424+
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
424425
<span class="nb">super</span><span class="p">(</span><span class="n">GCNConv</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">aggr</span><span class="o">=</span><span class="s1">&#39;add&#39;</span><span class="p">)</span>
426+
<span class="k">if</span> <span class="n">output_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
427+
<span class="n">output_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
425428

426-
<span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
427-
<span class="bp">self</span><span class="o">.</span><span class="n">root_emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
428-
<span class="bp">self</span><span class="o">.</span><span class="n">bond_encoder</span> <span class="o">=</span> <span class="n">BondEncoder</span><span class="p">(</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span><span class="p">)</span>
429+
<span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">)</span>
430+
<span class="bp">self</span><span class="o">.</span><span class="n">root_emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">output_size</span><span class="p">)</span>
431+
<span class="bp">self</span><span class="o">.</span><span class="n">bond_encoder</span> <span class="o">=</span> <span class="n">BondEncoder</span><span class="p">(</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">output_size</span><span class="p">)</span>
429432

430433
<div class="viewcode-block" id="GCNConv.forward">
431434
<a class="viewcode-back" href="../../../api/nn.html#torch_molecule.nn.gnn.GCNConv.forward">[docs]</a>
@@ -434,15 +437,14 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
434437
<span class="n">edge_embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bond_encoder</span><span class="p">(</span><span class="n">edge_attr</span><span class="p">)</span>
435438

436439
<span class="n">row</span><span class="p">,</span> <span class="n">col</span> <span class="o">=</span> <span class="n">edge_index</span>
437-
438440
<span class="c1">#edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)</span>
439441
<span class="n">deg</span> <span class="o">=</span> <span class="n">degree</span><span class="p">(</span><span class="n">row</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">dtype</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span>
440442
<span class="n">deg_inv_sqrt</span> <span class="o">=</span> <span class="n">deg</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">)</span>
441443
<span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">deg_inv_sqrt</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>
442444

443445
<span class="n">norm</span> <span class="o">=</span> <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">col</span><span class="p">]</span>
444446

445-
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_attr</span> <span class="o">=</span> <span class="n">edge_embedding</span><span class="p">,</span> <span class="n">norm</span><span class="o">=</span><span class="n">norm</span><span class="p">)</span> <span class="o">+</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">root_emb</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.</span><span class="o">/</span><span class="n">deg</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span></div>
447+
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_attr</span><span class="o">=</span><span class="n">edge_embedding</span><span class="p">,</span> <span class="n">norm</span><span class="o">=</span><span class="n">norm</span><span class="p">)</span> <span class="o">+</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">root_emb</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.</span><span class="o">/</span><span class="n">deg</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span></div>
446448

447449

448450
<div class="viewcode-block" id="GCNConv.message">

_modules/torch_molecule/utils/graph/features.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@
262262
<h1>Source code for torch_molecule.utils.graph.features</h1><div class="highlight"><pre>
263263
<span></span><span class="c1"># allowable multiple choice node and edge features </span>
264264
<span class="n">allowable_features</span> <span class="o">=</span> <span class="p">{</span>
265+
<span class="c1"># atom types: 1-118, 119 is masked atom, 120 is misc (e.g. * for polymers)</span>
266+
<span class="c1"># index: 0-117, 118, 119</span>
265267
<span class="s1">&#39;possible_atomic_num_list&#39;</span> <span class="p">:</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">120</span><span class="p">))</span> <span class="o">+</span> <span class="p">[</span><span class="s1">&#39;misc&#39;</span><span class="p">],</span>
266268
<span class="s1">&#39;possible_chirality_list&#39;</span> <span class="p">:</span> <span class="p">[</span>
267269
<span class="s1">&#39;CHI_UNSPECIFIED&#39;</span><span class="p">,</span>

_modules/torch_molecule/utils/graph/graph_from_smiles.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ <h1>Source code for torch_molecule.utils.graph.graph_from_smiles</h1><div class=
331331
<span class="c1"># atoms</span>
332332
<span class="n">atom_features_list</span> <span class="o">=</span> <span class="p">[]</span>
333333
<span class="k">for</span> <span class="n">atom</span> <span class="ow">in</span> <span class="n">mol</span><span class="o">.</span><span class="n">GetAtoms</span><span class="p">():</span>
334+
<span class="c1"># print(atom.GetSymbol(), atom_to_feature_vector(atom)[0])</span>
334335
<span class="n">atom_features_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">atom_to_feature_vector</span><span class="p">(</span><span class="n">atom</span><span class="p">))</span>
335336

336337
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">atom_features_list</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>

_sources/api/encoder.rst.txt

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,31 @@ Self-supervised Molecular Representation Learning
4242
:undoc-members:
4343
:show-inheritance:
4444

45-
.. rubric:: Context Prediction for Molecular Representation Learning
45+
.. rubric:: Graph masked autoencoder
46+
47+
.. autoclass:: torch_molecule.encoder.graphmae.modeling_graphmae.GraphMAEMolecularEncoder
48+
:members: fit, encode
49+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
50+
:undoc-members:
51+
:show-inheritance:
52+
53+
.. rubric:: Context Prediction
4654

4755
.. autoclass:: torch_molecule.encoder.contextpred.modeling_contextpred.ContextPredMolecularEncoder
4856
:members: fit, encode
4957
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5058
:undoc-members:
5159
:show-inheritance:
5260

53-
.. rubric:: Edge Prediction for Molecular Representation Learning
61+
.. rubric:: Edge Prediction
5462

5563
.. autoclass:: torch_molecule.encoder.edgepred.modeling_edgepred.EdgePredMolecularEncoder
5664
:members: fit, encode
5765
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5866
:undoc-members:
5967
:show-inheritance:
6068

61-
.. rubric:: InfoGraph for Molecular Representation Learning
69+
.. rubric:: InfoGraph
6270

6371
.. autoclass:: torch_molecule.encoder.infograph.modeling_infograph.InfoGraphMolecularEncoder
6472
:members: fit, encode
@@ -69,7 +77,7 @@ Self-supervised Molecular Representation Learning
6977
Supervised Pretraining for Molecules
7078
------------------------------------
7179

72-
.. rubric:: Supervised/Pseudolabeled Pretraining for Molecules
80+
.. rubric:: Pretraining with Supervised/Pseudolabeled Data
7381
.. autoclass:: torch_molecule.encoder.supervised.modeling_supervised.SupervisedMolecularEncoder
7482
:members: fit, encode
7583
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class

0 commit comments

Comments
 (0)