Skip to content

Commit b8e7432

Browse files
committed
Fix gds setup in notebook
1 parent edcbf9c commit b8e7432

File tree

1 file changed

+91
-94
lines changed

1 file changed

+91
-94
lines changed

examples/kge-predict-transe-pyg-train.ipynb

Lines changed: 91 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,43 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6+
"metadata": {},
67
"outputs": [],
78
"source": [
89
"%pip install graphdatascience torch torch_geometric"
9-
],
10-
"metadata": {
11-
"collapsed": false
12-
}
10+
]
1311
},
1412
{
1513
"cell_type": "code",
1614
"execution_count": null,
15+
"metadata": {},
1716
"outputs": [],
1817
"source": [
18+
"import os\n",
1919
"from graphdatascience import GraphDataScience\n",
2020
"import torch\n",
2121
"import torch.optim as optim\n",
2222
"from torch_geometric.data import Data, download_url\n",
2323
"from torch_geometric.nn import TransE\n",
2424
"import collections"
25-
],
26-
"metadata": {
27-
"collapsed": false
28-
}
25+
]
2926
},
3027
{
3128
"cell_type": "code",
3229
"execution_count": null,
3330
"metadata": {},
3431
"outputs": [],
3532
"source": [
36-
"gds = GraphDataScience(\"bolt://localhost:7687\", auth=('neo4j', 'neo4jneo4j'), database=\"ttt\")"
33+
"# Get Neo4j DB URI, credentials and name from environment if applicable\n",
34+
"NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n",
35+
"NEO4J_AUTH = None\n",
36+
"NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n",
37+
"if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n",
38+
" NEO4J_AUTH = (\n",
39+
" os.environ.get(\"NEO4J_USER\"),\n",
40+
" os.environ.get(\"NEO4J_PASSWORD\"),\n",
41+
" )\n",
42+
"gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)"
3743
]
3844
},
3945
{
@@ -42,11 +48,11 @@
4248
"metadata": {},
4349
"outputs": [],
4450
"source": [
45-
"url = ('https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237')\n",
46-
"raw_file_names = ['train.txt', 'valid.txt', 'test.txt']\n",
47-
"raw_dir = './data_from_url'\n",
51+
"url = \"https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237\"\n",
52+
"raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n",
53+
"raw_dir = \"./data_from_url\"\n",
4854
"for filename in raw_file_names:\n",
49-
" download_url(f'{url}/{filename}', raw_dir)"
55+
" download_url(f\"{url}/{filename}\", raw_dir)"
5056
]
5157
},
5258
{
@@ -65,20 +71,21 @@
6571
"outputs": [],
6672
"source": [
6773
"rel_types = {\n",
68-
" \"train.txt\":\"TRAIN\",\n",
69-
" \"valid.txt\":\"VALID\",\n",
70-
" \"test.txt\":\"TEST\",\n",
74+
" \"train.txt\": \"TRAIN\",\n",
75+
" \"valid.txt\": \"VALID\",\n",
76+
" \"test.txt\": \"TEST\",\n",
7177
"}\n",
7278
"rel_id_to_text_dict = {}\n",
7379
"rel_type_dict = collections.defaultdict(list)\n",
7480
"\n",
81+
"\n",
7582
"def process():\n",
7683
" node_dict_, rel_dict_ = {}, {}\n",
7784
" for file_name in raw_file_names:\n",
78-
" file_name_path = raw_dir + '/' + file_name\n",
85+
" file_name_path = raw_dir + \"/\" + file_name\n",
7986
"\n",
80-
" with open(file_name_path, 'r') as f:\n",
81-
" data = [x.split('\\t') for x in f.read().split('\\n')[:-1]]\n",
87+
" with open(file_name_path, \"r\") as f:\n",
88+
" data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n",
8289
"\n",
8390
" list_of_dicts = []\n",
8491
" for i, (src, rel, dst) in enumerate(data):\n",
@@ -94,42 +101,45 @@
94101
" target = node_dict_[dst]\n",
95102
" edge_type = rel_dict_[rel]\n",
96103
"\n",
97-
" rel_type_dict[edge_type].append({\n",
98-
" \"source\":source,\n",
99-
" \"target\":target,\n",
100-
" })\n",
101-
" list_of_dicts.append({\n",
102-
" \"source\": source,\n",
103-
" \"source_text\": src,\n",
104-
" \"target\": target,\n",
105-
" \"target_text\": dst,\n",
106-
" \"rel_id\": edge_type,\n",
107-
" })\n",
104+
" rel_type_dict[edge_type].append(\n",
105+
" {\n",
106+
" \"source\": source,\n",
107+
" \"target\": target,\n",
108+
" }\n",
109+
" )\n",
110+
" list_of_dicts.append(\n",
111+
" {\n",
112+
" \"source\": source,\n",
113+
" \"source_text\": src,\n",
114+
" \"target\": target,\n",
115+
" \"target_text\": dst,\n",
116+
" \"rel_id\": edge_type,\n",
117+
" }\n",
118+
" )\n",
108119
"\n",
109120
" rel_type = rel_types[file_name]\n",
110121
" print(f\"Writing {len(list_of_dicts)} entities of {rel_type}\")\n",
111122
" gds.run_cypher(\n",
112-
" \"UNWIND $ll as l \"+\n",
113-
" \"MERGE (n:Entity {id:l.source, text:l.source_text}) \"+\n",
114-
" \"MERGE (m:Entity {id:l.target, text:l.target_text}) \"+\n",
115-
" \"MERGE (n)-[:\"+rel_type+\" {rel_id:l.rel_id}]->(m) \",\n",
116-
" params={\n",
117-
" \"ll\": list_of_dicts\n",
118-
" },\n",
119-
" )\n",
123+
" \"UNWIND $ll as l \"\n",
124+
" + \"MERGE (n:Entity {id:l.source, text:l.source_text}) \"\n",
125+
" + \"MERGE (m:Entity {id:l.target, text:l.target_text}) \"\n",
126+
" + \"MERGE (n)-[:\"\n",
127+
" + rel_type\n",
128+
" + \" {rel_id:l.rel_id}]->(m) \",\n",
129+
" params={\"ll\": list_of_dicts},\n",
130+
" )\n",
120131
"\n",
121132
" for rel_id in rel_type_dict:\n",
122133
" REL_TYPE = f\"REL_{rel_id}\"\n",
123134
" print(f\"Writing {len(rel_type_dict[rel_id])} entities of {REL_TYPE}\")\n",
124135
" gds.run_cypher(\n",
125-
" \"UNWIND $ll AS l MATCH (n:Entity {id:l.source}), (m:Entity {id:l.target}) \"+\n",
126-
" \"MERGE (n)-[:\"+REL_TYPE+\" {rel_id:$rel_id, text:$text}]->(m) \",\n",
127-
" params={\n",
128-
" \"ll\": rel_type_dict[rel_id],\n",
129-
" \"rel_id\": rel_id,\n",
130-
" \"text\": rel_id_to_text_dict[rel_id]\n",
131-
" },\n",
132-
" )\n",
136+
" \"UNWIND $ll AS l MATCH (n:Entity {id:l.source}), (m:Entity {id:l.target}) \"\n",
137+
" + \"MERGE (n)-[:\"\n",
138+
" + REL_TYPE\n",
139+
" + \" {rel_id:$rel_id, text:$text}]->(m) \",\n",
140+
" params={\"ll\": rel_type_dict[rel_id], \"rel_id\": rel_id, \"text\": rel_id_to_text_dict[rel_id]},\n",
141+
" )\n",
142+
"\n",
133143
"\n",
134144
"process()"
135145
]
@@ -157,13 +167,14 @@
157167
" print(f\"Graph '{G.name()}' relationship types: {G.relationship_types()}\")\n",
158168
" print(f\"Graph '{G.name()}' relationship count: {G.relationship_count()}\")\n",
159169
"\n",
170+
"\n",
160171
"def project_graph():\n",
161172
" node_projection = {\"Entity\": {\"properties\": \"id\"}}\n",
162173
" relationship_projection = [\n",
163-
" {\"TRAIN\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
164-
" {\"TEST\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
165-
" {\"VALID\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
166-
" ]\n",
174+
" {\"TRAIN\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
175+
" {\"TEST\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
176+
" {\"VALID\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n",
177+
" ]\n",
167178
" G, result = gds.graph.project(\n",
168179
" \"fb15k-graph-ttv\",\n",
169180
" node_projection,\n",
@@ -174,6 +185,7 @@
174185
"\n",
175186
" return G\n",
176187
"\n",
188+
"\n",
177189
"ttv_G = project_graph()\n",
178190
"\n",
179191
"node_properties = gds.graph.nodeProperties.stream(\n",
@@ -193,15 +205,21 @@
193205
"outputs": [],
194206
"source": [
195207
"def create_data_from_graph(relationship_type):\n",
196-
" rels_tmp = gds.graph.relationshipProperties.stream(ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True)\n",
197-
" topology = [rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x])]\n",
208+
" rels_tmp = gds.graph.relationshipProperties.stream(\n",
209+
" ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n",
210+
" )\n",
211+
" topology = [\n",
212+
" rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n",
213+
" rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n",
214+
" ]\n",
198215
" edge_index = torch.tensor(topology, dtype=torch.long)\n",
199216
" edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n",
200217
" data = Data(edge_index=edge_index, edge_type=edge_type)\n",
201218
" data.num_nodes = len(nodeId_to_id)\n",
202219
" display(data)\n",
203220
" return data\n",
204221
"\n",
222+
"\n",
205223
"train_tensor_data = create_data_from_graph(\"TRAIN\")\n",
206224
"test_tensor_data = create_data_from_graph(\"TEST\")\n",
207225
"val_tensor_data = create_data_from_graph(\"VALID\")"
@@ -223,7 +241,7 @@
223241
"outputs": [],
224242
"source": [
225243
"def train_model_with_pyg():\n",
226-
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
244+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
227245
"\n",
228246
" model = TransE(\n",
229247
" num_nodes=train_tensor_data.num_nodes,\n",
@@ -241,7 +259,6 @@
241259
"\n",
242260
" optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
243261
"\n",
244-
"\n",
245262
" def train():\n",
246263
" model.train()\n",
247264
" total_loss = total_examples = 0\n",
@@ -254,7 +271,6 @@
254271
" total_examples += head_index.numel()\n",
255272
" return total_loss / total_examples\n",
256273
"\n",
257-
"\n",
258274
" @torch.no_grad()\n",
259275
" def test(data):\n",
260276
" model.eval()\n",
@@ -266,16 +282,14 @@
266282
" k=10,\n",
267283
" )\n",
268284
"\n",
269-
"\n",
270285
" # epoch_count = 501\n",
271286
" epoch_count = 2\n",
272287
" for epoch in range(1, epoch_count):\n",
273288
" loss = train()\n",
274-
" print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n",
289+
" print(f\"Epoch: {epoch:03d}, Loss: {loss:.4f}\")\n",
275290
" if epoch % 75 == 0:\n",
276291
" rank, hits = test(val_tensor_data)\n",
277-
" print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '\n",
278-
" f'Val Hits@10: {hits:.4f}')\n",
292+
" print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n",
279293
"\n",
280294
" print(model)\n",
281295
" # rank, hits_at_10 = test(test_tensor_data)\n",
@@ -307,12 +321,9 @@
307321
" if i % 100 == 0:\n",
308322
" print(f\"Node embeddings uploading: {i} of {len(nodeId_to_id)}\", end=\"\\r\")\n",
309323
" gds.run_cypher(\n",
310-
" \"MATCH (n:Entity {id: $i}) SET n.emb = $EMBEDDING\",\n",
311-
" params={\n",
312-
" \"i\": i,\n",
313-
" \"EMBEDDING\": model.node_emb.weight[i].tolist()\n",
314-
" },\n",
315-
" )\n",
324+
" \"MATCH (n:Entity {id: $i}) SET n.emb = $EMBEDDING\",\n",
325+
" params={\"i\": i, \"EMBEDDING\": model.node_emb.weight[i].tolist()},\n",
326+
" )\n",
316327
"print(f\"Node embeddings uploading has been finished\")"
317328
]
318329
},
@@ -355,10 +366,10 @@
355366
"source": [
356367
"# 3. Project graph to test\n",
357368
"G_test, result = gds.graph.project(\n",
358-
" \"graph_to_test\",\n",
359-
" {\"Entity\": {\"properties\": [\"id\", \"emb\"] }},\n",
360-
" rel_label_to_predict,\n",
361-
" )\n",
369+
" \"graph_to_test\",\n",
370+
" {\"Entity\": {\"properties\": [\"id\", \"emb\"]}},\n",
371+
" rel_label_to_predict,\n",
372+
")\n",
362373
"print_graph_info(G_test)"
363374
]
364375
},
@@ -369,9 +380,7 @@
369380
"outputs": [],
370381
"source": [
371382
"# 4. Set the model to predict\n",
372-
"transe_model = gds.model.transe.create(\n",
373-
" G_test, \"emb\", {rel_label_to_predict: target_emb}\n",
374-
")"
383+
"transe_model = gds.model.transe.create(G_test, \"emb\", {rel_label_to_predict: target_emb})"
375384
]
376385
},
377386
{
@@ -386,7 +395,7 @@
386395
" target_node_filter=\"Entity\",\n",
387396
" relationship_type=rel_label_to_predict,\n",
388397
" top_k=3,\n",
389-
" concurrency=4\n",
398+
" concurrency=4,\n",
390399
")\n",
391400
"print(result)"
392401
]
@@ -403,7 +412,7 @@
403412
" target_node_filter=\"Entity\",\n",
404413
" relationship_type=rel_label_to_predict,\n",
405414
" top_k=3,\n",
406-
" concurrency=4\n",
415+
" concurrency=4,\n",
407416
")\n",
408417
"print(result)"
409418
]
@@ -420,7 +429,7 @@
420429
" target_node_filter=\"Entity\",\n",
421430
" relationship_type=rel_label_to_predict,\n",
422431
" top_k=3,\n",
423-
" concurrency=4\n",
432+
" concurrency=4,\n",
424433
")\n",
425434
"print(result)"
426435
]
@@ -436,10 +445,10 @@
436445
" source_node_filter=[id_to_nodeId[5], id_to_nodeId[10]],\n",
437446
" target_node_filter=\"Entity\",\n",
438447
" relationship_type=rel_label_to_predict,\n",
439-
" write_relationship_type=\"WRITTEN_2_\"+rel_label_to_predict,\n",
448+
" write_relationship_type=\"WRITTEN_2_\" + rel_label_to_predict,\n",
440449
" write_property=\"transe_score\",\n",
441450
" top_k=3,\n",
442-
" concurrency=4\n",
451+
" concurrency=4,\n",
443452
")\n",
444453
"print(result)"
445454
]
@@ -458,7 +467,7 @@
458467
" mutate_relationship_type=\"MUT_WRITTEN_\" + rel_label_to_predict,\n",
459468
" mutate_property=\"mut_transe_score\",\n",
460469
" top_k=3,\n",
461-
" concurrency=4\n",
470+
" concurrency=4,\n",
462471
")\n",
463472
"print(result)"
464473
]
@@ -469,7 +478,9 @@
469478
"metadata": {},
470479
"outputs": [],
471480
"source": [
472-
"rr = gds.graph.relationshipProperties.stream(G_test, ['mut_transe_score'], \"MUT_WRITTEN_\" + rel_label_to_predict, separate_property_columns=True)\n",
481+
"rr = gds.graph.relationshipProperties.stream(\n",
482+
" G_test, [\"mut_transe_score\"], \"MUT_WRITTEN_\" + rel_label_to_predict, separate_property_columns=True\n",
483+
")\n",
473484
"print(rr)"
474485
]
475486
},
@@ -484,22 +495,8 @@
484495
}
485496
],
486497
"metadata": {
487-
"kernelspec": {
488-
"display_name": "Python 3",
489-
"language": "python",
490-
"name": "python3"
491-
},
492498
"language_info": {
493-
"codemirror_mode": {
494-
"name": "ipython",
495-
"version": 2
496-
},
497-
"file_extension": ".py",
498-
"mimetype": "text/x-python",
499-
"name": "python",
500-
"nbconvert_exporter": "python",
501-
"pygments_lexer": "ipython2",
502-
"version": "2.7.6"
499+
"name": "python"
503500
}
504501
},
505502
"nbformat": 4,

0 commit comments

Comments
 (0)