Skip to content

Commit bd95fea

Browse files
committed
Update reqs and printings
1 parent 4db86cd commit bd95fea

File tree

2 files changed

+17
-49
lines changed

2 files changed

+17
-49
lines changed

examples/kge-distmult.py

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def put_data_in_db(gds):
117117
for rel_type in dataset[rel_split]:
118118
edges = dataset[rel_split][rel_type]
119119

120-
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
121-
# MERGE (n)-[:{rel_split}]->(m)
122120
gds.run_cypher(
123121
f"""
124122
UNWIND $ll as l
@@ -156,33 +154,6 @@ def project_train_graph(gds):
156154
return G_train
157155

158156

159-
def project_predict_graph(gds):
160-
all_rels = gds.run_cypher(
161-
"""
162-
CALL db.relationshipTypes() YIELD relationshipType
163-
"""
164-
)
165-
all_rels = all_rels["relationshipType"].to_list()
166-
rel_spec = {}
167-
for rel in all_rels:
168-
if rel.startswith("REL_"):
169-
rel_spec[rel] = {"properties": ["split"]}
170-
171-
gds.graph.drop("fullGraph", failIfMissing=False)
172-
gds.graph.drop("predictGraph", failIfMissing=False)
173-
174-
# {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}}
175-
# print(rel_spec)
176-
177-
G_full, result = gds.graph.project("fullGraph", ["Entity"], all_rels)
178-
179-
G_full, result = gds.graph.project("fullGraph", ["Entity"], rel_spec)
180-
# G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2')
181-
182-
inspect_graph(G_full)
183-
return G_full
184-
185-
186157
def inspect_graph(G):
187158
func_names = [
188159
"name",
@@ -200,47 +171,43 @@ def inspect_graph(G):
200171
create_constraint(gds)
201172
put_data_in_db(gds)
202173
G_train = project_train_graph(gds)
203-
# G_predict = project_predict_graph(gds)
204-
# inspect_graph(G_train)
205174

206175
gds.set_compute_cluster_ip("localhost")
207176

208177
print(gds.debug.arrow())
209178

210179
model_name = "dummyModelName_" + str(time.time())
211180

212-
gds.kge.model.train(
181+
node_id_text = gds.find_node_id(["Entity"], {"text": "/m/016wzw"})
182+
node_id_2 = gds.find_node_id(["Entity"], {"id": 2})
183+
node_id_3 = gds.find_node_id(["Entity"], {"id": 3})
184+
node_id_0 = gds.find_node_id(["Entity"], {"id": 0})
185+
186+
res = gds.kge.model.train(
213187
G_train,
214188
model_name=model_name,
215-
scoring_function="DistMult",
189+
scoring_function="distmult",
216190
num_epochs=1,
217191
embedding_dimension=10,
218192
epochs_per_checkpoint=0,
219193
)
194+
print(res['metrics'])
220195

221-
gds.kge.model.predict(
222-
G_train,
196+
res = gds.kge.model.predict(
223197
model_name=model_name,
224198
top_k=10,
225-
node_ids=[1, 2, 3],
199+
node_ids=[node_id_3, node_id_2, node_id_text],
226200
rel_types=["REL_1", "REL_2"],
227201
)
202+
print(res.to_string())
228203

229-
gds.kge.model.predict_tail(
230-
G_train,
231-
model_name=model_name,
232-
top_k=10,
233-
node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})],
234-
rel_types=["REL_1", "REL_2"],
235-
)
236-
237-
gds.kge.model.score_triples(
238-
G_train,
204+
scores = gds.kge.model.score_triplets(
239205
model_name=model_name,
240-
triples=[
241-
(gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})),
242-
(gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})),
206+
triplets=[
207+
(node_id_2, "REL_1", node_id_text),
208+
(node_id_0, "REL_123", node_id_3),
243209
],
244210
)
211+
print(scores)
245212

246213
print("Finished training")

requirements/base/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ textdistance >= 4.0, < 5.0
77
tqdm >= 4.0, < 5.0
88
typing-extensions >= 4.0, < 5.0
99
requests
10+
rsa

0 commit comments

Comments
 (0)