|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 | 5 | "execution_count": null,
|
| 6 | + "metadata": {}, |
6 | 7 | "outputs": [],
|
7 | 8 | "source": [
|
8 | 9 | "%pip install graphdatascience torch torch_geometric"
|
9 |
| - ], |
10 |
| - "metadata": { |
11 |
| - "collapsed": false |
12 |
| - } |
| 10 | + ] |
13 | 11 | },
|
14 | 12 | {
|
15 | 13 | "cell_type": "code",
|
16 | 14 | "execution_count": null,
|
| 15 | + "metadata": {}, |
17 | 16 | "outputs": [],
|
18 | 17 | "source": [
|
| 18 | + "import os\n", |
19 | 19 | "from graphdatascience import GraphDataScience\n",
|
20 | 20 | "import torch\n",
|
21 | 21 | "import torch.optim as optim\n",
|
22 | 22 | "from torch_geometric.data import Data, download_url\n",
|
23 | 23 | "from torch_geometric.nn import TransE\n",
|
24 | 24 | "import collections"
|
25 |
| - ], |
26 |
| - "metadata": { |
27 |
| - "collapsed": false |
28 |
| - } |
| 25 | + ] |
29 | 26 | },
|
30 | 27 | {
|
31 | 28 | "cell_type": "code",
|
32 | 29 | "execution_count": null,
|
33 | 30 | "metadata": {},
|
34 | 31 | "outputs": [],
|
35 | 32 | "source": [
|
36 |
| - "gds = GraphDataScience(\"bolt://localhost:7687\", auth=('neo4j', 'neo4jneo4j'), database=\"ttt\")" |
| 33 | + "# Get Neo4j DB URI, credentials and name from environment if applicable\n", |
| 34 | + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", |
| 35 | + "NEO4J_AUTH = None\n", |
| 36 | + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", |
| 37 | + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", |
| 38 | + " NEO4J_AUTH = (\n", |
| 39 | + " os.environ.get(\"NEO4J_USER\"),\n", |
| 40 | + " os.environ.get(\"NEO4J_PASSWORD\"),\n", |
| 41 | + " )\n", |
| 42 | + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)" |
37 | 43 | ]
|
38 | 44 | },
|
39 | 45 | {
|
|
42 | 48 | "metadata": {},
|
43 | 49 | "outputs": [],
|
44 | 50 | "source": [
|
45 |
| - "url = ('https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237')\n", |
46 |
| - "raw_file_names = ['train.txt', 'valid.txt', 'test.txt']\n", |
47 |
| - "raw_dir = './data_from_url'\n", |
| 51 | + "url = \"https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237\"\n", |
| 52 | + "raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n", |
| 53 | + "raw_dir = \"./data_from_url\"\n", |
48 | 54 | "for filename in raw_file_names:\n",
|
49 |
| - " download_url(f'{url}/{filename}', raw_dir)" |
| 55 | + " download_url(f\"{url}/{filename}\", raw_dir)" |
50 | 56 | ]
|
51 | 57 | },
|
52 | 58 | {
|
|
65 | 71 | "outputs": [],
|
66 | 72 | "source": [
|
67 | 73 | "rel_types = {\n",
|
68 |
| - " \"train.txt\":\"TRAIN\",\n", |
69 |
| - " \"valid.txt\":\"VALID\",\n", |
70 |
| - " \"test.txt\":\"TEST\",\n", |
| 74 | + " \"train.txt\": \"TRAIN\",\n", |
| 75 | + " \"valid.txt\": \"VALID\",\n", |
| 76 | + " \"test.txt\": \"TEST\",\n", |
71 | 77 | "}\n",
|
72 | 78 | "rel_id_to_text_dict = {}\n",
|
73 | 79 | "rel_type_dict = collections.defaultdict(list)\n",
|
74 | 80 | "\n",
|
| 81 | + "\n", |
75 | 82 | "def process():\n",
|
76 | 83 | " node_dict_, rel_dict_ = {}, {}\n",
|
77 | 84 | " for file_name in raw_file_names:\n",
|
78 |
| - " file_name_path = raw_dir + '/' + file_name\n", |
| 85 | + " file_name_path = raw_dir + \"/\" + file_name\n", |
79 | 86 | "\n",
|
80 |
| - " with open(file_name_path, 'r') as f:\n", |
81 |
| - " data = [x.split('\\t') for x in f.read().split('\\n')[:-1]]\n", |
| 87 | + " with open(file_name_path, \"r\") as f:\n", |
| 88 | + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", |
82 | 89 | "\n",
|
83 | 90 | " list_of_dicts = []\n",
|
84 | 91 | " for i, (src, rel, dst) in enumerate(data):\n",
|
|
94 | 101 | " target = node_dict_[dst]\n",
|
95 | 102 | " edge_type = rel_dict_[rel]\n",
|
96 | 103 | "\n",
|
97 |
| - " rel_type_dict[edge_type].append({\n", |
98 |
| - " \"source\":source,\n", |
99 |
| - " \"target\":target,\n", |
100 |
| - " })\n", |
101 |
| - " list_of_dicts.append({\n", |
102 |
| - " \"source\": source,\n", |
103 |
| - " \"source_text\": src,\n", |
104 |
| - " \"target\": target,\n", |
105 |
| - " \"target_text\": dst,\n", |
106 |
| - " \"rel_id\": edge_type,\n", |
107 |
| - " })\n", |
| 104 | + " rel_type_dict[edge_type].append(\n", |
| 105 | + " {\n", |
| 106 | + " \"source\": source,\n", |
| 107 | + " \"target\": target,\n", |
| 108 | + " }\n", |
| 109 | + " )\n", |
| 110 | + " list_of_dicts.append(\n", |
| 111 | + " {\n", |
| 112 | + " \"source\": source,\n", |
| 113 | + " \"source_text\": src,\n", |
| 114 | + " \"target\": target,\n", |
| 115 | + " \"target_text\": dst,\n", |
| 116 | + " \"rel_id\": edge_type,\n", |
| 117 | + " }\n", |
| 118 | + " )\n", |
108 | 119 | "\n",
|
109 | 120 | " rel_type = rel_types[file_name]\n",
|
110 | 121 | " print(f\"Writing {len(list_of_dicts)} entities of {rel_type}\")\n",
|
111 | 122 | " gds.run_cypher(\n",
|
112 |
| - " \"UNWIND $ll as l \"+\n", |
113 |
| - " \"MERGE (n:Entity {id:l.source, text:l.source_text}) \"+\n", |
114 |
| - " \"MERGE (m:Entity {id:l.target, text:l.target_text}) \"+\n", |
115 |
| - " \"MERGE (n)-[:\"+rel_type+\" {rel_id:l.rel_id}]->(m) \",\n", |
116 |
| - " params={\n", |
117 |
| - " \"ll\": list_of_dicts\n", |
118 |
| - " },\n", |
119 |
| - " )\n", |
| 123 | + " \"UNWIND $ll as l \"\n", |
| 124 | + " + \"MERGE (n:Entity {id:l.source, text:l.source_text}) \"\n", |
| 125 | + " + \"MERGE (m:Entity {id:l.target, text:l.target_text}) \"\n", |
| 126 | + " + \"MERGE (n)-[:\"\n", |
| 127 | + " + rel_type\n", |
| 128 | + " + \" {rel_id:l.rel_id}]->(m) \",\n", |
| 129 | + " params={\"ll\": list_of_dicts},\n", |
| 130 | + " )\n", |
120 | 131 | "\n",
|
121 | 132 | " for rel_id in rel_type_dict:\n",
|
122 | 133 | " REL_TYPE = f\"REL_{rel_id}\"\n",
|
123 | 134 | " print(f\"Writing {len(rel_type_dict[rel_id])} entities of {REL_TYPE}\")\n",
|
124 | 135 | " gds.run_cypher(\n",
|
125 |
| - " \"UNWIND $ll AS l MATCH (n:Entity {id:l.source}), (m:Entity {id:l.target}) \"+\n", |
126 |
| - " \"MERGE (n)-[:\"+REL_TYPE+\" {rel_id:$rel_id, text:$text}]->(m) \",\n", |
127 |
| - " params={\n", |
128 |
| - " \"ll\": rel_type_dict[rel_id],\n", |
129 |
| - " \"rel_id\": rel_id,\n", |
130 |
| - " \"text\": rel_id_to_text_dict[rel_id]\n", |
131 |
| - " },\n", |
132 |
| - " )\n", |
| 136 | + " \"UNWIND $ll AS l MATCH (n:Entity {id:l.source}), (m:Entity {id:l.target}) \"\n", |
| 137 | + " + \"MERGE (n)-[:\"\n", |
| 138 | + " + REL_TYPE\n", |
| 139 | + " + \" {rel_id:$rel_id, text:$text}]->(m) \",\n", |
| 140 | + " params={\"ll\": rel_type_dict[rel_id], \"rel_id\": rel_id, \"text\": rel_id_to_text_dict[rel_id]},\n", |
| 141 | + " )\n", |
| 142 | + "\n", |
133 | 143 | "\n",
|
134 | 144 | "process()"
|
135 | 145 | ]
|
|
157 | 167 | " print(f\"Graph '{G.name()}' relationship types: {G.relationship_types()}\")\n",
|
158 | 168 | " print(f\"Graph '{G.name()}' relationship count: {G.relationship_count()}\")\n",
|
159 | 169 | "\n",
|
| 170 | + "\n", |
160 | 171 | "def project_graph():\n",
|
161 | 172 | " node_projection = {\"Entity\": {\"properties\": \"id\"}}\n",
|
162 | 173 | " relationship_projection = [\n",
|
163 |
| - " {\"TRAIN\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
164 |
| - " {\"TEST\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
165 |
| - " {\"VALID\" : {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
166 |
| - " ]\n", |
| 174 | + " {\"TRAIN\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
| 175 | + " {\"TEST\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
| 176 | + " {\"VALID\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", |
| 177 | + " ]\n", |
167 | 178 | " G, result = gds.graph.project(\n",
|
168 | 179 | " \"fb15k-graph-ttv\",\n",
|
169 | 180 | " node_projection,\n",
|
|
174 | 185 | "\n",
|
175 | 186 | " return G\n",
|
176 | 187 | "\n",
|
| 188 | + "\n", |
177 | 189 | "ttv_G = project_graph()\n",
|
178 | 190 | "\n",
|
179 | 191 | "node_properties = gds.graph.nodeProperties.stream(\n",
|
|
193 | 205 | "outputs": [],
|
194 | 206 | "source": [
|
195 | 207 | "def create_data_from_graph(relationship_type):\n",
|
196 |
| - " rels_tmp = gds.graph.relationshipProperties.stream(ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True)\n", |
197 |
| - " topology = [rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x])]\n", |
| 208 | + " rels_tmp = gds.graph.relationshipProperties.stream(\n", |
| 209 | + " ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n", |
| 210 | + " )\n", |
| 211 | + " topology = [\n", |
| 212 | + " rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n", |
| 213 | + " rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n", |
| 214 | + " ]\n", |
198 | 215 | " edge_index = torch.tensor(topology, dtype=torch.long)\n",
|
199 | 216 | " edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n",
|
200 | 217 | " data = Data(edge_index=edge_index, edge_type=edge_type)\n",
|
201 | 218 | " data.num_nodes = len(nodeId_to_id)\n",
|
202 | 219 | " display(data)\n",
|
203 | 220 | " return data\n",
|
204 | 221 | "\n",
|
| 222 | + "\n", |
205 | 223 | "train_tensor_data = create_data_from_graph(\"TRAIN\")\n",
|
206 | 224 | "test_tensor_data = create_data_from_graph(\"TEST\")\n",
|
207 | 225 | "val_tensor_data = create_data_from_graph(\"VALID\")"
|
|
223 | 241 | "outputs": [],
|
224 | 242 | "source": [
|
225 | 243 | "def train_model_with_pyg():\n",
|
226 |
| - " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", |
| 244 | + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", |
227 | 245 | "\n",
|
228 | 246 | " model = TransE(\n",
|
229 | 247 | " num_nodes=train_tensor_data.num_nodes,\n",
|
|
241 | 259 | "\n",
|
242 | 260 | " optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
|
243 | 261 | "\n",
|
244 |
| - "\n", |
245 | 262 | " def train():\n",
|
246 | 263 | " model.train()\n",
|
247 | 264 | " total_loss = total_examples = 0\n",
|
|
254 | 271 | " total_examples += head_index.numel()\n",
|
255 | 272 | " return total_loss / total_examples\n",
|
256 | 273 | "\n",
|
257 |
| - "\n", |
258 | 274 | " @torch.no_grad()\n",
|
259 | 275 | " def test(data):\n",
|
260 | 276 | " model.eval()\n",
|
|
266 | 282 | " k=10,\n",
|
267 | 283 | " )\n",
|
268 | 284 | "\n",
|
269 |
| - "\n", |
270 | 285 | " # epoch_count = 501\n",
|
271 | 286 | " epoch_count = 2\n",
|
272 | 287 | " for epoch in range(1, epoch_count):\n",
|
273 | 288 | " loss = train()\n",
|
274 |
| - " print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n", |
| 289 | + " print(f\"Epoch: {epoch:03d}, Loss: {loss:.4f}\")\n", |
275 | 290 | " if epoch % 75 == 0:\n",
|
276 | 291 | " rank, hits = test(val_tensor_data)\n",
|
277 |
| - " print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '\n", |
278 |
| - " f'Val Hits@10: {hits:.4f}')\n", |
| 292 | + " print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n", |
279 | 293 | "\n",
|
280 | 294 | " print(model)\n",
|
281 | 295 | " # rank, hits_at_10 = test(test_tensor_data)\n",
|
|
307 | 321 | " if i % 100 == 0:\n",
|
308 | 322 | " print(f\"Node embeddings uploading: {i} of {len(nodeId_to_id)}\", end=\"\\r\")\n",
|
309 | 323 | " gds.run_cypher(\n",
|
310 |
| - " \"MATCH (n:Entity {id: $i}) SET n.emb = $EMBEDDING\",\n", |
311 |
| - " params={\n", |
312 |
| - " \"i\": i,\n", |
313 |
| - " \"EMBEDDING\": model.node_emb.weight[i].tolist()\n", |
314 |
| - " },\n", |
315 |
| - " )\n", |
| 324 | + " \"MATCH (n:Entity {id: $i}) SET n.emb = $EMBEDDING\",\n", |
| 325 | + " params={\"i\": i, \"EMBEDDING\": model.node_emb.weight[i].tolist()},\n", |
| 326 | + " )\n", |
316 | 327 | "print(f\"Node embeddings uploading has been finished\")"
|
317 | 328 | ]
|
318 | 329 | },
|
|
355 | 366 | "source": [
|
356 | 367 | "# 3. Project graph to test\n",
|
357 | 368 | "G_test, result = gds.graph.project(\n",
|
358 |
| - " \"graph_to_test\",\n", |
359 |
| - " {\"Entity\": {\"properties\": [\"id\", \"emb\"] }},\n", |
360 |
| - " rel_label_to_predict,\n", |
361 |
| - " )\n", |
| 369 | + " \"graph_to_test\",\n", |
| 370 | + " {\"Entity\": {\"properties\": [\"id\", \"emb\"]}},\n", |
| 371 | + " rel_label_to_predict,\n", |
| 372 | + ")\n", |
362 | 373 | "print_graph_info(G_test)"
|
363 | 374 | ]
|
364 | 375 | },
|
|
369 | 380 | "outputs": [],
|
370 | 381 | "source": [
|
371 | 382 | "# 4. Set the model to predict\n",
|
372 |
| - "transe_model = gds.model.transe.create(\n", |
373 |
| - " G_test, \"emb\", {rel_label_to_predict: target_emb}\n", |
374 |
| - ")" |
| 383 | + "transe_model = gds.model.transe.create(G_test, \"emb\", {rel_label_to_predict: target_emb})" |
375 | 384 | ]
|
376 | 385 | },
|
377 | 386 | {
|
|
386 | 395 | " target_node_filter=\"Entity\",\n",
|
387 | 396 | " relationship_type=rel_label_to_predict,\n",
|
388 | 397 | " top_k=3,\n",
|
389 |
| - " concurrency=4\n", |
| 398 | + " concurrency=4,\n", |
390 | 399 | ")\n",
|
391 | 400 | "print(result)"
|
392 | 401 | ]
|
|
403 | 412 | " target_node_filter=\"Entity\",\n",
|
404 | 413 | " relationship_type=rel_label_to_predict,\n",
|
405 | 414 | " top_k=3,\n",
|
406 |
| - " concurrency=4\n", |
| 415 | + " concurrency=4,\n", |
407 | 416 | ")\n",
|
408 | 417 | "print(result)"
|
409 | 418 | ]
|
|
420 | 429 | " target_node_filter=\"Entity\",\n",
|
421 | 430 | " relationship_type=rel_label_to_predict,\n",
|
422 | 431 | " top_k=3,\n",
|
423 |
| - " concurrency=4\n", |
| 432 | + " concurrency=4,\n", |
424 | 433 | ")\n",
|
425 | 434 | "print(result)"
|
426 | 435 | ]
|
|
436 | 445 | " source_node_filter=[id_to_nodeId[5], id_to_nodeId[10]],\n",
|
437 | 446 | " target_node_filter=\"Entity\",\n",
|
438 | 447 | " relationship_type=rel_label_to_predict,\n",
|
439 |
| - " write_relationship_type=\"WRITTEN_2_\"+rel_label_to_predict,\n", |
| 448 | + " write_relationship_type=\"WRITTEN_2_\" + rel_label_to_predict,\n", |
440 | 449 | " write_property=\"transe_score\",\n",
|
441 | 450 | " top_k=3,\n",
|
442 |
| - " concurrency=4\n", |
| 451 | + " concurrency=4,\n", |
443 | 452 | ")\n",
|
444 | 453 | "print(result)"
|
445 | 454 | ]
|
|
458 | 467 | " mutate_relationship_type=\"MUT_WRITTEN_\" + rel_label_to_predict,\n",
|
459 | 468 | " mutate_property=\"mut_transe_score\",\n",
|
460 | 469 | " top_k=3,\n",
|
461 |
| - " concurrency=4\n", |
| 470 | + " concurrency=4,\n", |
462 | 471 | ")\n",
|
463 | 472 | "print(result)"
|
464 | 473 | ]
|
|
469 | 478 | "metadata": {},
|
470 | 479 | "outputs": [],
|
471 | 480 | "source": [
|
472 |
| - "rr = gds.graph.relationshipProperties.stream(G_test, ['mut_transe_score'], \"MUT_WRITTEN_\" + rel_label_to_predict, separate_property_columns=True)\n", |
| 481 | + "rr = gds.graph.relationshipProperties.stream(\n", |
| 482 | + " G_test, [\"mut_transe_score\"], \"MUT_WRITTEN_\" + rel_label_to_predict, separate_property_columns=True\n", |
| 483 | + ")\n", |
473 | 484 | "print(rr)"
|
474 | 485 | ]
|
475 | 486 | },
|
|
484 | 495 | }
|
485 | 496 | ],
|
486 | 497 | "metadata": {
|
487 |
| - "kernelspec": { |
488 |
| - "display_name": "Python 3", |
489 |
| - "language": "python", |
490 |
| - "name": "python3" |
491 |
| - }, |
492 | 498 | "language_info": {
|
493 |
| - "codemirror_mode": { |
494 |
| - "name": "ipython", |
495 |
| - "version": 2 |
496 |
| - }, |
497 |
| - "file_extension": ".py", |
498 |
| - "mimetype": "text/x-python", |
499 |
| - "name": "python", |
500 |
| - "nbconvert_exporter": "python", |
501 |
| - "pygments_lexer": "ipython2", |
502 |
| - "version": "2.7.6" |
| 499 | + "name": "python" |
503 | 500 | }
|
504 | 501 | },
|
505 | 502 | "nbformat": 4,
|
|
0 commit comments