Skip to content

Commit 4d02389

Browse files
authored
Merge pull request #319 from FluxML/develop
GraphParallel support positional_layer
2 parents 14796f6 + ab642f6 commit 4d02389

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/layers/graphlayers.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ aggregate_neighbors(l::WithGraph, args...) = aggregate_neighbors(l.layer, l.grap
8989
update_batch_vertex(l::WithGraph, args...) = update_batch_vertex(l.layer, l.graph, args...)
9090

9191
"""
92-
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity)
92+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity,
93+
positional_layer=identity)
9394
9495
Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input
9596
and it can be specified by assigning layers for specific (node, edge and global) features.
@@ -99,6 +100,7 @@ and it can be specified by assigning layers for specific (node, edge and global)
99100
- `node_layer`: A regular Flux layer for passing node features.
100101
- `edge_layer`: A regular Flux layer for passing edge features.
101102
- `global_layer`: A regular Flux layer for passing global features.
103+
- `positional_layer`: A regular Flux layer for passing positional features.
102104
103105
# Example
104106
@@ -109,32 +111,36 @@ julia> l = GraphParallel(
109111
node_layer=Dropout(0.5),
110112
global_layer=Dense(10, 5)
111113
)
112-
GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10 => 5))
114+
GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10 => 5), positional_layer=identity)
113115
```
114116
"""
115-
struct GraphParallel{N,E,G}
117+
struct GraphParallel{N,E,G,P}
116118
node_layer::N
117119
edge_layer::E
118120
global_layer::G
121+
positional_layer::P
119122
end
120123

121124
@functor GraphParallel
122125

123-
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) =
124-
GraphParallel(node_layer, edge_layer, global_layer)
126+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity,
127+
positional_layer=identity) =
128+
GraphParallel(node_layer, edge_layer, global_layer, positional_layer)
125129

126130
function (l::GraphParallel)(fg::AbstractFeaturedGraph)
127131
nf = l.node_layer(node_feature(fg))
128132
ef = l.edge_layer(edge_feature(fg))
129133
gf = l.global_layer(global_feature(fg))
130-
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
134+
pf = l.positional_layer(positional_feature(fg))
135+
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf, pf=pf)
131136
end
132137

133138
function Base.show(io::IO, l::GraphParallel)
134139
print(io, "GraphParallel(")
135140
print(io, "node_layer=", l.node_layer)
136141
print(io, ", edge_layer=", l.edge_layer)
137142
print(io, ", global_layer=", l.global_layer)
143+
print(io, ", positional_layer=", l.positional_layer)
138144
print(io, ")")
139145
end
140146

0 commit comments

Comments
 (0)