Skip to content

Commit 1e65718

Browse files
committed
format
1 parent 0f83881 commit 1e65718

File tree

7 files changed

+111
-81
lines changed

7 files changed

+111
-81
lines changed

src/linkml_store/api/collection.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,21 @@
44
import logging
55
from collections import defaultdict
66
from pathlib import Path
7-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, List, Optional, TextIO, Tuple, Type, Union, \
8-
Iterable
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
ClassVar,
11+
Dict,
12+
Generic,
13+
Iterable,
14+
Iterator,
15+
List,
16+
Optional,
17+
TextIO,
18+
Tuple,
19+
Type,
20+
Union,
21+
)
922

1023
import numpy as np
1124
from linkml_runtime import SchemaView
@@ -985,7 +998,9 @@ def diff(self, other: "Collection", **kwargs) -> List[PatchDict]:
985998
patches_from_objects_lists(src_objs, tgt_objs, primary_key=primary_key)
986999
return patches_from_objects_lists(src_objs, tgt_objs, primary_key=primary_key)
9871000

988-
def iter_validate_collection(self, objects: Optional[Iterable[OBJECT]] = None, **kwargs) -> Iterator["ValidationResult"]:
1001+
def iter_validate_collection(
1002+
self, objects: Optional[Iterable[OBJECT]] = None, **kwargs
1003+
) -> Iterator["ValidationResult"]:
9891004
"""
9901005
Validate the contents of the collection
9911006
@@ -1001,7 +1016,7 @@ def iter_validate_collection(self, objects: Optional[Iterable[OBJECT]] = None, *
10011016
if not cd:
10021017
raise ValueError(f"Cannot find class definition for {self.target_class_name}")
10031018
type_designator = None
1004-
for att in cd.attributes.values():
1019+
for att in self.parent.schema_view.class_induced_slots(cd.name):
10051020
if att.designates_type:
10061021
type_designator = att.name
10071022
class_name = cd.name
@@ -1014,7 +1029,7 @@ def iter_validate_collection(self, objects: Optional[Iterable[OBJECT]] = None, *
10141029
# TODO: move type designator logic to core linkml
10151030
this_class_name = obj.get(type_designator)
10161031
if this_class_name:
1017-
if ":"in this_class_name:
1032+
if ":" in this_class_name:
10181033
this_class_name = this_class_name.split(":")[-1]
10191034
v_class_name = this_class_name
10201035
yield from validator.iter_results(obj, v_class_name)

src/linkml_store/api/database.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,8 @@ def set_schema_view(self, schema_view: Union[str, Path, SchemaView]):
504504
schema_view = str(schema_view)
505505
if isinstance(schema_view, str):
506506
schema_view = SchemaView(schema_view)
507-
# self._schema_view = schema_view
508-
self._schema_view = SchemaView(schema_view.materialize_derived_schema())
507+
self._schema_view = schema_view
508+
# self._schema_view = SchemaView(schema_view.materialize_derived_schema())
509509
if not self._collections:
510510
return
511511

@@ -531,7 +531,6 @@ def set_schema_view(self, schema_view: Union[str, Path, SchemaView]):
531531
coll = self._collections[slot.name]
532532
coll.metadata.type = slot.range
533533

534-
535534
def load_schema_view(self, path: Union[str, Path]):
536535
"""
537536
Load a schema view from a file.

src/linkml_store/api/stores/neo4j/neo4j_collection.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from linkml_store.api import Collection
88
from linkml_store.api.collection import DEFAULT_FACET_LIMIT, OBJECT
99
from linkml_store.api.queries import Query, QueryResult
10-
from linkml_store.graphs.graph_map import GraphProjection, EdgeProjection, NodeProjection
10+
from linkml_store.graphs.graph_map import EdgeProjection, GraphProjection, NodeProjection
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -35,7 +35,7 @@ def session(self) -> Session:
3535

3636
def _check_if_initialized(self) -> bool:
3737
with self.session() as session:
38-
result = session.run(f"MATCH (n) RETURN count(n) > 0 as exists")
38+
result = session.run("MATCH (n) RETURN count(n) > 0 as exists")
3939
return result.single()["exists"]
4040

4141
@property
@@ -101,7 +101,9 @@ def set_is_node_collection(self, force=False):
101101
raise ValueError("Cannot reassign without force=True")
102102
self.metadata.graph_projection = NodeProjection()
103103

104-
def _prop_clause(self, obj: OBJECT, node_var: Optional[str] = None, exclude_attributes: Optional[List[str]]=None) -> str:
104+
def _prop_clause(
105+
self, obj: OBJECT, node_var: Optional[str] = None, exclude_attributes: Optional[List[str]] = None
106+
) -> str:
105107
if exclude_attributes is None:
106108
exclude_attributes = [self.category_labels_attribute]
107109
node_prefix = node_var + "." if node_var else ""
@@ -141,7 +143,9 @@ def _create_insert_cypher_query(self, obj: OBJECT) -> str:
141143
# check if nodes present; if not, make dangling stubs
142144
# TODO: decide on how this should be handled in validation if some fields are required
143145
for node_id in [obj[ep.subject_attribute], obj[ep.object_attribute]]:
144-
check_query = f"MATCH (n {{{ep.identifier_attribute}: ${ep.identifier_attribute}}}) RETURN count(n) as count"
146+
check_query = (
147+
f"MATCH (n {{{ep.identifier_attribute}: ${ep.identifier_attribute}}}) RETURN count(n) as count"
148+
)
145149
with self.session() as session:
146150
result = session.run(check_query, **{ep.identifier_attribute: node_id})
147151
if result.single()["count"] == 0:
@@ -150,11 +154,9 @@ def _create_insert_cypher_query(self, obj: OBJECT) -> str:
150154
session.run(stub_query, **{ep.identifier_attribute: node_id})
151155
else:
152156
raise ValueError(f"Node with identifier {node_id} not found in the database.")
153-
edge_props = self._prop_clause(obj, exclude_attributes=[
154-
ep.subject_attribute,
155-
ep.predicate_attribute,
156-
ep.object_attribute
157-
])
157+
edge_props = self._prop_clause(
158+
obj, exclude_attributes=[ep.subject_attribute, ep.predicate_attribute, ep.object_attribute]
159+
)
158160
return f"""
159161
MATCH (s {{{id_attribute}: ${ep.subject_attribute}}}), (o {{{id_attribute}: ${ep.object_attribute}}})
160162
CREATE (s)-[r:{pred} {{{edge_props}}}]->(o)
@@ -175,13 +177,15 @@ def query(self, query: Query, limit: Optional[int] = None, offset: Optional[int]
175177
if self.is_edge_collection:
176178
rows = [self._edge_to_dict(record) for record in result]
177179
else:
180+
178181
def node_to_dict(n) -> dict:
179182
d = dict(n.items())
180183
if ca:
181184
labels = list(n.labels)
182185
if labels:
183186
d[ca] = labels[0]
184187
return d
188+
185189
rows = [node_to_dict(record["n"]) for record in result]
186190

187191
# count_query = self._build_count_query(query, is_count=True)
@@ -191,7 +195,9 @@ def node_to_dict(n) -> dict:
191195

192196
return QueryResult(query=query, num_rows=count, rows=rows)
193197

194-
def _build_cypher_query(self, query: Query, limit: Optional[int] = None, offset: Optional[int]=None, is_count=False) -> str:
198+
def _build_cypher_query(
199+
self, query: Query, limit: Optional[int] = None, offset: Optional[int] = None, is_count=False
200+
) -> str:
195201
if self.is_edge_collection:
196202
ep = self.edge_projection
197203
ia = ep.identifier_attribute
@@ -247,8 +253,7 @@ def _build_cypher_query(self, query: Query, limit: Optional[int] = None, offset:
247253

248254
return cypher_query
249255

250-
251-
def _build_where_clause(self, where_clause: Dict[str, Any], prefix: str = 'n') -> str:
256+
def _build_where_clause(self, where_clause: Dict[str, Any], prefix: str = "n") -> str:
252257
conditions = []
253258
if where_clause is None:
254259
return ""
@@ -269,15 +274,15 @@ def _edge_to_dict(self, record: Dict) -> Dict[str, Any]:
269274
ep.subject_attribute: record["subject"],
270275
ep.predicate_attribute: record["predicate"],
271276
ep.object_attribute: record["object"],
272-
**dict(r.items())
277+
**dict(r.items()),
273278
}
274279

275280
def query_facets(
276-
self,
277-
where: Dict = None,
278-
facet_columns: List[Union[str, Tuple[str, ...]]] = None,
279-
facet_limit=DEFAULT_FACET_LIMIT,
280-
**kwargs,
281+
self,
282+
where: Dict = None,
283+
facet_columns: List[Union[str, Tuple[str, ...]]] = None,
284+
facet_limit=DEFAULT_FACET_LIMIT,
285+
**kwargs,
281286
) -> Dict[Union[str, Tuple[str, ...]], List[Tuple[Any, int]]]:
282287
results = {}
283288
if not facet_columns:
@@ -334,23 +339,24 @@ def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> int:
334339

335340
return deleted_nodes
336341

337-
def delete_where(self, where: Optional[Dict[str, Any]] = None,
338-
missing_ok=True, **kwargs) -> int:
342+
def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> int:
339343
delete_policy = self.delete_policy
340344
where_clause = self._build_where_clause(where) if where else ""
341345
node_pattern = self._node_pattern(where)
342346

343347
with self.session() as session:
344-
deleted_nodes, deleted_relationships = self._execute_delete(session, node_pattern, where_clause,
345-
delete_policy)
348+
deleted_nodes, deleted_relationships = self._execute_delete(
349+
session, node_pattern, where_clause, delete_policy
350+
)
346351

347352
if deleted_nodes == 0 and not missing_ok:
348353
raise ValueError(f"No nodes found for {where}")
349354

350355
return deleted_nodes
351356

352-
def _execute_delete(self, session, node_pattern: str, where_clause: str, delete_policy: DeletePolicy, **params) -> \
353-
Tuple[int, int]:
357+
def _execute_delete(
358+
self, session, node_pattern: str, where_clause: str, delete_policy: DeletePolicy, **params
359+
) -> Tuple[int, int]:
354360
deleted_relationships = 0
355361
deleted_nodes = 0
356362

@@ -376,7 +382,6 @@ def _execute_delete(self, session, node_pattern: str, where_clause: str, delete_
376382

377383
return deleted_nodes, deleted_relationships
378384

379-
380385
def update(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> int:
381386
if not isinstance(objs, list):
382387
objs = [objs]
@@ -393,30 +398,32 @@ def update(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> int:
393398
def _create_update_cypher_query(self, obj: OBJECT) -> str:
394399
id_attribute = self.identifier_attribute
395400
category_labels_attribute = self.category_labels_attribute
396-
node_pattern = self._node_pattern(obj)
397401

398402
# Prepare SET clause
399403
set_items = [f"n.{k} = ${k}" for k in obj.keys() if k not in [id_attribute, category_labels_attribute]]
400404
set_clause = ", ".join(set_items)
401405

402406
# Prepare labels update
403407
labels_to_add = []
404-
labels_to_remove = []
408+
# labels_to_remove = []
405409
if category_labels_attribute in obj:
406-
new_labels = obj[category_labels_attribute] if isinstance(obj[category_labels_attribute], list) else [
407-
obj[category_labels_attribute]]
410+
new_labels = (
411+
obj[category_labels_attribute]
412+
if isinstance(obj[category_labels_attribute], list)
413+
else [obj[category_labels_attribute]]
414+
)
408415
labels_to_add = [f":{label}" for label in new_labels]
409-
labels_to_remove = [f":Label" for _ in new_labels] # Placeholder for labels to remove
416+
# labels_to_remove = [":Label" for _ in new_labels] # Placeholder for labels to remove
410417

411418
# Construct the query
412419
query = f"MATCH (n {{{id_attribute}: ${id_attribute}}})\n"
413-
#if labels_to_remove:
420+
# f labels_to_remove:
414421
# query += f"REMOVE n{' '.join(labels_to_remove)}\n"
415422
if labels_to_add:
416423
query += f"SET n{' '.join(labels_to_add)}\n"
417-
#f"REMOVE n{' '.join(labels_to_remove)}' if labels_to_remove else ''}"
418-
#f"{f'SET n{' '.join(labels_to_add)}' if labels_to_add else ''}"
424+
# f"REMOVE n{' '.join(labels_to_remove)}' if labels_to_remove else ''}"
425+
# f"{f'SET n{' '.join(labels_to_add)}' if labels_to_add else ''}"
419426
query += f"SET {set_clause}\n"
420427
query += "RETURN n"
421428
print(query)
422-
return query
429+
return query

src/linkml_store/api/stores/neo4j/neo4j_database.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# neo4j_database.py
22

33
import logging
4-
from typing import Optional, Union, List, Dict, Any
54
from pathlib import Path
5+
from typing import Optional, Union
66

7-
from neo4j import GraphDatabase, Driver, Session
7+
from neo4j import Driver, GraphDatabase, Session
88

99
from linkml_store.api import Database
1010
from linkml_store.api.queries import Query, QueryResult
@@ -107,22 +107,30 @@ def export_database(self, location: str, target_format: Optional[Union[str, Form
107107

108108
result = session.run("MATCH ()-[r]->() RETURN r")
109109
relationships = [
110-
{"type": record["r"].type, "start": record["r"].start_node.id, "end": record["r"].end_node.id,
111-
**dict(record["r"].items())} for record in result]
110+
{
111+
"type": record["r"].type,
112+
"start": record["r"].start_node.id,
113+
"end": record["r"].end_node.id,
114+
**dict(record["r"].items()),
115+
}
116+
for record in result
117+
]
112118

113119
data = {"nodes": nodes, "relationships": relationships}
114120

115121
import json
116-
with open(path, 'w') as f:
122+
123+
with open(path, "w") as f:
117124
json.dump(data, f)
118125
else:
119126
super().export_database(location, target_format=target_format, **kwargs)
120127

121128
def import_database(self, location: str, source_format: Optional[str] = None, **kwargs):
122129
if source_format == Format.JSON or source_format == "json":
123130
path = Path(location)
124-
with open(path, 'r') as f:
131+
with open(path, "r") as f:
125132
import json
133+
126134
data = json.load(f)
127135

128136
with self.driver.session() as session:
@@ -133,11 +141,14 @@ def import_database(self, location: str, source_format: Optional[str] = None, **
133141
session.run(query, **node)
134142

135143
for rel in data["relationships"]:
136-
rel_type = rel.pop("type")
144+
# rel_type = rel.pop("type")
137145
start = rel.pop("start")
138146
end = rel.pop("end")
139-
props = ", ".join([f"{k}: ${k}" for k in rel.keys()])
140-
query = f"MATCH (a), (b) WHERE id(a) = {start} AND id(b) = {end} CREATE (a)-[r:{rel_type} {{{props}}}]->(b)"
147+
# props = ", ".join([f"{k}: ${k}" for k in rel.keys()])
148+
query = (
149+
f"MATCH (a), (b) WHERE id(a) = {start} AND id(b) = {end} "
150+
"CREATE (a)-[r:{rel_type} {{{props}}}]->(b)"
151+
)
141152
session.run(query, **rel)
142153
else:
143-
super().import_database(location, source_format=source_format, **kwargs)
154+
super().import_database(location, source_format=source_format, **kwargs)

src/linkml_store/graphs/graph_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC
2-
from typing import List, Optional
2+
from typing import Optional
33

44
from pydantic import BaseModel
55

@@ -17,6 +17,7 @@ class GraphProjection(BaseModel, ABC):
1717
class NodeProjection(GraphProjection):
1818
category_labels_attribute: Optional[str] = DEFAULT_CATEGORY_LABELS_ATTRIBUTE
1919

20+
2021
class EdgeProjection(GraphProjection):
2122
subject_attribute: str = DEFAULT_SUBJECT_ATTRIBUTE
2223
predicate_attribute: str = DEFAULT_PREDICATE_ATTRIBUTE

src/linkml_store/utils/neo4j_utils.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from py2neo import Graph
21
import networkx as nx
2+
from py2neo import Graph
3+
34

45
def draw_neo4j_graph(handle="bolt://localhost:7687", auth=("neo4j", None)):
56
# Connect to Neo4j
@@ -16,26 +17,26 @@ def draw_neo4j_graph(handle="bolt://localhost:7687", auth=("neo4j", None)):
1617
# Create a NetworkX graph
1718
G = nx.DiGraph() # Use DiGraph for directed edges
1819
for record in result:
19-
n = record['n']
20-
m = record['m']
21-
r = record['r']
22-
G.add_node(n['name'], label=list(n.labels or ["-"])[0])
23-
G.add_node(m['name'], label=list(m.labels or ["-"])[0])
24-
G.add_edge(n['name'], m['name'], type=type(r).__name__)
20+
n = record["n"]
21+
m = record["m"]
22+
r = record["r"]
23+
G.add_node(n["name"], label=list(n.labels or ["-"])[0])
24+
G.add_node(m["name"], label=list(m.labels or ["-"])[0])
25+
G.add_edge(n["name"], m["name"], type=type(r).__name__)
2526

2627
# Draw the graph
2728
pos = nx.spring_layout(G)
2829

2930
# Draw nodes
30-
nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=10000)
31+
nx.draw_networkx_nodes(G, pos, node_color="lightblue", node_size=10000)
3132

3233
# Draw edges
33-
nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True)
34+
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True)
3435

3536
# Add node labels
36-
node_labels = nx.get_node_attributes(G, 'label')
37+
node_labels = nx.get_node_attributes(G, "label")
3738
nx.draw_networkx_labels(G, pos, {node: f"{node}\n({label})" for node, label in node_labels.items()}, font_size=16)
3839

3940
# Add edge labels
40-
edge_labels = nx.get_edge_attributes(G, 'type')
41-
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=16)
41+
edge_labels = nx.get_edge_attributes(G, "type")
42+
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=16)

0 commit comments

Comments
 (0)