44import pyarrow as pa
55from pyarrow .cffi import ffi as arrow_ffi
66
7+ from lakeapi2sql .utils import prepare_connection_string
8+
79
810class BulkInfoField (TypedDict ):
911 name : str
@@ -14,48 +16,14 @@ class BulkInfo(TypedDict):
1416 fields : list [BulkInfoField ]
1517
1618
17- async def _prepare_connection_string (connection_string : str , aad_token : str | None ) -> tuple [str , str | None ]:
18- if "authentication" in connection_string .lower ():
19- parts = [(kv [0 : kv .index ("=" )], kv [kv .index ("=" ) + 1 :]) for kv in connection_string .split (";" )]
20- auth_part = next ((p for p in parts if p [0 ].casefold () == "Authentication" .casefold ()))
21- parts .remove (auth_part )
22- credential = None
23- auth_method = auth_part [1 ].lower ()
24- if auth_method in ["ActiveDirectoryDefault" .lower ()]:
25- from azure .identity .aio import DefaultAzureCredential
26-
27- credential = DefaultAzureCredential ()
28- elif auth_method in ["ActiveDirectoryMSI" .lower (), "ActiveDirectoryManagedIdentity" .lower ()]:
29- from azure .identity .aio import ManagedIdentityCredential
30-
31- client_part = next ((p for p in parts if p [0 ].lower () in ["user" , "msiclientid" ]), None )
32- if client_part :
33- parts .remove (client_part )
34- credential = ManagedIdentityCredential (client_id = client_part [1 ] if client_part else None )
35- elif auth_method == "ActiveDirectoryInteractive" .lower ():
36- from azure .identity import InteractiveBrowserCredential
37-
38- credential = InteractiveBrowserCredential ()
39- elif auth_method == "SqlPassword" : # that's kind of an no-op
40- return ";" .join ((p [0 ] + "=" + p [1 ] for p in parts )), None
41- if credential is not None :
42- from azure .core .credentials import AccessToken
43-
44- res = credential .get_token ("https://database.windows.net/.default" )
45- token : AccessToken = await res if inspect .isawaitable (res ) else res # type: ignore
46- aad_token = token .token
47- return ";" .join ((p [0 ] + "=" + p [1 ] for p in parts )), aad_token
48- return connection_string , aad_token
49-
50-
5119async def insert_record_batch_to_sql (
5220 connection_string : str ,
5321 table_name : str ,
5422 reader : pa .RecordBatchReader ,
5523 col_names : list [str ] | None = None ,
5624 aad_token : str | None = None ,
5725):
58- connection_string , aad_token = await _prepare_connection_string (connection_string , aad_token )
26+ connection_string , aad_token = await prepare_connection_string (connection_string , aad_token )
5927
6028 return await lvd .insert_arrow_reader_to_sql (connection_string , reader , table_name , col_names or [], aad_token )
6129
@@ -68,7 +36,7 @@ async def insert_http_arrow_stream_to_sql(
6836 aad_token : str | None = None ,
6937 col_names : list [str ] | None = None ,
7038) -> BulkInfo :
71- connection_string , aad_token = await _prepare_connection_string (connection_string , aad_token )
39+ connection_string , aad_token = await prepare_connection_string (connection_string , aad_token )
7240
7341 return await lvd .insert_arrow_stream_to_sql (
7442 connection_string , table_name , col_names or [], url , basic_auth [0 ], basic_auth [1 ], aad_token
0 commit comments