Skip to content

Commit 262639f

Browse files
committed
expose sql interface
1 parent bb3e27f commit 262639f

File tree

3 files changed

+213
-33
lines changed

3 files changed

+213
-33
lines changed

lakeapi2sql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .sql_connection import TdsConnection
2+
from .bulk_insert import insert_record_batch_to_sql

lakeapi2sql/sql_connection.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,20 @@
33

44

55
class TdsConnection:
6-
def __init__(self, connection) -> None:
7-
self._connection = connection
6+
def __init__(self, connection_string: str, aad_token: str | None = None) -> None:
7+
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
8+
self._connection_string = connection_string
9+
self._aad_token = aad_token
10+
11+
async def __aenter__(self) -> "TdsConnection":
12+
self._connection = await lvd.connect_sql(self.connection_string, self.aad_token)
13+
return self
14+
15+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
16+
pass
817

918
async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
1019
return await lvd.execute_sql(self._connection, sql, arguments)
1120

12-
13-
async def connect_sql(connection_string: str, aad_token: str | None = None) -> TdsConnection:
14-
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
15-
return TdsConnection(await lvd.connect_sql(connection_string, aad_token))
21+
async def execute_sql_with_result(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
22+
return await lvd.execute_sql_with_result(self._connection, sql, arguments)

src/lib.rs

Lines changed: 198 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ use std::sync::Arc;
44
use arrow::datatypes::{Field, Schema};
55
use arrow::ffi_stream::ArrowArrayStreamReader;
66
use arrow::pyarrow::FromPyArrow;
7+
use error::LakeApi2SqlError;
8+
use futures::{StreamExt, TryStreamExt};
79
use pyo3::exceptions::{PyConnectionError, PyIOError, PyTypeError};
810
use pyo3::prelude::*;
911
use pyo3::types::{PyDict, PyInt, PyList, PyString};
1012
mod arrow_convert;
1113
pub mod bulk_insert;
1214
pub mod connect;
1315
pub mod error;
14-
use tiberius::ToSql;
16+
use tiberius::{FromSql, QueryItem, QueryStream, ResultMetadata, Row, ToSql};
1517
use tokio::net::TcpStream;
1618

1719
fn field_into_dict<'a>(py: Python<'a>, field: &'a Field) -> &'a PyDict {
@@ -40,6 +42,105 @@ fn into_dict<'a>(py: Python<'a>, schema: Arc<Schema>) -> &PyDict {
4042
d.set_item("metadata", metadata.unwrap()).unwrap();
4143
d
4244
}
45+
fn into_dict_result<'a>(py: Python<'a>, meta: Option<ResultMetadata>, rows: Vec<Row>) -> &PyDict {
46+
let d = PyDict::new(py);
47+
if let Some(meta) = meta {
48+
let fields: Vec<&PyDict> = meta
49+
.columns()
50+
.iter()
51+
.map(|f| {
52+
let mut d = PyDict::new(py);
53+
d.set_item("name", f.name().clone()).unwrap();
54+
d.set_item("column_type", format!("{0:?}", f.column_type()))
55+
.unwrap();
56+
57+
d
58+
})
59+
.collect();
60+
61+
d.set_item("columns", fields).unwrap();
62+
}
63+
let mut py_rows = PyList::new(
64+
py,
65+
rows.iter().map(|row| {
66+
PyList::new(
67+
py,
68+
row.cells()
69+
.map(|(c, val)| match val {
70+
tiberius::ColumnData::U8(o) => o.into_py(py),
71+
tiberius::ColumnData::I16(o) => o.into_py(py),
72+
tiberius::ColumnData::I32(o) => o.into_py(py),
73+
tiberius::ColumnData::I64(o) => o.into_py(py),
74+
tiberius::ColumnData::F32(o) => o.into_py(py),
75+
tiberius::ColumnData::F64(o) => o.into_py(py),
76+
tiberius::ColumnData::Bit(o) => o.into_py(py),
77+
tiberius::ColumnData::String(o) => {
78+
o.as_ref().map(|x| x.clone().into_owned()).into_py(py)
79+
}
80+
tiberius::ColumnData::Guid(o) => o.map(|x| x.to_string()).into_py(py),
81+
tiberius::ColumnData::Binary(o) => {
82+
o.as_ref().map(|x| x.clone().into_owned()).into_py(py)
83+
}
84+
tiberius::ColumnData::Numeric(o) => o.map(|x| x.to_string()).into_py(py),
85+
tiberius::ColumnData::Xml(o) => {
86+
o.as_ref().map(|x| x.clone().to_string()).into_py(py)
87+
}
88+
tiberius::ColumnData::DateTime(o) => o
89+
.map(|x| {
90+
tiberius::time::time::PrimitiveDateTime::from_sql(&val)
91+
.unwrap()
92+
.unwrap()
93+
.to_string()
94+
})
95+
.into_py(py),
96+
tiberius::ColumnData::SmallDateTime(o) => o
97+
.map(|x| {
98+
tiberius::time::time::PrimitiveDateTime::from_sql(&val)
99+
.unwrap()
100+
.unwrap()
101+
.to_string()
102+
})
103+
.into_py(py),
104+
tiberius::ColumnData::Time(o) => o
105+
.map(|x| {
106+
tiberius::time::time::Time::from_sql(&val)
107+
.unwrap()
108+
.unwrap()
109+
.to_string()
110+
})
111+
.into_py(py),
112+
tiberius::ColumnData::Date(o) => o
113+
.map(|x| {
114+
tiberius::time::time::Date::from_sql(&val)
115+
.unwrap()
116+
.unwrap()
117+
.to_string()
118+
})
119+
.into_py(py),
120+
tiberius::ColumnData::DateTime2(o) => o
121+
.map(|x| {
122+
tiberius::time::time::PrimitiveDateTime::from_sql(&val)
123+
.unwrap()
124+
.unwrap()
125+
.to_string()
126+
})
127+
.into_py(py),
128+
tiberius::ColumnData::DateTimeOffset(o) => o
129+
.map(|x| {
130+
tiberius::time::time::PrimitiveDateTime::from_sql(&val)
131+
.unwrap()
132+
.unwrap()
133+
.to_string()
134+
})
135+
.into_py(py),
136+
})
137+
.collect::<Vec<PyObject>>(),
138+
)
139+
}),
140+
);
141+
d.set_item("rows", py_rows);
142+
d
143+
}
43144

44145
async fn insert_arrow_stream_to_sql_rs(
45146
connection_string: String,
@@ -132,6 +233,36 @@ fn connect_sql<'a>(
132233
}
133234
})
134235
}
236+
237+
struct ValueWrap(Box<dyn ToSql>);
238+
239+
impl ToSql for ValueWrap {
240+
fn to_sql(&self) -> tiberius::ColumnData<'_> {
241+
self.0.to_sql()
242+
}
243+
}
244+
245+
fn to_exec_args(args: Vec<&PyAny>) -> Result<Vec<ValueWrap>, PyErr> {
246+
let mut res: Vec<ValueWrap> = Vec::new();
247+
for i in 0..args.len() - 1 {
248+
let x = args[i];
249+
res.push(ValueWrap(if x.is_none() {
250+
Box::new(Option::<i64>::None) as Box<dyn ToSql>
251+
} else if let Ok(v) = x.extract::<i64>() {
252+
Box::new(v) as Box<dyn ToSql>
253+
} else if let Ok(v) = x.extract::<f64>() {
254+
Box::new(v) as Box<dyn ToSql>
255+
} else if let Ok(v) = x.extract::<String>() {
256+
Box::new(v) as Box<dyn ToSql>
257+
} else if let Ok(v) = x.extract::<bool>() {
258+
Box::new(v) as Box<dyn ToSql>
259+
} else {
260+
return Err(PyErr::new::<PyTypeError, _>("Unsupported type"));
261+
}))
262+
}
263+
Ok(res)
264+
}
265+
135266
#[pyfunction]
136267
fn execute_sql<'a>(
137268
py: Python<'a>,
@@ -149,41 +280,19 @@ fn execute_sql<'a>(
149280
list2
150281
});
151282
}
152-
let tds_args = args
153-
.iter()
154-
.map(|x| {
155-
if x.is_none() {
156-
let b_box: Box<dyn ToSql> = Box::new(Option::<i64>::None);
157-
Ok(b_box)
158-
} else if let Ok(v) = x.extract::<i64>() {
159-
let b_box: Box<dyn ToSql> = Box::new(v);
160-
Ok(b_box)
161-
} else if let Ok(v) = x.extract::<f64>() {
162-
let b_box: Box<dyn ToSql> = Box::new(v);
163-
Ok(b_box)
164-
} else if let Ok(v) = x.extract::<String>() {
165-
let b_box: Box<dyn ToSql> = Box::new(v);
166-
Ok(b_box)
167-
} else if let Ok(v) = x.extract::<bool>() {
168-
let b_box: Box<dyn ToSql> = Box::new(v);
169-
Ok(b_box)
170-
} else {
171-
Err(PyErr::new::<PyTypeError, _>("Unsupported type"))
172-
}
173-
})
174-
.collect::<Result<Vec<Box<dyn ToSql>>, PyErr>>()?;
283+
let tds_args = to_exec_args(args)?;
284+
175285
let mutex = conn.0.clone();
176286
pyo3_asyncio::tokio::future_into_py(py, async move {
177-
let prms: Vec<_> = tds_args.iter().map(|x| &(*x)).collect();
178287
let res = mutex
179288
.clone()
180289
.lock()
181290
.await
182291
.execute(
183292
query,
184-
&tds_args
293+
tds_args
185294
.iter()
186-
.map(|x| x.borrow() as &dyn ToSql)
295+
.map(|x| x.0.borrow() as &dyn ToSql)
187296
.collect::<Vec<&dyn ToSql>>()
188297
.as_slice(),
189298
)
@@ -198,6 +307,67 @@ fn execute_sql<'a>(
198307
})
199308
}
200309

310+
#[pyfunction]
311+
fn execute_sql_with_result<'a>(
312+
py: Python<'a>,
313+
conn: &MsSqlConnection,
314+
query: String,
315+
args: Vec<&PyAny>,
316+
) -> PyResult<&'a PyAny> {
317+
let tds_args = to_exec_args(args)?;
318+
319+
let mutex = conn.0.clone();
320+
pyo3_asyncio::tokio::future_into_py(py, async move {
321+
let arc = mutex.clone();
322+
let mut conn = arc.lock().await;
323+
let res = conn
324+
.query(
325+
query,
326+
tds_args
327+
.iter()
328+
.map(|x| x.0.borrow() as &dyn ToSql)
329+
.collect::<Vec<&dyn ToSql>>()
330+
.as_slice(),
331+
)
332+
.await;
333+
334+
match res {
335+
Ok(mut stream) => {
336+
let mut meta = None;
337+
let mut rows = vec![];
338+
while let Some(item) = stream
339+
.try_next()
340+
.await
341+
.map_err(|er| PyErr::new::<PyIOError, _>(format!("Error executing: {er}")))?
342+
{
343+
match item {
344+
// our first item is the column data always
345+
QueryItem::Metadata(m) if m.result_index() == 0 => {
346+
meta = Some(m);
347+
// the first result column info can be handled here
348+
}
349+
// ... and from there on from 0..N rows
350+
QueryItem::Row(row) if row.result_index() == 0 => rows.push(row),
351+
// the second result set returns first another metadata item
352+
QueryItem::Metadata(meta) => {
353+
break;
354+
}
355+
// ...and, again, we get rows from the second resultset
356+
QueryItem::Row(row) => {
357+
break;
358+
}
359+
}
360+
}
361+
Ok(Python::with_gil(|py| {
362+
let d: Py<PyDict> = into_dict_result(py, meta, rows).into();
363+
d
364+
}))
365+
}
366+
Err(er) => Err(PyErr::new::<PyIOError, _>(format!("Error executing: {er}"))),
367+
}
368+
})
369+
}
370+
201371
#[pyfunction]
202372
fn insert_arrow_reader_to_sql<'a>(
203373
py: Python<'a>,
@@ -237,6 +407,7 @@ fn _lowlevel(_py: Python, m: &PyModule) -> PyResult<()> {
237407
m.add_function(wrap_pyfunction!(insert_arrow_stream_to_sql, m)?)?;
238408
m.add_function(wrap_pyfunction!(connect_sql, m)?)?;
239409
m.add_function(wrap_pyfunction!(execute_sql, m)?)?;
410+
m.add_function(wrap_pyfunction!(execute_sql_with_result, m)?)?;
240411
m.add_function(wrap_pyfunction!(insert_arrow_reader_to_sql, m)?)?;
241412

242413
Ok(())

0 commit comments

Comments
 (0)