Skip to content

Commit a7473b4

Browse files
committed
revise import issue for torch-scatter in grin and bfgnn
1 parent aab23bc commit a7473b4

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
| Model | Required Packages |
5757
|-------|-------------------|
5858
| HFPretrainedMolecularEncoder | transformers |
59+
| BFGNNMolecularPredictor | torch-scatter |
60+
| GRINMolecularPredictor | torch-scatter |
5961

6062
## Usage
6163

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
# TODO
2-
3-
# safe-100m: https://huggingface.co/anrilombard/safe-100m
1+
# TODO

torch_molecule/nn/gnn.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from torch_geometric.nn.norm import GraphNorm, PairNorm, DiffGroupNorm, InstanceNorm, LayerNorm, GraphSizeNorm
55
from torch_geometric.nn import MessagePassing
66
from torch_geometric.nn import global_add_pool
7-
from torch_scatter import scatter_min, scatter_max, scatter_mean,scatter_add
7+
try:
8+
from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_add
9+
_has_torch_scatter = True
10+
except ImportError:
11+
scatter_min = scatter_max = scatter_mean = scatter_add = None
12+
_has_torch_scatter = False
813

914
from ..utils import get_atom_feature_dims, get_bond_feature_dims
1015

@@ -192,6 +197,9 @@ def message(self, x_j, edge_attr):
192197
return self.f_agg(torch.cat([x_j, edge_attr], dim=1))
193198

194199
def aggregate(self, inputs, index, dim_size=None):
200+
if not _has_torch_scatter or scatter_min is None:
201+
raise ImportError("BFGNN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")
202+
195203
out, _ = scatter_min(inputs, index, dim=0, dim_size=dim_size)
196204
out[out == float('inf')] = 0.0
197205
return out
@@ -240,6 +248,9 @@ def message(self, x_j, edge_attr, norm):
240248
return norm.view(-1, 1) * F.relu(m)
241249

242250
def aggregate(self, inputs, index, dim_size=None):
251+
if not _has_torch_scatter or scatter_min is None:
252+
raise ImportError("BFGNN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")
253+
243254
out, _ = scatter_min(inputs, index, dim=0, dim_size=dim_size)
244255
out[out == float('inf')] = 0.0
245256
return out
@@ -285,6 +296,9 @@ def message(self, x_j, edge_attr):
285296
return F.relu(x_j + edge_attr)
286297

287298
def aggregate(self, inputs, index, dim_size=None):
299+
if not _has_torch_scatter or scatter_max is None:
300+
raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")
301+
288302
out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
289303
out[out == float('inf')] = 0.0
290304
return out
@@ -339,6 +353,9 @@ def message(self, x_j, edge_attr, norm):
339353
return norm.view(-1,1) * F.relu(x_j + edge_attr)
340354

341355
def aggregate(self, inputs, index, dim_size=None):
356+
if not _has_torch_scatter or scatter_max is None:
357+
raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")
358+
342359
out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
343360
out[out == float('inf')] = 0.0
344361
return out

0 commit comments

Comments
 (0)