Skip to content

Commit f16fc8f

Browse files
authored
Merge pull request #15 from bmsuisse/executesql
Executesql
2 parents 5bfc904 + 262639f commit f16fc8f

File tree

7 files changed

+310
-42
lines changed

7 files changed

+310
-42
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lake2sql"
3-
version = "0.8.3"
3+
version = "0.9.0"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

lakeapi2sql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .sql_connection import TdsConnection
2+
from .bulk_insert import insert_record_batch_to_sql

lakeapi2sql/bulk_insert.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pyarrow as pa
55
from pyarrow.cffi import ffi as arrow_ffi
66

7+
from lakeapi2sql.utils import prepare_connection_string
8+
79

810
class BulkInfoField(TypedDict):
911
name: str
@@ -14,48 +16,14 @@ class BulkInfo(TypedDict):
1416
fields: list[BulkInfoField]
1517

1618

17-
async def _prepare_connection_string(connection_string: str, aad_token: str | None) -> tuple[str, str | None]:
18-
if "authentication" in connection_string.lower():
19-
parts = [(kv[0 : kv.index("=")], kv[kv.index("=") + 1 :]) for kv in connection_string.split(";")]
20-
auth_part = next((p for p in parts if p[0].casefold() == "Authentication".casefold()))
21-
parts.remove(auth_part)
22-
credential = None
23-
auth_method = auth_part[1].lower()
24-
if auth_method in ["ActiveDirectoryDefault".lower()]:
25-
from azure.identity.aio import DefaultAzureCredential
26-
27-
credential = DefaultAzureCredential()
28-
elif auth_method in ["ActiveDirectoryMSI".lower(), "ActiveDirectoryManagedIdentity".lower()]:
29-
from azure.identity.aio import ManagedIdentityCredential
30-
31-
client_part = next((p for p in parts if p[0].lower() in ["user", "msiclientid"]), None)
32-
if client_part:
33-
parts.remove(client_part)
34-
credential = ManagedIdentityCredential(client_id=client_part[1] if client_part else None)
35-
elif auth_method == "ActiveDirectoryInteractive".lower():
36-
from azure.identity import InteractiveBrowserCredential
37-
38-
credential = InteractiveBrowserCredential()
39-
elif auth_method == "SqlPassword": # that's kind of an no-op
40-
return ";".join((p[0] + "=" + p[1] for p in parts)), None
41-
if credential is not None:
42-
from azure.core.credentials import AccessToken
43-
44-
res = credential.get_token("https://database.windows.net/.default")
45-
token: AccessToken = await res if inspect.isawaitable(res) else res # type: ignore
46-
aad_token = token.token
47-
return ";".join((p[0] + "=" + p[1] for p in parts)), aad_token
48-
return connection_string, aad_token
49-
50-
5119
async def insert_record_batch_to_sql(
5220
connection_string: str,
5321
table_name: str,
5422
reader: pa.RecordBatchReader,
5523
col_names: list[str] | None = None,
5624
aad_token: str | None = None,
5725
):
58-
connection_string, aad_token = await _prepare_connection_string(connection_string, aad_token)
26+
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
5927

6028
return await lvd.insert_arrow_reader_to_sql(connection_string, reader, table_name, col_names or [], aad_token)
6129

@@ -68,7 +36,7 @@ async def insert_http_arrow_stream_to_sql(
6836
aad_token: str | None = None,
6937
col_names: list[str] | None = None,
7038
) -> BulkInfo:
71-
connection_string, aad_token = await _prepare_connection_string(connection_string, aad_token)
39+
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
7240

7341
return await lvd.insert_arrow_stream_to_sql(
7442
connection_string, table_name, col_names or [], url, basic_auth[0], basic_auth[1], aad_token

lakeapi2sql/sql_connection.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import lakeapi2sql._lowlevel as lvd
2+
from lakeapi2sql.utils import prepare_connection_string
3+
4+
5+
class TdsConnection:
6+
def __init__(self, connection_string: str, aad_token: str | None = None) -> None:
7+
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
8+
self._connection_string = connection_string
9+
self._aad_token = aad_token
10+
11+
async def __aenter__(self) -> "TdsConnection":
12+
self._connection = await lvd.connect_sql(self.connection_string, self.aad_token)
13+
return self
14+
15+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
16+
pass
17+
18+
async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
19+
return await lvd.execute_sql(self._connection, sql, arguments)
20+
21+
async def execute_sql_with_result(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
22+
return await lvd.execute_sql_with_result(self._connection, sql, arguments)

lakeapi2sql/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import inspect
2+
3+
4+
async def prepare_connection_string(connection_string: str, aad_token: str | None) -> tuple[str, str | None]:
5+
if "authentication" in connection_string.lower():
6+
parts = [(kv[0 : kv.index("=")], kv[kv.index("=") + 1 :]) for kv in connection_string.split(";")]
7+
auth_part = next((p for p in parts if p[0].casefold() == "Authentication".casefold()))
8+
parts.remove(auth_part)
9+
credential = None
10+
auth_method = auth_part[1].lower()
11+
if auth_method in ["ActiveDirectoryDefault".lower()]:
12+
from azure.identity.aio import DefaultAzureCredential
13+
14+
credential = DefaultAzureCredential()
15+
elif auth_method in ["ActiveDirectoryMSI".lower(), "ActiveDirectoryManagedIdentity".lower()]:
16+
from azure.identity.aio import ManagedIdentityCredential
17+
18+
client_part = next((p for p in parts if p[0].lower() in ["user", "msiclientid"]), None)
19+
if client_part:
20+
parts.remove(client_part)
21+
credential = ManagedIdentityCredential(client_id=client_part[1] if client_part else None)
22+
elif auth_method == "ActiveDirectoryInteractive".lower():
23+
from azure.identity import InteractiveBrowserCredential
24+
25+
credential = InteractiveBrowserCredential()
26+
elif auth_method == "SqlPassword": # that's kind of an no-op
27+
return ";".join((p[0] + "=" + p[1] for p in parts)), None
28+
if credential is not None:
29+
from azure.core.credentials import AccessToken
30+
31+
res = credential.get_token("https://database.windows.net/.default")
32+
token: AccessToken = await res if inspect.isawaitable(res) else res # type: ignore
33+
aad_token = token.token
34+
return ";".join((p[0] + "=" + p[1] for p in parts)), aad_token
35+
return connection_string, aad_token

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "maturin"
55
[project]
66
name = "lakeapi2sql"
77
requires-python = ">=3.10"
8-
version = "0.8.4"
8+
version = "0.9.0"
99
classifiers = [
1010
"Programming Language :: Rust",
1111
"Programming Language :: Python :: Implementation :: CPython",

0 commit comments

Comments
 (0)