@@ -89,7 +89,8 @@ aggregate_neighbors(l::WithGraph, args...) = aggregate_neighbors(l.layer, l.grap
89
89
update_batch_vertex (l:: WithGraph , args... ) = update_batch_vertex (l. layer, l. graph, args... )
90
90
91
91
"""
92
- GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity)
92
+ GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity,
93
+ positional_layer=identity)
93
94
94
95
Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input
95
96
and it can be specified by assigning layers for specific (node, edge and global) features.
@@ -99,6 +100,7 @@ and it can be specified by assigning layers for specific (node, edge and global)
99
100
- `node_layer`: A regular Flux layer for passing node features.
100
101
- `edge_layer`: A regular Flux layer for passing edge features.
101
102
- `global_layer`: A regular Flux layer for passing global features.
103
+ - `positional_layer`: A regular Flux layer for passing positional features.
102
104
103
105
# Example
104
106
@@ -109,32 +111,36 @@ julia> l = GraphParallel(
109
111
node_layer=Dropout(0.5),
110
112
global_layer=Dense(10, 5)
111
113
)
112
- GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10 => 5))
114
+ GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10 => 5), positional_layer=identity )
113
115
```
114
116
"""
115
- struct GraphParallel{N,E,G}
117
+ struct GraphParallel{N,E,G,P }
116
118
node_layer:: N
117
119
edge_layer:: E
118
120
global_layer:: G
121
+ positional_layer:: P
119
122
end
120
123
121
124
@functor GraphParallel
122
125
123
- GraphParallel (; node_layer= identity, edge_layer= identity, global_layer= identity) =
124
- GraphParallel (node_layer, edge_layer, global_layer)
126
+ GraphParallel (; node_layer= identity, edge_layer= identity, global_layer= identity,
127
+ positional_layer= identity) =
128
+ GraphParallel (node_layer, edge_layer, global_layer, positional_layer)
125
129
126
130
function (l:: GraphParallel )(fg:: AbstractFeaturedGraph )
127
131
nf = l. node_layer (node_feature (fg))
128
132
ef = l. edge_layer (edge_feature (fg))
129
133
gf = l. global_layer (global_feature (fg))
130
- return ConcreteFeaturedGraph (fg, nf= nf, ef= ef, gf= gf)
134
+ pf = l. positional_layer (positional_feature (fg))
135
+ return ConcreteFeaturedGraph (fg, nf= nf, ef= ef, gf= gf, pf= pf)
131
136
end
132
137
133
138
function Base. show (io:: IO , l:: GraphParallel )
134
139
print (io, " GraphParallel(" )
135
140
print (io, " node_layer=" , l. node_layer)
136
141
print (io, " , edge_layer=" , l. edge_layer)
137
142
print (io, " , global_layer=" , l. global_layer)
143
+ print (io, " , positional_layer=" , l. positional_layer)
138
144
print (io, " )" )
139
145
end
140
146
0 commit comments