1
+ import math
2
+ import torch
3
+ from torch .nn import Parameter
4
+ import torch .nn .functional as F
5
+ from typing import Any
6
+ from .utils import mask_adjs , mask_x
7
+
8
+ def glorot (tensor ):
9
+ if tensor is not None :
10
+ stdv = math .sqrt (6.0 / (tensor .size (- 2 ) + tensor .size (- 1 )))
11
+ tensor .data .uniform_ (- stdv , stdv )
12
+
13
+ def zeros (tensor ):
14
+ if tensor is not None :
15
+ tensor .data .fill_ (0 )
16
+
17
+ def reset (value : Any ):
18
+ if hasattr (value , 'reset_parameters' ):
19
+ value .reset_parameters ()
20
+ else :
21
+ for child in value .children () if hasattr (value , 'children' ) else []:
22
+ reset (child )
23
+
24
+ # -------- GCN layer --------
25
+ class DenseGCNConv (torch .nn .Module ):
26
+ r"""See :class:`torch_geometric.nn.conv.GCNConv`.
27
+ """
28
+ def __init__ (self , in_channels , out_channels , improved = False , bias = True ):
29
+ super (DenseGCNConv , self ).__init__ ()
30
+
31
+ self .in_channels = in_channels
32
+ self .out_channels = out_channels
33
+ self .improved = improved
34
+
35
+ self .weight = Parameter (torch .Tensor (self .in_channels , out_channels ))
36
+
37
+ if bias :
38
+ self .bias = Parameter (torch .Tensor (out_channels ))
39
+ else :
40
+ self .register_parameter ('bias' , None )
41
+
42
+ self .reset_parameters ()
43
+
44
+ def reset_parameters (self ):
45
+ glorot (self .weight )
46
+ zeros (self .bias )
47
+
48
+
49
+ def forward (self , x , adj , mask = None , add_loop = True ):
50
+ r"""
51
+ Args:
52
+ x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
53
+ \times N \times F}`, with batch-size :math:`B`, (maximum)
54
+ number of nodes :math:`N` for each graph, and feature
55
+ dimension :math:`F`.
56
+ adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
57
+ \times N \times N}`. The adjacency tensor is broadcastable in
58
+ the batch dimension, resulting in a shared adjacency matrix for
59
+ the complete batch.
60
+ mask (BoolTensor, optional): Mask matrix
61
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
62
+ the valid nodes for each graph. (default: :obj:`None`)
63
+ add_loop (bool, optional): If set to :obj:`False`, the layer will
64
+ not automatically add self-loops to the adjacency matrices.
65
+ (default: :obj:`True`)
66
+ """
67
+ x = x .unsqueeze (0 ) if x .dim () == 2 else x
68
+ adj = adj .unsqueeze (0 ) if adj .dim () == 2 else adj
69
+ B , N , _ = adj .size ()
70
+
71
+ if add_loop :
72
+ adj = adj .clone ()
73
+ idx = torch .arange (N , dtype = torch .long , device = adj .device )
74
+ adj [:, idx , idx ] = 1 if not self .improved else 2
75
+
76
+ out = torch .matmul (x , self .weight )
77
+ deg_inv_sqrt = adj .sum (dim = - 1 ).clamp (min = 1 ).pow (- 0.5 )
78
+
79
+ adj = deg_inv_sqrt .unsqueeze (- 1 ) * adj * deg_inv_sqrt .unsqueeze (- 2 )
80
+ out = torch .matmul (adj , out )
81
+
82
+ if self .bias is not None :
83
+ out = out + self .bias
84
+
85
+ if mask is not None :
86
+ out = out * mask .view (B , N , 1 ).to (x .dtype )
87
+
88
+ return out
89
+
90
+
91
+ def __repr__ (self ):
92
+ return '{}({}, {})' .format (self .__class__ .__name__ , self .in_channels ,
93
+ self .out_channels )
94
+
95
+ # -------- MLP layer --------
96
+ class MLP (torch .nn .Module ):
97
+ def __init__ (self , num_layers , input_dim , hidden_dim , output_dim , use_bn = False , activate_func = F .relu ):
98
+ """
99
+ num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
100
+ input_dim: dimensionality of input features
101
+ hidden_dim: dimensionality of hidden units at ALL layers
102
+ output_dim: number of classes for prediction
103
+ num_classes: the number of classes of input, to be treated with different gains and biases,
104
+ (see the definition of class `ConditionalLayer1d`)
105
+ """
106
+
107
+ super (MLP , self ).__init__ ()
108
+
109
+ self .linear_or_not = True # default is linear model
110
+ self .num_layers = num_layers
111
+ self .use_bn = use_bn
112
+ self .activate_func = activate_func
113
+
114
+ if num_layers < 1 :
115
+ raise ValueError ("number of layers should be positive!" )
116
+ elif num_layers == 1 :
117
+ # Linear model
118
+ self .linear = torch .nn .Linear (input_dim , output_dim )
119
+ else :
120
+ # Multi-layer model
121
+ self .linear_or_not = False
122
+ self .linears = torch .nn .ModuleList ()
123
+
124
+ self .linears .append (torch .nn .Linear (input_dim , hidden_dim ))
125
+ for layer in range (num_layers - 2 ):
126
+ self .linears .append (torch .nn .Linear (hidden_dim , hidden_dim ))
127
+ self .linears .append (torch .nn .Linear (hidden_dim , output_dim ))
128
+
129
+ if self .use_bn :
130
+ self .batch_norms = torch .nn .ModuleList ()
131
+ for layer in range (num_layers - 1 ):
132
+ self .batch_norms .append (torch .nn .BatchNorm1d (hidden_dim ))
133
+
134
+
135
+ def forward (self , x ):
136
+ """
137
+ :param x: [num_classes * batch_size, N, F_i], batch of node features
138
+ note that in self.cond_layers[layer],
139
+ `x` is splited into `num_classes` groups in dim=0,
140
+ and then treated with different gains and biases
141
+ """
142
+ if self .linear_or_not :
143
+ # If linear model
144
+ return self .linear (x )
145
+ else :
146
+ # If MLP
147
+ h = x
148
+ for layer in range (self .num_layers - 1 ):
149
+ h = self .linears [layer ](h )
150
+ if self .use_bn :
151
+ h = self .batch_norms [layer ](h )
152
+ h = self .activate_func (h )
153
+ return self .linears [self .num_layers - 1 ](h )
154
+
155
+
156
+ # -------- Graph Multi-Head Attention (GMH) --------
157
+ # -------- From Baek et al. (2021) --------
158
+ class Attention (torch .nn .Module ):
159
+ def __init__ (self , in_dim , attn_dim , out_dim , num_heads = 4 , conv = 'GCN' ):
160
+ super (Attention , self ).__init__ ()
161
+ self .num_heads = num_heads
162
+ self .attn_dim = attn_dim
163
+ self .out_dim = out_dim
164
+ self .conv = conv
165
+
166
+ self .gnn_q , self .gnn_k , self .gnn_v = self .get_gnn (in_dim , attn_dim , out_dim , conv )
167
+ self .activation = torch .tanh
168
+ self .softmax_dim = 2
169
+
170
+ def forward (self , x , adj , flags , attention_mask = None ):
171
+ if self .conv == 'GCN' :
172
+ Q = self .gnn_q (x , adj )
173
+ K = self .gnn_k (x , adj )
174
+ else :
175
+ Q = self .gnn_q (x )
176
+ K = self .gnn_k (x )
177
+
178
+ V = self .gnn_v (x , adj )
179
+ dim_split = self .attn_dim // self .num_heads
180
+ Q_ = torch .cat (Q .split (dim_split , 2 ), 0 )
181
+ K_ = torch .cat (K .split (dim_split , 2 ), 0 )
182
+
183
+ if attention_mask is not None :
184
+ attention_mask = torch .cat ([attention_mask for _ in range (self .num_heads )], 0 )
185
+ attention_score = Q_ .bmm (K_ .transpose (1 ,2 ))/ math .sqrt (self .out_dim )
186
+ A = self .activation ( attention_mask + attention_score )
187
+ else :
188
+ A = self .activation ( Q_ .bmm (K_ .transpose (1 ,2 ))/ math .sqrt (self .out_dim ) ) # (B x num_heads) x N x N
189
+
190
+ # -------- (B x num_heads) x N x N --------
191
+ A = A .view (- 1 , * adj .shape )
192
+ A = A .mean (dim = 0 )
193
+ A = (A + A .transpose (- 1 ,- 2 ))/ 2
194
+
195
+ return V , A
196
+
197
+ def get_gnn (self , in_dim , attn_dim , out_dim , conv = 'GCN' ):
198
+
199
+ if conv == 'GCN' :
200
+ gnn_q = DenseGCNConv (in_dim , attn_dim )
201
+ gnn_k = DenseGCNConv (in_dim , attn_dim )
202
+ gnn_v = DenseGCNConv (in_dim , out_dim )
203
+
204
+ return gnn_q , gnn_k , gnn_v
205
+
206
+ elif conv == 'MLP' :
207
+ num_layers = 2
208
+ gnn_q = MLP (num_layers , in_dim , 2 * attn_dim , attn_dim , activate_func = torch .tanh )
209
+ gnn_k = MLP (num_layers , in_dim , 2 * attn_dim , attn_dim , activate_func = torch .tanh )
210
+ gnn_v = DenseGCNConv (in_dim , out_dim )
211
+
212
+ return gnn_q , gnn_k , gnn_v
213
+
214
+ else :
215
+ raise NotImplementedError (f'{ conv } not implemented.' )
216
+
217
+
218
+ # -------- Layer of ScoreNetworkA --------
219
+ class AttentionLayer (torch .nn .Module ):
220
+ def __init__ (self , num_linears , conv_input_dim , attn_dim , conv_output_dim , input_dim , output_dim ,
221
+ num_heads = 4 , conv = 'GCN' ):
222
+ super (AttentionLayer , self ).__init__ ()
223
+ self .attn = torch .nn .ModuleList ()
224
+ for _ in range (input_dim ):
225
+ self .attn_dim = attn_dim
226
+ self .attn .append (Attention (conv_input_dim , self .attn_dim , conv_output_dim ,
227
+ num_heads = num_heads , conv = conv ))
228
+
229
+ self .hidden_dim = 2 * max (input_dim , output_dim )
230
+ self .mlp = MLP (num_linears , 2 * input_dim , self .hidden_dim , output_dim , use_bn = False , activate_func = F .elu )
231
+ self .multi_channel = MLP (2 , input_dim * conv_output_dim , self .hidden_dim , conv_output_dim ,
232
+ use_bn = False , activate_func = F .elu )
233
+
234
+ def forward (self , x , adj , flags ):
235
+ """
236
+
237
+ :param x: B x N x F_i
238
+ :param adj: B x C_i x N x N
239
+ :return: x_out: B x N x F_o, adj_out: B x C_o x N x N
240
+ """
241
+ mask_list = []
242
+ x_list = []
243
+ for _ in range (len (self .attn )):
244
+ _x , mask = self .attn [_ ](x , adj [:,_ ,:,:], flags )
245
+ mask_list .append (mask .unsqueeze (- 1 ))
246
+ x_list .append (_x )
247
+ x_out = mask_x (self .multi_channel (torch .cat (x_list , dim = - 1 )), flags )
248
+ x_out = torch .tanh (x_out )
249
+
250
+ mlp_in = torch .cat ([torch .cat (mask_list , dim = - 1 ), adj .permute (0 ,2 ,3 ,1 )], dim = - 1 )
251
+ shape = mlp_in .shape
252
+ mlp_out = self .mlp (mlp_in .view (- 1 , shape [- 1 ]))
253
+ _adj = mlp_out .view (shape [0 ], shape [1 ], shape [2 ], - 1 ).permute (0 ,3 ,1 ,2 )
254
+ _adj = _adj + _adj .transpose (- 1 ,- 2 )
255
+ adj_out = mask_adjs (_adj , flags )
256
+
257
+ return x_out , adj_out
0 commit comments