Skip to content

Fix issues with signup and signin #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: unit-tests
name: Unit tests

on:
push:
Expand Down
62 changes: 34 additions & 28 deletions src/surrealdb/connections/async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,30 @@ def set_token(self, token: str) -> None:
"""
self.token = token

async def authenticate(self) -> None:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
self.token = response["result"]
return response["result"]

async def signin(self, vars: dict) -> dict:
message = RequestMessage(
self.id,
Expand All @@ -112,9 +136,16 @@ async def signin(self, vars: dict) -> dict:
response = await self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
package = dict()
package["token"] = self.token
return package
return response["result"]

async def info(self) -> dict:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = await self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
Expand Down Expand Up @@ -187,15 +218,6 @@ async def delete(
self.check_response_for_result(response, "delete")
return response["result"]

async def info(self) -> dict:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = await self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

async def insert(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
Expand All @@ -222,11 +244,6 @@ async def insert_relation(
self.check_response_for_result(response, "insert_relation")
return response["result"]

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def let(self, key: str, value: Any) -> None:
self.vars[key] = value

Expand Down Expand Up @@ -306,17 +323,6 @@ async def upsert(
self.check_response_for_result(response, "upsert")
return response["result"]

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
self.token = response["result"]
return response["result"]

async def __aenter__(self) -> "AsyncHttpSurrealConnection":
"""
Asynchronous context manager entry.
Expand Down
38 changes: 19 additions & 19 deletions src/surrealdb/connections/async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ async def use(self, namespace: str, database: str) -> None:
"""
raise NotImplementedError(f"query not implemented for: {self}")

async def authenticate(self, token: str) -> None:
"""Authenticate the current connection with a JWT token.

Args:
token: The JWT authentication token.

Example:
await db.authenticate('insert token here')
"""
raise NotImplementedError(f"authenticate not implemented for: {self}")

async def invalidate(self) -> None:
"""Invalidate the authentication for the current connection.

Example:
await db.invalidate()
"""
raise NotImplementedError(f"invalidate not implemented for: {self}")

async def signup(self, vars: Dict) -> str:
"""Sign this connection up to a specific authentication scope.
[See the docs](https://surrealdb.com/docs/sdk/python/methods/signup)
Expand Down Expand Up @@ -77,25 +96,6 @@ async def signin(self, vars: Dict) -> str:
"""
raise NotImplementedError(f"query not implemented for: {self}")

async def invalidate(self) -> None:
"""Invalidate the authentication for the current connection.

Example:
await db.invalidate()
"""
raise NotImplementedError(f"invalidate not implemented for: {self}")

async def authenticate(self, token: str) -> None:
"""Authenticate the current connection with a JWT token.

Args:
token: The JWT authentication token.

Example:
await db.authenticate('insert token here')
"""
raise NotImplementedError(f"authenticate not implemented for: {self}")

async def let(self, key: str, value: Any) -> None:
"""Assign a value as a variable for this connection.

Expand Down
85 changes: 41 additions & 44 deletions src/surrealdb/connections/async_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,28 @@ async def connect(self, url: Optional[str] = None, max_size: Optional[int] = Non
subprotocols=[websockets.Subprotocol("cbor")]
)

# async def signup(self, vars: Dict[str, Any]) -> str:
async def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

async def signin(self, vars: Dict[str, Any]) -> str:
message = RequestMessage(
Expand All @@ -96,9 +117,25 @@ async def signin(self, vars: Dict[str, Any]) -> str:
response = await self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
if response.get("id") is None:
raise Exception(f"no id signing in: {response}")
self.id = response["id"]
return response["result"]

async def info(self) -> Optional[dict]:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
outcome = await self._send(message, "getting database information")
self.check_response_for_result(outcome, "getting database information")
return outcome["result"]

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
self.id,
RequestMethod.USE,
namespace=namespace,
database=database,
)
await self._send(message, "use")

async def query(self, query: str, params: Optional[dict] = None) -> dict:
if params is None:
Expand All @@ -125,24 +162,6 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict:
response = await self._send(message, "query", bypass=True)
return response

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
self.id,
RequestMethod.USE,
namespace=namespace,
database=database,
)
await self._send(message, "use")

async def info(self) -> Optional[dict]:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
outcome = await self._send(message, "getting database information")
self.check_response_for_result(outcome, "getting database information")
return outcome["result"]

async def version(self) -> str:
message = RequestMessage(
self.id,
Expand All @@ -152,18 +171,6 @@ async def version(self) -> str:
self.check_response_for_result(response, "getting database version")
return response["result"]

async def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")

async def let(self, key: str, value: Any) -> None:
message = RequestMessage(
self.id,
Expand Down Expand Up @@ -331,16 +338,6 @@ async def kill(self, query_uuid: Union[str, UUID]) -> None:
)
await self._send(message, "kill")

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

async def upsert(
self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None
) -> Union[List[dict], dict]:
Expand Down
50 changes: 33 additions & 17 deletions src/surrealdb/connections/blocking_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ def _send(self, message: RequestMessage, operation: str, bypass: bool = False) -
def set_token(self, token: str) -> None:
self.token = token

def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return self._send(message, "authenticating")

def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
self._send(message, "invalidating")
self.token = None

def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

def signin(self, vars: dict) -> dict:
message = RequestMessage(
self.id,
Expand All @@ -65,9 +88,16 @@ def signin(self, vars: dict) -> dict:
response = self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
package = dict()
package["token"] = self.token
return package
return response["result"]

def info(self):
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
Expand Down Expand Up @@ -140,15 +170,6 @@ def delete(
self.check_response_for_result(response, "delete")
return response["result"]

def info(self):
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

def insert(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
Expand All @@ -175,11 +196,6 @@ def insert_relation(
self.check_response_for_result(response, "insert_relation")
return response["result"]

def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
self._send(message, "invalidating")
self.token = None

def let(self, key: str, value: Any) -> None:
self.vars[key] = value

Expand Down
Loading
Loading