Skip to content

Commit d874c3c

Browse files
committed
adding init
1 parent 06ace1b commit d874c3c

File tree

4 files changed

+259
-0
lines changed

4 files changed

+259
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ rust-version = "1.76.0"
1313

1414
[dependencies]
1515
datafusion = "42"
16+
paste = "1"
17+
1618

1719
[lints.clippy]
1820
dbg_macro = "deny"

src/common.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#[macro_export]
2+
macro_rules! make_udaf_expr {
3+
($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
4+
#[doc = $DOC]
5+
pub fn $EXPR_FN(
6+
$($arg: datafusion::logical_expr::Expr,)*
7+
) -> datafusion::logical_expr::Expr {
8+
datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf(
9+
$AGGREGATE_UDF_FN(),
10+
vec![$($arg),*],
11+
false,
12+
None,
13+
None,
14+
None,
15+
))
16+
}
17+
};
18+
}
19+
20+
#[macro_export]
21+
macro_rules! create_func {
22+
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
23+
create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default());
24+
};
25+
($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => {
26+
paste::paste! {
27+
#[allow(non_upper_case_globals)]
28+
static [< STATIC_ $UDAF >]: std::sync::OnceLock<std::sync::Arc<datafusion::logical_expr::AggregateUDF>> =
29+
std::sync::OnceLock::new();
30+
31+
#[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion::logical_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")]
32+
pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<datafusion::logical_expr::AggregateUDF> {
33+
[< STATIC_ $UDAF >]
34+
.get_or_init(|| {
35+
std::sync::Arc::new(datafusion::logical_expr::AggregateUDF::from($CREATE))
36+
})
37+
.clone()
38+
}
39+
}
40+
}
41+
}
42+
43+
#[macro_export]
44+
macro_rules! make_udaf_expr_and_func {
45+
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
46+
make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN);
47+
create_func!($UDAF, $AGGREGATE_UDF_FN);
48+
};
49+
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
50+
#[doc = $DOC]
51+
pub fn $EXPR_FN(
52+
args: Vec<datafusion::logical_expr::Expr>,
53+
) -> datafusion::logical_expr::Expr {
54+
datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf(
55+
$AGGREGATE_UDF_FN(),
56+
args,
57+
false,
58+
None,
59+
None,
60+
None,
61+
))
62+
}
63+
64+
create_func!($UDAF, $AGGREGATE_UDF_FN);
65+
};
66+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod common;
2+
pub mod max_min_by;

src/max_min_by.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
use crate::create_func;
2+
use crate::make_udaf_expr;
3+
use crate::make_udaf_expr_and_func;
4+
use datafusion::arrow::datatypes::DataType;
5+
use datafusion::common::{exec_err, Result};
6+
use datafusion::functions_aggregate::first_last::last_value_udaf;
7+
use datafusion::logical_expr::expr;
8+
use datafusion::logical_expr::expr::AggregateFunction;
9+
use datafusion::logical_expr::expr::Sort;
10+
use datafusion::logical_expr::function;
11+
use datafusion::logical_expr::function::AccumulatorArgs;
12+
use datafusion::logical_expr::simplify::SimplifyInfo;
13+
use datafusion::logical_expr::Expr;
14+
use datafusion::logical_expr::Volatility;
15+
use datafusion::logical_expr::{AggregateUDFImpl, Signature};
16+
use datafusion::physical_plan::Accumulator;
17+
use std::any::Any;
18+
use std::fmt::Debug;
19+
use std::ops::Deref;
20+
21+
make_udaf_expr_and_func!(
22+
MaxByFunction,
23+
max_by,
24+
x y,
25+
"Returns the value of the first column corresponding to the maximum value in the second column.",
26+
max_by_udaf
27+
);
28+
29+
pub struct MaxByFunction {
30+
signature: Signature,
31+
}
32+
33+
impl Debug for MaxByFunction {
34+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
35+
f.debug_struct("MaxBy")
36+
.field("name", &self.name())
37+
.field("signature", &self.signature)
38+
.field("accumulator", &"<FUNC>")
39+
.finish()
40+
}
41+
}
42+
impl Default for MaxByFunction {
43+
fn default() -> Self {
44+
Self::new()
45+
}
46+
}
47+
48+
impl MaxByFunction {
49+
pub fn new() -> Self {
50+
Self {
51+
signature: Signature::user_defined(Volatility::Immutable),
52+
}
53+
}
54+
}
55+
56+
fn get_min_max_by_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
57+
match &input_types[0] {
58+
DataType::Dictionary(_, dict_value_type) => {
59+
// TODO add checker, if the value type is complex data type
60+
Ok(vec![dict_value_type.deref().clone()])
61+
}
62+
_ => Ok(input_types.to_vec()),
63+
}
64+
}
65+
66+
impl AggregateUDFImpl for MaxByFunction {
67+
fn as_any(&self) -> &dyn Any {
68+
self
69+
}
70+
71+
fn name(&self) -> &str {
72+
"max_by"
73+
}
74+
75+
fn signature(&self) -> &Signature {
76+
&self.signature
77+
}
78+
79+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
80+
Ok(arg_types[0].to_owned())
81+
}
82+
83+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
84+
exec_err!("should not reach here")
85+
}
86+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
87+
get_min_max_by_result_type(arg_types)
88+
}
89+
90+
fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
91+
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
92+
let mut order_by = aggr_func.order_by.unwrap_or_default();
93+
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));
94+
95+
order_by.push(Sort::new(second_arg, true, false));
96+
97+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
98+
last_value_udaf(),
99+
vec![first_arg],
100+
aggr_func.distinct,
101+
aggr_func.filter,
102+
Some(order_by),
103+
aggr_func.null_treatment,
104+
)))
105+
};
106+
Some(Box::new(simplify))
107+
}
108+
}
109+
110+
make_udaf_expr_and_func!(
111+
MinByFunction,
112+
min_by,
113+
x y,
114+
"Returns the value of the first column corresponding to the minimum value in the second column.",
115+
min_by_udaf
116+
);
117+
118+
pub struct MinByFunction {
119+
signature: Signature,
120+
}
121+
122+
impl Debug for MinByFunction {
123+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
124+
f.debug_struct("MinBy")
125+
.field("name", &self.name())
126+
.field("signature", &self.signature)
127+
.field("accumulator", &"<FUNC>")
128+
.finish()
129+
}
130+
}
131+
132+
impl Default for MinByFunction {
133+
fn default() -> Self {
134+
Self::new()
135+
}
136+
}
137+
138+
impl MinByFunction {
139+
pub fn new() -> Self {
140+
Self {
141+
signature: Signature::user_defined(Volatility::Immutable),
142+
}
143+
}
144+
}
145+
146+
impl AggregateUDFImpl for MinByFunction {
147+
fn as_any(&self) -> &dyn Any {
148+
self
149+
}
150+
151+
fn name(&self) -> &str {
152+
"min_by"
153+
}
154+
155+
fn signature(&self) -> &Signature {
156+
&self.signature
157+
}
158+
159+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
160+
Ok(arg_types[0].to_owned())
161+
}
162+
163+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
164+
exec_err!("should not reach here")
165+
}
166+
167+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
168+
get_min_max_by_result_type(arg_types)
169+
}
170+
171+
fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
172+
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
173+
let mut order_by = aggr_func.order_by.unwrap_or_default();
174+
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));
175+
176+
order_by.push(Sort::new(second_arg, false, false)); // false for ascending sort
177+
178+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
179+
last_value_udaf(),
180+
vec![first_arg],
181+
aggr_func.distinct,
182+
aggr_func.filter,
183+
Some(order_by),
184+
aggr_func.null_treatment,
185+
)))
186+
};
187+
Some(Box::new(simplify))
188+
}
189+
}

0 commit comments

Comments
 (0)