@@ -383,14 +383,15 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
383
383
< div class ="viewcode-block " id ="GINConv ">
384
384
< a class ="viewcode-back " href ="../../../api/nn.html#torch_molecule.nn.gnn.GINConv "> [docs]</ a >
385
385
< span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> GINConv</ span > < span class ="p "> (</ span > < span class ="n "> MessagePassing</ span > < span class ="p "> ):</ span >
386
- < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ):</ span >
386
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size</ span > < span class ="p "> , </ span > < span class =" n " > output_size </ span > < span class =" o " > = </ span > < span class =" kc " > None </ span > < span class =" p " > ):</ span >
387
387
< span class ="w "> </ span > < span class ="sd "> '''</ span >
388
388
< span class ="sd "> hidden_size (int): node embedding dimensionality</ span >
389
389
< span class ="sd "> '''</ span >
390
-
391
390
< span class ="nb "> super</ span > < span class ="p "> (</ span > < span class ="n "> GINConv</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="n "> aggr</ span > < span class ="o "> =</ span > < span class ="s2 "> "add"</ span > < span class ="p "> )</ span >
391
+ < span class ="k "> if</ span > < span class ="n "> output_size</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
392
+ < span class ="n "> output_size</ span > < span class ="o "> =</ span > < span class ="n "> hidden_size</ span >
392
393
393
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> mlp</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Sequential</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> BatchNorm1d</ span > < span class ="p "> (</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> ReLU</ span > < span class ="p "> (),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ))</ span >
394
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> mlp</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Sequential</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> BatchNorm1d</ span > < span class ="p "> (</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> ReLU</ span > < span class ="p "> (),</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="mi "> 2</ span > < span class ="o "> *</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="n "> output_size</ span > < span class ="p "> ))</ span >
394
395
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> eps</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ([</ span > < span class ="mi "> 0</ span > < span class ="p "> ]))</ span >
395
396
396
397
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> bond_encoder</ span > < span class ="o "> =</ span > < span class ="n "> BondEncoder</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="o "> =</ span > < span class ="n "> hidden_size</ span > < span class ="p "> )</ span >
@@ -420,12 +421,14 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
420
421
< div class ="viewcode-block " id ="GCNConv ">
421
422
< a class ="viewcode-back " href ="../../../api/nn.html#torch_molecule.nn.gnn.GCNConv "> [docs]</ a >
422
423
< span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> GCNConv</ span > < span class ="p "> (</ span > < span class ="n "> MessagePassing</ span > < span class ="p "> ):</ span >
423
- < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ):</ span >
424
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size</ span > < span class ="p "> , </ span > < span class =" n " > output_size </ span > < span class =" o " > = </ span > < span class =" kc " > None </ span > < span class =" p " > ):</ span >
424
425
< span class ="nb "> super</ span > < span class ="p "> (</ span > < span class ="n "> GCNConv</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="n "> aggr</ span > < span class ="o "> =</ span > < span class ="s1 "> 'add'</ span > < span class ="p "> )</ span >
426
+ < span class ="k "> if</ span > < span class ="n "> output_size</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
427
+ < span class ="n "> output_size</ span > < span class ="o "> =</ span > < span class ="n "> hidden_size</ span >
425
428
426
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> linear</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size </ span > < span class ="p "> )</ span >
427
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> root_emb</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Embedding</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_size </ span > < span class ="p "> )</ span >
428
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> bond_encoder</ span > < span class ="o "> =</ span > < span class ="n "> BondEncoder</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="o "> =</ span > < span class ="n "> hidden_size </ span > < span class ="p "> )</ span >
429
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> linear</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Linear</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="p "> ,</ span > < span class ="n "> output_size </ span > < span class ="p "> )</ span >
430
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> root_emb</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Embedding</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> output_size </ span > < span class ="p "> )</ span >
431
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> bond_encoder</ span > < span class ="o "> =</ span > < span class ="n "> BondEncoder</ span > < span class ="p "> (</ span > < span class ="n "> hidden_size</ span > < span class ="o "> =</ span > < span class ="n "> output_size </ span > < span class ="p "> )</ span >
429
432
430
433
< div class ="viewcode-block " id ="GCNConv.forward ">
431
434
< a class ="viewcode-back " href ="../../../api/nn.html#torch_molecule.nn.gnn.GCNConv.forward "> [docs]</ a >
@@ -434,15 +437,14 @@ <h1>Source code for torch_molecule.nn.gnn</h1><div class="highlight"><pre>
434
437
< span class ="n "> edge_embedding</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> bond_encoder</ span > < span class ="p "> (</ span > < span class ="n "> edge_attr</ span > < span class ="p "> )</ span >
435
438
436
439
< span class ="n "> row</ span > < span class ="p "> ,</ span > < span class ="n "> col</ span > < span class ="o "> =</ span > < span class ="n "> edge_index</ span >
437
-
438
440
< span class ="c1 "> #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)</ span >
439
441
< span class ="n "> deg</ span > < span class ="o "> =</ span > < span class ="n "> degree</ span > < span class ="p "> (</ span > < span class ="n "> row</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span >
440
442
< span class ="n "> deg_inv_sqrt</ span > < span class ="o "> =</ span > < span class ="n "> deg</ span > < span class ="o "> .</ span > < span class ="n "> pow</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mf "> 0.5</ span > < span class ="p "> )</ span >
441
443
< span class ="n "> deg_inv_sqrt</ span > < span class ="p "> [</ span > < span class ="n "> deg_inv_sqrt</ span > < span class ="o "> ==</ span > < span class ="nb "> float</ span > < span class ="p "> (</ span > < span class ="s1 "> 'inf'</ span > < span class ="p "> )]</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span >
442
444
443
445
< span class ="n "> norm</ span > < span class ="o "> =</ span > < span class ="n "> deg_inv_sqrt</ span > < span class ="p "> [</ span > < span class ="n "> row</ span > < span class ="p "> ]</ span > < span class ="o "> *</ span > < span class ="n "> deg_inv_sqrt</ span > < span class ="p "> [</ span > < span class ="n "> col</ span > < span class ="p "> ]</ span >
444
446
445
- < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> propagate</ span > < span class ="p "> (</ span > < span class ="n "> edge_index</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> edge_attr</ span > < span class ="o "> =</ span > < span class ="n "> edge_embedding</ span > < span class ="p "> ,</ span > < span class ="n "> norm</ span > < span class ="o "> =</ span > < span class ="n "> norm</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="n "> F</ span > < span class ="o "> .</ span > < span class ="n "> relu</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> +</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> root_emb</ span > < span class ="o "> .</ span > < span class ="n "> weight</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="mf "> 1.</ span > < span class ="o "> /</ span > < span class ="n "> deg</ span > < span class ="o "> .</ span > < span class ="n "> view</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > </ div >
447
+ < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> propagate</ span > < span class ="p "> (</ span > < span class ="n "> edge_index</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> edge_attr</ span > < span class ="o "> =</ span > < span class ="n "> edge_embedding</ span > < span class ="p "> ,</ span > < span class ="n "> norm</ span > < span class ="o "> =</ span > < span class ="n "> norm</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="n "> F</ span > < span class ="o "> .</ span > < span class ="n "> relu</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> +</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> root_emb</ span > < span class ="o "> .</ span > < span class ="n "> weight</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="mf "> 1.</ span > < span class ="o "> /</ span > < span class ="n "> deg</ span > < span class ="o "> .</ span > < span class ="n "> view</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > </ div >
446
448
447
449
448
450
< div class ="viewcode-block " id ="GCNConv.message ">
0 commit comments