Skip to content

Commit 5bfc904

Browse files
authored
Merge pull request #14 from bmsuisse/dev
better error handling
2 parents b2a6329 + 8099b99 commit 5bfc904

File tree

8 files changed

+116
-140
lines changed

8 files changed

+116
-140
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ log = "0.4.19"
1515
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
1616
pyo3-log = "0.9.0"
1717
rust_decimal = "1.32.0"
18+
thiserror = "1.0.59"
1819

1920
time = "0.3.22"
2021
tokio = { version = "1.28.2", features = ["net", "macros"] }
@@ -25,7 +26,7 @@ reqwest = { version = "0.11.18", features = [
2526
"stream",
2627
"rustls-tls-native-roots",
2728
], default-features = false }
28-
tiberius = { git = "https://github.com/aersam/tiberius.git", branch = "bulk", features = [
29+
tiberius = { git = "https://github.com/aersam/tiberius.git", branch = "expose_ado_net", features = [
2930
"time",
3031
"sql-browser-tokio",
3132
"rust_decimal",
@@ -35,7 +36,7 @@ tiberius = { git = "https://github.com/aersam/tiberius.git", branch = "bulk", fe
3536

3637
[target.'cfg(not(target_os="linux"))'.dependencies]
3738
reqwest = { version = "0.11.18", features = ["stream"] }
38-
tiberius = { git = "https://github.com/aersam/tiberius.git", branch = "bulk", features = [
39+
tiberius = { git = "https://github.com/aersam/tiberius.git", branch = "expose_ado_net", features = [
3940
"time",
4041
"sql-browser-tokio",
4142
"rust_decimal",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "maturin"
55
[project]
66
name = "lakeapi2sql"
77
requires-python = ">=3.10"
8-
version = "0.8.3"
8+
version = "0.8.4"
99
classifiers = [
1010
"Programming Language :: Rust",
1111
"Programming Language :: Python :: Implementation :: CPython",

src/arrow_convert.rs

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,7 @@ use tiberius::ColumnType;
3838
use tiberius::ToSql;
3939
use tiberius::TokenRow;
4040

41-
#[derive(Debug)]
42-
pub(crate) struct NotSupportedError {
43-
dtype: DataType,
44-
column_type: ColumnType,
45-
}
46-
impl Display for NotSupportedError {
47-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48-
f.write_fmt(format_args!(
49-
"Cannot use data type {}. Sql Type: {:?}",
50-
self.dtype, self.column_type
51-
))
52-
}
53-
}
54-
impl std::error::Error for NotSupportedError {
55-
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
56-
None
57-
}
58-
59-
fn description(&self) -> &str {
60-
"description() is deprecated; use Display"
61-
}
62-
63-
fn cause(&self) -> Option<&dyn std::error::Error> {
64-
self.source()
65-
}
66-
}
41+
use crate::error::LakeApi2SqlError;
6742

6843
fn to_col_dt<'a>(d: ColumnData<'_>) -> ColumnData<'a> {
6944
match d {
@@ -77,11 +52,14 @@ fn to_col_dt<'a>(d: ColumnData<'_>) -> ColumnData<'a> {
7752
pub(crate) fn get_token_rows<'a, 'b>(
7853
batch: &'a RecordBatch,
7954
colsnames: &'b Vec<(String, ColumnType)>,
80-
) -> Result<Vec<TokenRow<'a>>, Box<dyn std::error::Error + Send + Sync>> {
81-
let unix_min_date = Date::from_calendar_date(1970, tiberius::time::time::Month::January, 1)?;
82-
let sql_min_date = Date::from_calendar_date(1, tiberius::time::time::Month::January, 1)?;
83-
let sql_min_datetime = Date::from_calendar_date(1900, tiberius::time::time::Month::January, 1)?;
84-
let unix_min: PrimitiveDateTime = unix_min_date.with_time(Time::from_hms(0, 0, 0)?);
55+
) -> Result<Vec<TokenRow<'a>>, LakeApi2SqlError> {
56+
let unix_min_date =
57+
Date::from_calendar_date(1970, tiberius::time::time::Month::January, 1).unwrap();
58+
let sql_min_date =
59+
Date::from_calendar_date(1, tiberius::time::time::Month::January, 1).unwrap();
60+
let sql_min_datetime =
61+
Date::from_calendar_date(1900, tiberius::time::time::Month::January, 1).unwrap();
62+
let unix_min: PrimitiveDateTime = unix_min_date.with_time(Time::from_hms(0, 0, 0).unwrap());
8563
let sql_min_to_unix_min = (unix_min_date - sql_min_date).whole_days();
8664
let sql_min_dt_to_unix_min = (unix_min_date - sql_min_datetime).whole_days();
8765
let rows = batch.num_rows();
@@ -503,7 +481,7 @@ pub(crate) fn get_token_rows<'a, 'b>(
503481
}
504482
arrow::datatypes::DataType::Decimal128(_, s) => {
505483
let ba = col.as_any().downcast_ref::<Decimal128Array>().unwrap();
506-
let scale: u8 = s.clone().try_into()?;
484+
let scale: u8 = s.clone().try_into().unwrap();
507485
let mut rowindex = 0;
508486
match coltype {
509487
ColumnType::Numericn | ColumnType::Decimaln => {
@@ -525,18 +503,18 @@ pub(crate) fn get_token_rows<'a, 'b>(
525503
}
526504
}
527505
_ => {
528-
return Err(Box::new(NotSupportedError {
506+
return Err(LakeApi2SqlError::NotSupported {
529507
dtype: col.data_type().clone(),
530508
column_type: coltype.clone(),
531-
}))
509+
})
532510
} //other => panic!("Not supported {:?}", other),
533511
}
534512
}
535513
dt => {
536-
return Err(Box::new(NotSupportedError {
514+
return Err(LakeApi2SqlError::NotSupported {
537515
dtype: dt.clone(),
538516
column_type: coltype.clone(),
539-
}))
517+
})
540518
} //other => panic!("Not supported {:?}", other),
541519
}
542520
}

src/bulk_insert.rs

Lines changed: 22 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ use std::{fmt::Display, sync::Arc};
22

33
use arrow::ffi_stream::ArrowArrayStreamReader;
44
use arrow::record_batch::RecordBatchReader;
5-
use arrow::{
6-
datatypes::Schema, error::ArrowError, ipc::reader::StreamReader, record_batch::RecordBatch,
7-
};
5+
use arrow::{datatypes::Schema, ipc::reader::StreamReader, record_batch::RecordBatch};
86
use futures::stream::TryStreamExt;
97
use log::info;
108
use tiberius::Client;
@@ -19,58 +17,13 @@ use tokio::sync::mpsc;
1917
use tokio::task;
2018

2119
use crate::arrow_convert::get_token_rows;
22-
23-
#[derive(Debug)]
24-
pub(crate) struct ArrowErrorWrap {
25-
error: ArrowError,
26-
}
27-
impl Display for ArrowErrorWrap {
28-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29-
f.write_fmt(format_args!("arrow error {}", self.error))
30-
}
31-
}
32-
impl std::error::Error for ArrowErrorWrap {
33-
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
34-
None
35-
}
36-
37-
fn description(&self) -> &str {
38-
"description() is deprecated; use Display"
39-
}
40-
41-
fn cause(&self) -> Option<&dyn std::error::Error> {
42-
self.source()
43-
}
44-
}
45-
46-
#[derive(Debug)]
47-
pub(crate) struct SendErrorWrap {
48-
error: String,
49-
}
50-
impl Display for SendErrorWrap {
51-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52-
f.write_fmt(format_args!("send error {}", self.error))
53-
}
54-
}
55-
impl std::error::Error for SendErrorWrap {
56-
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
57-
None
58-
}
59-
60-
fn description(&self) -> &str {
61-
"description() is deprecated; use Display"
62-
}
63-
64-
fn cause(&self) -> Option<&dyn std::error::Error> {
65-
self.source()
66-
}
67-
}
20+
use crate::error::LakeApi2SqlError;
6821

6922
async fn get_cols_from_table(
7023
db_client: &mut Client<Compat<TcpStream>>,
7124
table_name: &str,
7225
column_names: &[&str],
73-
) -> Result<Vec<(String, ColumnType)>, Box<dyn std::error::Error + Send + Sync>> {
26+
) -> Result<Vec<(String, ColumnType)>, LakeApi2SqlError> {
7427
let cols_sql = match column_names.len() {
7528
0 => "*".to_owned(),
7629
_ => column_names
@@ -95,7 +48,7 @@ pub async fn bulk_insert_batch<'a>(
9548
blk: &mut tiberius::BulkLoadRequest<'a, Compat<TcpStream>>,
9649
batch: &'a RecordBatch,
9750
collist: &'a Vec<(String, ColumnType)>,
98-
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
51+
) -> Result<(), LakeApi2SqlError> {
9952
let nrows = batch.num_rows();
10053
info!("{table_name}: received {nrows}");
10154
let rows = task::block_in_place(|| get_token_rows(batch, &collist))?;
@@ -114,7 +67,7 @@ pub async fn bulk_insert<'a>(
11467
url: &str,
11568
user: &str,
11669
password: &str,
117-
) -> Result<Arc<Schema>, Box<dyn std::error::Error + Send + Sync>> {
70+
) -> Result<Arc<Schema>, LakeApi2SqlError> {
11871
//let mut row = TokenRow::new();
11972
//row.push(1.into_sql());
12073
//blk.send(row).await?;
@@ -141,32 +94,22 @@ pub async fn bulk_insert<'a>(
14194
.compat();
14295
let (tx, mut rx) = mpsc::channel::<RecordBatch>(2);
14396
let syncstr = SyncIoBridge::new(res);
144-
let worker = tokio::task::spawn_blocking(
145-
move || -> Result<Arc<Schema>, Box<dyn std::error::Error + Send + Sync>> {
146-
let reader = StreamReader::try_new(syncstr, None);
147-
if let Err(err) = reader {
148-
return Err(Box::new(ArrowErrorWrap { error: err }));
149-
}
150-
let mut reader = reader.unwrap();
151-
let schema = reader.schema();
152-
loop {
153-
match reader.next() {
154-
Some(x) => match x {
155-
Ok(b) => {
156-
tx.blocking_send(b).map_err(|e| {
157-
Box::new(SendErrorWrap {
158-
error: e.to_string(),
159-
})
160-
})?;
161-
}
162-
Err(l) => println!("{:?}", l),
163-
},
164-
None => break,
165-
};
166-
}
167-
Ok(schema)
168-
},
169-
);
97+
let worker = tokio::task::spawn_blocking(move || -> Result<Arc<Schema>, LakeApi2SqlError> {
98+
let mut reader = StreamReader::try_new(syncstr, None)?;
99+
let schema = reader.schema();
100+
loop {
101+
match reader.next() {
102+
Some(x) => match x {
103+
Ok(b) => {
104+
tx.blocking_send(b)?;
105+
}
106+
Err(l) => println!("{:?}", l),
107+
},
108+
None => break,
109+
};
110+
}
111+
Ok(schema)
112+
});
170113
while let Some(v) = rx.recv().await {
171114
let mut blk = db_client
172115
.bulk_insert_with_options(
@@ -188,7 +131,7 @@ pub async fn bulk_insert_reader(
188131
table_name: &str,
189132
column_names: &[&str],
190133
reader: &mut ArrowArrayStreamReader,
191-
) -> Result<Arc<Schema>, Box<dyn std::error::Error + Send + Sync>> {
134+
) -> Result<Arc<Schema>, LakeApi2SqlError> {
192135
//let mut row = TokenRow::new();
193136
//row.push(1.into_sql());
194137
//blk.send(row).await?;

src/connect.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tiberius::client::AdoNetConfig;
2+
use tiberius::client::ConfigString;
13
use tiberius::error::Error;
24
use tiberius::AuthMethod;
35
use tiberius::Client;
@@ -7,10 +9,12 @@ use tokio::net::TcpStream;
79
use tokio_util::compat::Compat;
810
use tokio_util::compat::TokioAsyncWriteCompatExt;
911

12+
use crate::error::LakeApi2SqlError;
13+
1014
pub async fn connect_sql(
1115
con_str: &str,
1216
aad_token: Option<String>,
13-
) -> Result<Client<Compat<TcpStream>>, Box<dyn std::error::Error + Send + Sync>> {
17+
) -> Result<Client<Compat<TcpStream>>, LakeApi2SqlError> {
1418
let mut config = Config::from_ado_string(con_str)?;
1519
if let Some(tv) = aad_token.clone() {
1620
config.authentication(AuthMethod::AADToken(tv));
@@ -26,18 +30,30 @@ pub async fn connect_sql(
2630
Ok(client) => client,
2731
// The server wants us to redirect to a different address
2832
Err(Error::Routing { host, port }) => {
29-
let mut config = Config::from_ado_string(con_str)?;
33+
let ado_cfg: AdoNetConfig = con_str.parse()?;
34+
let ado_cfg2: AdoNetConfig = con_str.parse()?;
35+
36+
let mut config = Config::from_config_string(ado_cfg)?;
3037
if let Some(tv) = aad_token {
3138
config.authentication(AuthMethod::AADToken(tv));
3239
}
3340
config.host(&host);
3441
config.port(port);
42+
let instance = match ado_cfg2.server() {
43+
Ok(v) => v.instance,
44+
Err(_) => None,
45+
};
46+
if instance.is_some() {
47+
let tcp = TcpStream::connect_named(&config).await?;
48+
tcp.set_nodelay(true)?;
49+
Client::connect(config, tcp.compat_write()).await?
50+
} else {
51+
let tcp = TcpStream::connect(config.get_addr()).await?;
52+
tcp.set_nodelay(true)?;
3553

36-
let tcp = TcpStream::connect(config.get_addr()).await?;
37-
tcp.set_nodelay(true)?;
38-
39-
// we should not have more than one redirect, so we'll short-circuit here.
40-
Client::connect(config, tcp.compat_write()).await?
54+
// we should not have more than one redirect, so we'll short-circuit here.
55+
Client::connect(config, tcp.compat_write()).await?
56+
}
4157
}
4258
Err(e) => Err(e)?,
4359
};

src/error.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
use pyo3::{exceptions::*, PyErr};
2+
use thiserror::Error;
3+
4+
#[derive(Error, Debug)]
5+
6+
pub enum LakeApi2SqlError {
7+
#[error("Error connecting: {dtype} to {column_type:?}")]
8+
NotSupported {
9+
dtype: arrow::datatypes::DataType,
10+
column_type: tiberius::ColumnType,
11+
},
12+
13+
#[error("Error joining: {0}")]
14+
JoinError(#[from] tokio::task::JoinError),
15+
16+
#[error("Arrow Error: {0}")]
17+
ArrowError(#[from] arrow::error::ArrowError),
18+
19+
#[error("IO Error: {0}")]
20+
IOError(#[from] std::io::Error),
21+
22+
#[error("HTTP Error: {0}")]
23+
HttpError(#[from] reqwest::Error),
24+
25+
#[error("Send Error: {0}")]
26+
SendError(#[from] tokio::sync::mpsc::error::SendError<arrow::array::RecordBatch>),
27+
28+
#[error(transparent)]
29+
TiberiusError(#[from] tiberius::error::Error),
30+
}
31+
32+
impl From<LakeApi2SqlError> for PyErr {
33+
fn from(val: LakeApi2SqlError) -> Self {
34+
match val {
35+
v @ LakeApi2SqlError::NotSupported {
36+
dtype: _,
37+
column_type: _,
38+
} => PyErr::new::<PyTypeError, _>(format!("{:?}", v)),
39+
LakeApi2SqlError::JoinError(e) => PyErr::new::<PyIOError, _>(format!("{:?}", e)),
40+
LakeApi2SqlError::ArrowError(e) => PyErr::new::<PyValueError, _>(format!("{:?}", e)),
41+
LakeApi2SqlError::IOError(e) => PyErr::new::<PyIOError, _>(format!("{:?}", e)),
42+
LakeApi2SqlError::HttpError(e) => PyErr::new::<PyIOError, _>(format!("{:?}", e)),
43+
LakeApi2SqlError::SendError(e) => PyErr::new::<PyIOError, _>(format!("{:?}", e)),
44+
LakeApi2SqlError::TiberiusError(e) => PyErr::new::<PyIOError, _>(format!("{:?}", e)),
45+
}
46+
}
47+
}

0 commit comments

Comments
 (0)