Skip to content

Commit bf7fee2

Browse files
committed
Adding tests and updating clients, vendors
1 parent 7cc252a commit bf7fee2

File tree

5 files changed

+44
-32
lines changed

5 files changed

+44
-32
lines changed

gqlalchemy/models.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from collections import defaultdict
1717
from dataclasses import dataclass
1818
from datetime import datetime, date, time, timedelta
19-
from enum import Enum
2019
import json
2120
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
2221
from enum import Enum, EnumMeta
@@ -59,34 +58,37 @@ def _format_timedelta(duration: timedelta) -> str:
5958

6059
return f"P{days}DT{hours}H{minutes}M{remainder_sec}S"
6160

61+
6262
class GraphEnum(ABC):
6363
def __init__(self, enum):
6464

6565
if not isinstance(enum, (Enum, EnumMeta)):
6666
raise TypeError()
67-
67+
6868
self.enum = enum if isinstance(enum, Enum) else None
6969
self.cls = enum.__class__ if isinstance(enum, Enum) else enum
70-
70+
7171
@property
7272
def name(self):
7373
return self.cls.__name__
74-
74+
7575
@property
7676
def members(self):
7777
return self.cls.__members__
78-
78+
7979
@abstractmethod
8080
def _to_cypher(self):
8181
pass
8282

83+
8384
class MemgraphEnum(GraphEnum):
8485
def _to_cypher(self):
8586
return f"{{ {', '.join(self.cls._member_names_)} }}"
86-
87+
8788
def __repr__(self):
88-
return f"<enum '{self.name}'>" if self.enum is None else f'{self.name}::{self.enum.name}'
89-
89+
return f"<enum '{self.name}'>" if self.enum is None else f"{self.name}::{self.enum.name}"
90+
91+
9092
class TriggerEventType:
9193
"""An enum representing types of trigger events."""
9294

@@ -343,7 +345,7 @@ def __init__(self, **data):
343345
if issubclass(cls, Enum) and not attrs.get("enum", False):
344346
value = data.get(field)
345347
if isinstance(value, dict):
346-
member = value.get("__value").split('::')[1]
348+
member = value.get("__value").split("::")[1]
347349
data[field] = cls[member].value
348350
super().__init__(**data)
349351

@@ -583,7 +585,7 @@ def get_base_labels() -> Set[str]:
583585
if cls.enums is None:
584586
cls.enums = db.get_enums()
585587
enum_names = [x.name for x in cls.enums]
586-
if(field_cls.__name__ in enum_names):
588+
if field_cls.__name__ in enum_names:
587589
existing = cls.enums[enum_names.index(field_cls.__name__)]
588590
db.sync_enum(existing, MemgraphEnum(field_cls))
589591
else:
@@ -742,7 +744,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
742744
if cls.enums is None:
743745
cls.enums = db.get_enums()
744746
enum_names = [x.name for x in cls.enums]
745-
if(field_type in enum_names):
747+
if field_type in enum_names:
746748
existing = cls.enums[enum_names.index(field_type)]
747749
db.sync_enum(existing, MemgraphEnum(field_cls))
748750
else:

gqlalchemy/vendors/database_client.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717

1818
from gqlalchemy.connection import Connection
1919
from gqlalchemy.exceptions import GQLAlchemyError
20-
from gqlalchemy.models import (
21-
Constraint,
22-
Index,
23-
GraphEnum,
24-
Node,
25-
Relationship
26-
)
20+
from gqlalchemy.models import Constraint, Index, GraphEnum, Node, Relationship
21+
2722

2823
class DatabaseClient(ABC):
2924
def __init__(
@@ -127,7 +122,7 @@ def ensure_constraints(
127122
self.drop_constraint(obsolete_constraints)
128123
for missing_constraint in new_constraints.difference(old_constraints):
129124
self.create_constraint(missing_constraint)
130-
125+
131126
@abstractmethod
132127
def create_enum(self, enum: GraphEnum) -> None:
133128
pass
@@ -136,7 +131,7 @@ def create_enum(self, enum: GraphEnum) -> None:
136131
def get_enums(self) -> List[GraphEnum]:
137132
"""Returns a list of all enums defined in the database."""
138133
pass
139-
134+
140135
@abstractmethod
141136
def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None:
142137
"""Ensures that database enum matches input enum."""
@@ -146,7 +141,7 @@ def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None:
146141
def drop_enum(self, enum: GraphEnum) -> None:
147142
"""Drops a single enum in the database."""
148143
pass
149-
144+
150145
@abstractmethod
151146
def drop_enums(self) -> None:
152147
"""Drops all enums in the database"""

gqlalchemy/vendors/memgraph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import sqlite3
1818
from typing import List, Optional, Union
19-
import warnings
2019

2120
from gqlalchemy.connection import Connection, MemgraphConnection
2221
from gqlalchemy.disk_storage import OnDiskPropertyDatabase
@@ -25,7 +24,6 @@
2524
GQLAlchemyFileNotFoundError,
2625
GQLAlchemyOnDiskPropertyDatabaseNotDefinedError,
2726
GQLAlchemyUniquenessConstraintError,
28-
GQLAlchemyWarning,
2927
)
3028
from gqlalchemy.models import (
3129
MemgraphConstraintExists,
@@ -169,7 +167,7 @@ def get_constraints(
169167
)
170168
)
171169
return constraints
172-
170+
173171
def create_enum(self, graph_enum: MemgraphEnum) -> None:
174172
query = f"CREATE ENUM {graph_enum.name} VALUES {graph_enum._to_cypher()};"
175173
self.execute(query)
@@ -178,9 +176,9 @@ def get_enums(self) -> List[MemgraphEnum]:
178176
"""Returns a list of all enums defined in the database."""
179177
enums: List[MemgraphEnum] = []
180178
for result in self.execute_and_fetch("SHOW ENUMS;"):
181-
enums.append(MemgraphEnum(Enum(result['Enum Name'], result['Enum Values'])))
179+
enums.append(MemgraphEnum(Enum(result["Enum Name"], result["Enum Values"])))
182180
return enums
183-
181+
184182
def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None:
185183
"""Ensures that database enum matches input enum."""
186184
for value in new.members:
@@ -190,9 +188,11 @@ def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None:
190188

191189
def drop_enum(self, graph_enum: MemgraphEnum):
192190
raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.name} is persisted in the database.")
193-
191+
194192
def drop_enums(self, graph_enums: List[MemgraphEnum]):
195-
raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database.")
193+
raise GQLAlchemyError(
194+
f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database."
195+
)
196196

197197
def get_exists_constraints(
198198
self,

gqlalchemy/vendors/neo4j.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import os
1616
from typing import List, Optional, Union
17-
from enum import Enum
1817

1918
from gqlalchemy.connection import Connection, Neo4jConnection
2019
from gqlalchemy.exceptions import (
@@ -107,14 +106,14 @@ def create_enum(self, graph_enum: GraphEnum) -> None:
107106
def get_enums(self) -> List[GraphEnum]:
108107
"""Returns a list of all enums defined in the database."""
109108
raise GQLAlchemyError(f"SHOW ENUMS not yet implemented in Neo4j.")
110-
109+
111110
def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None:
112111
"""Ensures that database enum matches input enum."""
113112
raise GQLAlchemyError(f"ALTER ENUM not yet implemented in Neo4j.")
114-
113+
115114
def drop_enum(self, graph_enum: GraphEnum):
116115
raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.")
117-
116+
118117
def drop_enums(self, graph_enums: List[GraphEnum]):
119118
raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.")
120119

tests/ogm/test_custom_fields.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414
from pydantic.v1 import Field
1515

16+
from enum import Enum
17+
1618
from gqlalchemy import (
1719
MemgraphConstraintExists,
1820
MemgraphConstraintUnique,
1921
MemgraphIndex,
22+
MemgraphEnum,
2023
Neo4jConstraintUnique,
2124
Neo4jIndex,
2225
Node,
@@ -56,6 +59,19 @@ def test_create_index(memgraph):
5659
assert actual_constraints == [memgraph_index]
5760

5861

62+
def test_create_graph_enum(memgraph):
63+
enum1 = Enum("MgEnum", (("MEMBER1", "value1"), ("MEMBER2", "value2"), ("MEMBER3", "value3")))
64+
65+
class Node3(Node):
66+
type: enum1
67+
68+
memgraph_enum = MemgraphEnum(enum1)
69+
70+
actual_enums = memgraph.get_enums()
71+
72+
assert actual_enums == [memgraph_enum]
73+
74+
5975
def test_create_constraint_unique_neo4j(neo4j):
6076
class Node2(Node):
6177
id: int = Field(db=neo4j)

0 commit comments

Comments
 (0)