@@ -117,8 +117,6 @@ def put_data_in_db(gds):
117
117
for rel_type in dataset [rel_split ]:
118
118
edges = dataset [rel_split ][rel_type ]
119
119
120
- # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
121
- # MERGE (n)-[:{rel_split}]->(m)
122
120
gds .run_cypher (
123
121
f"""
124
122
UNWIND $ll as l
@@ -156,33 +154,6 @@ def project_train_graph(gds):
156
154
return G_train
157
155
158
156
159
- def project_predict_graph (gds ):
160
- all_rels = gds .run_cypher (
161
- """
162
- CALL db.relationshipTypes() YIELD relationshipType
163
- """
164
- )
165
- all_rels = all_rels ["relationshipType" ].to_list ()
166
- rel_spec = {}
167
- for rel in all_rels :
168
- if rel .startswith ("REL_" ):
169
- rel_spec [rel ] = {"properties" : ["split" ]}
170
-
171
- gds .graph .drop ("fullGraph" , failIfMissing = False )
172
- gds .graph .drop ("predictGraph" , failIfMissing = False )
173
-
174
- # {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}}
175
- # print(rel_spec)
176
-
177
- G_full , result = gds .graph .project ("fullGraph" , ["Entity" ], all_rels )
178
-
179
- G_full , result = gds .graph .project ("fullGraph" , ["Entity" ], rel_spec )
180
- # G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2')
181
-
182
- inspect_graph (G_full )
183
- return G_full
184
-
185
-
186
157
def inspect_graph (G ):
187
158
func_names = [
188
159
"name" ,
@@ -200,47 +171,43 @@ def inspect_graph(G):
200
171
create_constraint (gds )
201
172
put_data_in_db (gds )
202
173
G_train = project_train_graph (gds )
203
- # G_predict = project_predict_graph(gds)
204
- # inspect_graph(G_train)
205
174
206
175
gds .set_compute_cluster_ip ("localhost" )
207
176
208
177
print (gds .debug .arrow ())
209
178
210
179
model_name = "dummyModelName_" + str (time .time ())
211
180
212
- gds .kge .model .train (
181
+ node_id_text = gds .find_node_id (["Entity" ], {"text" : "/m/016wzw" })
182
+ node_id_2 = gds .find_node_id (["Entity" ], {"id" : 2 })
183
+ node_id_3 = gds .find_node_id (["Entity" ], {"id" : 3 })
184
+ node_id_0 = gds .find_node_id (["Entity" ], {"id" : 0 })
185
+
186
+ res = gds .kge .model .train (
213
187
G_train ,
214
188
model_name = model_name ,
215
- scoring_function = "DistMult " ,
189
+ scoring_function = "distmult " ,
216
190
num_epochs = 1 ,
217
191
embedding_dimension = 10 ,
218
192
epochs_per_checkpoint = 0 ,
219
193
)
194
+ print (res ['metrics' ])
220
195
221
- gds .kge .model .predict (
222
- G_train ,
196
+ res = gds .kge .model .predict (
223
197
model_name = model_name ,
224
198
top_k = 10 ,
225
- node_ids = [1 , 2 , 3 ],
199
+ node_ids = [node_id_3 , node_id_2 , node_id_text ],
226
200
rel_types = ["REL_1" , "REL_2" ],
227
201
)
202
+ print (res .to_string ())
228
203
229
- gds .kge .model .predict_tail (
230
- G_train ,
231
- model_name = model_name ,
232
- top_k = 10 ,
233
- node_ids = [gds .find_node_id (["Entity" ], {"text" : "/m/016wzw" }), gds .find_node_id (["Entity" ], {"id" : 2 })],
234
- rel_types = ["REL_1" , "REL_2" ],
235
- )
236
-
237
- gds .kge .model .score_triples (
238
- G_train ,
204
+ scores = gds .kge .model .score_triplets (
239
205
model_name = model_name ,
240
- triples = [
241
- (gds . find_node_id ([ "Entity" ], { "text" : "/m/016wzw" }), " REL_1" , gds . find_node_id ([ "Entity" ], { "id" : 2 }) ),
242
- (gds . find_node_id ([ "Entity" ], { "id" : 0 }), " REL_123" , gds . find_node_id ([ "Entity" ], { "id" : 3 }) ),
206
+ triplets = [
207
+ (node_id_2 , " REL_1" , node_id_text ),
208
+ (node_id_0 , " REL_123" , node_id_3 ),
243
209
],
244
210
)
211
+ print (scores )
245
212
246
213
print ("Finished training" )
0 commit comments