|
4 | 4 | from torch_geometric.nn.norm import GraphNorm, PairNorm, DiffGroupNorm, InstanceNorm, LayerNorm, GraphSizeNorm
|
5 | 5 | from torch_geometric.nn import MessagePassing
|
6 | 6 | from torch_geometric.nn import global_add_pool
|
7 |
| -from torch_scatter import scatter_min, scatter_max, scatter_mean,scatter_add |
| 7 | +try: |
| 8 | + from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_add |
| 9 | + _has_torch_scatter = True |
| 10 | +except ImportError: |
| 11 | + scatter_min = scatter_max = scatter_mean = scatter_add = None |
| 12 | + _has_torch_scatter = False |
8 | 13 |
|
9 | 14 | from ..utils import get_atom_feature_dims, get_bond_feature_dims
|
10 | 15 |
|
@@ -192,6 +197,9 @@ def message(self, x_j, edge_attr):
|
192 | 197 | return self.f_agg(torch.cat([x_j, edge_attr], dim=1))
|
193 | 198 |
|
194 | 199 | def aggregate(self, inputs, index, dim_size=None):
|
| 200 | + if not _has_torch_scatter or scatter_min is None: |
| 201 | + raise ImportError("BFGNN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.") |
| 202 | + |
195 | 203 | out, _ = scatter_min(inputs, index, dim=0, dim_size=dim_size)
|
196 | 204 | out[out == float('inf')] = 0.0
|
197 | 205 | return out
|
@@ -240,6 +248,9 @@ def message(self, x_j, edge_attr, norm):
|
240 | 248 | return norm.view(-1, 1) * F.relu(m)
|
241 | 249 |
|
242 | 250 | def aggregate(self, inputs, index, dim_size=None):
|
| 251 | + if not _has_torch_scatter or scatter_min is None: |
| 252 | + raise ImportError("BFGNN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.") |
| 253 | + |
243 | 254 | out, _ = scatter_min(inputs, index, dim=0, dim_size=dim_size)
|
244 | 255 | out[out == float('inf')] = 0.0
|
245 | 256 | return out
|
@@ -285,6 +296,9 @@ def message(self, x_j, edge_attr):
|
285 | 296 | return F.relu(x_j + edge_attr)
|
286 | 297 |
|
287 | 298 | def aggregate(self, inputs, index, dim_size=None):
|
| 299 | + if not _has_torch_scatter or scatter_max is None: |
| 300 | + raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.") |
| 301 | + |
288 | 302 | out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
|
289 | 303 | out[out == float('inf')] = 0.0
|
290 | 304 | return out
|
@@ -339,6 +353,9 @@ def message(self, x_j, edge_attr, norm):
|
339 | 353 | return norm.view(-1,1) * F.relu(x_j + edge_attr)
|
340 | 354 |
|
341 | 355 | def aggregate(self, inputs, index, dim_size=None):
|
| 356 | + if not _has_torch_scatter or scatter_max is None: |
| 357 | + raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.") |
| 358 | + |
342 | 359 | out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
|
343 | 360 | out[out == float('inf')] = 0.0
|
344 | 361 | return out
|
|
0 commit comments