Skip to content

Commit 49aac6b

Browse files
committed
adding max by and min by function
1 parent f89f200 commit 49aac6b

File tree

4 files changed

+323
-2
lines changed

4 files changed

+323
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ datafusion = "42"
3838
hashbrown = { version = "0.14.5", features = ["raw"] }
3939
log = "^0.4"
4040
paste = "1"
41+
arrow = { version = "53.0.0", features = ["test_utils"] }
4142

4243
[dev-dependencies]
4344
arrow = { version = "53.0.0", features = ["test_utils"] }

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ use datafusion::logical_expr::AggregateUDF;
2626
#[macro_use]
2727
pub mod macros;
2828
pub mod common;
29+
pub mod max_min_by;
2930
pub mod mode;
30-
3131
pub mod expr_extra_fn {
3232
pub use super::mode::mode;
3333
}
3434

3535
pub fn all_extra_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
36-
vec![mode_udaf()]
36+
vec![mode_udaf(), max_min_by::max_by_udaf(), max_min_by::min_by_udaf()]
3737
}
3838

3939
/// Registers all enabled packages with a [`FunctionRegistry`]

src/max_min_by.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
use arrow::datatypes::DataType;
2+
use datafusion::error::DataFusionError;
3+
use datafusion::functions_aggregate::first_last::last_value_udaf;
4+
use datafusion::logical_expr::expr::AggregateFunction;
5+
use datafusion::logical_expr::expr::Sort;
6+
use datafusion::logical_expr::simplify::SimplifyInfo;
7+
use datafusion::logical_expr::{expr, function, Accumulator, AggregateUDFImpl};
8+
use datafusion::prelude::Expr;
9+
use datafusion::{
10+
common::exec_err,
11+
logical_expr::{function::AccumulatorArgs, Signature, Volatility},
12+
};
13+
use std::any::Any;
14+
use std::fmt::Debug;
15+
use std::ops::Deref;
16+
17+
make_udaf_expr_and_func!(
18+
MaxByFunction,
19+
max_by,
20+
x y,
21+
"Returns the value of the first column corresponding to the maximum value in the second column.",
22+
max_by_udaf
23+
);
24+
25+
pub struct MaxByFunction {
26+
signature: Signature,
27+
}
28+
29+
impl Debug for MaxByFunction {
30+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31+
f.debug_struct("MaxBy")
32+
.field("name", &self.name())
33+
.field("signature", &self.signature)
34+
.field("accumulator", &"<FUNC>")
35+
.finish()
36+
}
37+
}
38+
impl Default for MaxByFunction {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl MaxByFunction {
45+
pub fn new() -> Self {
46+
Self {
47+
signature: Signature::user_defined(Volatility::Immutable),
48+
}
49+
}
50+
}
51+
52+
fn get_min_max_by_result_type(input_types: &[DataType]) -> Result<Vec<DataType>, DataFusionError> {
53+
match &input_types[0] {
54+
DataType::Dictionary(_, dict_value_type) => {
55+
// TODO add checker, if the value type is complex data type
56+
Ok(vec![dict_value_type.deref().clone()])
57+
}
58+
_ => Ok(input_types.to_vec()),
59+
}
60+
}
61+
62+
impl AggregateUDFImpl for MaxByFunction {
63+
fn as_any(&self) -> &dyn Any {
64+
self
65+
}
66+
67+
fn name(&self) -> &str {
68+
"max_by"
69+
}
70+
71+
fn signature(&self) -> &Signature {
72+
&self.signature
73+
}
74+
75+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType, DataFusionError> {
76+
Ok(arg_types[0].to_owned())
77+
}
78+
79+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
80+
exec_err!("should not reach here")
81+
}
82+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>, DataFusionError> {
83+
get_min_max_by_result_type(arg_types)
84+
}
85+
86+
fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
87+
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
88+
let mut order_by = aggr_func.order_by.unwrap_or_default();
89+
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));
90+
91+
order_by.push(Sort::new(second_arg, true, false));
92+
93+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
94+
last_value_udaf(),
95+
vec![first_arg],
96+
aggr_func.distinct,
97+
aggr_func.filter,
98+
Some(order_by),
99+
aggr_func.null_treatment,
100+
)))
101+
};
102+
Some(Box::new(simplify))
103+
}
104+
}
105+
106+
make_udaf_expr_and_func!(
107+
MinByFunction,
108+
min_by,
109+
x y,
110+
"Returns the value of the first column corresponding to the minimum value in the second column.",
111+
min_by_udaf
112+
);
113+
114+
pub struct MinByFunction {
115+
signature: Signature,
116+
}
117+
118+
impl Debug for MinByFunction {
119+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
120+
f.debug_struct("MinBy")
121+
.field("name", &self.name())
122+
.field("signature", &self.signature)
123+
.field("accumulator", &"<FUNC>")
124+
.finish()
125+
}
126+
}
127+
128+
impl Default for MinByFunction {
129+
fn default() -> Self {
130+
Self::new()
131+
}
132+
}
133+
134+
impl MinByFunction {
135+
pub fn new() -> Self {
136+
Self {
137+
signature: Signature::user_defined(Volatility::Immutable),
138+
}
139+
}
140+
}
141+
142+
impl AggregateUDFImpl for MinByFunction {
143+
fn as_any(&self) -> &dyn Any {
144+
self
145+
}
146+
147+
fn name(&self) -> &str {
148+
"min_by"
149+
}
150+
151+
fn signature(&self) -> &Signature {
152+
&self.signature
153+
}
154+
155+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType, DataFusionError> {
156+
Ok(arg_types[0].to_owned())
157+
}
158+
159+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
160+
exec_err!("should not reach here")
161+
}
162+
163+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>, DataFusionError> {
164+
get_min_max_by_result_type(arg_types)
165+
}
166+
167+
fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
168+
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
169+
let mut order_by = aggr_func.order_by.unwrap_or_default();
170+
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));
171+
172+
order_by.push(Sort::new(second_arg, false, false)); // false for ascending sort
173+
174+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
175+
last_value_udaf(),
176+
vec![first_arg],
177+
aggr_func.distinct,
178+
aggr_func.filter,
179+
Some(order_by),
180+
aggr_func.null_treatment,
181+
)))
182+
};
183+
Some(Box::new(simplify))
184+
}
185+
}

tests/main.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,138 @@ async fn test_mode_time64() {
115115
- +-----------------------------+
116116
"###);
117117
}
118+
119+
#[tokio::test]
120+
async fn test_max_by_and_min_by() {
121+
let mut execution = TestExecution::new().await.unwrap();
122+
123+
// Test max_by with numbers
124+
let actual = execution
125+
.run_and_format("SELECT max_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y);")
126+
.await;
127+
128+
insta::assert_yaml_snapshot!(actual, @r###"
129+
- +---------------------+
130+
- "| max_by(tab.x,tab.y) |"
131+
- +---------------------+
132+
- "| 3 |"
133+
- +---------------------+
134+
"###);
135+
136+
// Test min_by with numbers
137+
let actual = execution
138+
.run_and_format("SELECT min_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y);")
139+
.await;
140+
141+
insta::assert_yaml_snapshot!(actual, @r###"
142+
- +---------------------+
143+
- "| min_by(tab.x,tab.y) |"
144+
- +---------------------+
145+
- "| 2 |"
146+
- +---------------------+
147+
"###);
148+
149+
// Test max_by with strings
150+
let actual = execution
151+
.run_and_format("SELECT max_by(name, length(name)) FROM VALUES ('Alice'), ('Bob'), ('Charlie') as tab(name);")
152+
.await;
153+
154+
insta::assert_yaml_snapshot!(actual, @r###"
155+
- +---------------------------------------------+
156+
- "| max_by(tab.name,character_length(tab.name)) |"
157+
- +---------------------------------------------+
158+
- "| Charlie |"
159+
- +---------------------------------------------+
160+
"###);
161+
162+
// Test min_by with strings
163+
let actual = execution
164+
.run_and_format("SELECT min_by(name, length(name)) FROM VALUES ('Alice'), ('Bob'), ('Charlie') as tab(name);")
165+
.await;
166+
167+
insta::assert_yaml_snapshot!(actual, @r###"
168+
- +---------------------------------------------+
169+
- "| min_by(tab.name,character_length(tab.name)) |"
170+
- +---------------------------------------------+
171+
- "| Bob |"
172+
- +---------------------------------------------+
173+
"###);
174+
175+
// Test max_by with null values
176+
let actual = execution
177+
.run_and_format("SELECT max_by(x, y) FROM VALUES (1, 10), (2, null), (3, 15), (null, 8) as tab(x, y);")
178+
.await;
179+
180+
insta::assert_yaml_snapshot!(actual, @r###"
181+
- +---------------------+
182+
- "| max_by(tab.x,tab.y) |"
183+
- +---------------------+
184+
- "| 2 |"
185+
- +---------------------+
186+
"###);
187+
188+
// Test min_by with null values
189+
let actual = execution
190+
.run_and_format("SELECT min_by(x, y) FROM VALUES (1, 10), (2, null), (3, 15), (null, 8) as tab(x, y);")
191+
.await;
192+
193+
insta::assert_yaml_snapshot!(actual, @r###"
194+
- +---------------------+
195+
- "| min_by(tab.x,tab.y) |"
196+
- +---------------------+
197+
- "| 2 |"
198+
- +---------------------+
199+
"###);
200+
201+
// Test max_by with a single value
202+
let actual = execution
203+
.run_and_format("SELECT max_by(x, y) FROM VALUES (1, 10) as tab(x, y);")
204+
.await;
205+
206+
insta::assert_yaml_snapshot!(actual, @r###"
207+
- +---------------------+
208+
- "| max_by(tab.x,tab.y) |"
209+
- +---------------------+
210+
- "| 1 |"
211+
- +---------------------+
212+
"###);
213+
214+
// Test min_by with a single value
215+
let actual = execution
216+
.run_and_format("SELECT min_by(x, y) FROM VALUES (1, 10) as tab(x, y);")
217+
.await;
218+
219+
insta::assert_yaml_snapshot!(actual, @r###"
220+
- +---------------------+
221+
- "| min_by(tab.x,tab.y) |"
222+
- +---------------------+
223+
- "| 1 |"
224+
- +---------------------+
225+
"###);
226+
227+
// Test max_by with an empty set
228+
let actual = execution
229+
.run_and_format("SELECT max_by(x, y) FROM (SELECT * FROM (VALUES (1, 10)) WHERE 1=0) as tab(x, y);")
230+
.await;
231+
232+
insta::assert_yaml_snapshot!(actual, @r###"
233+
- +---------------------+
234+
- "| max_by(tab.x,tab.y) |"
235+
- +---------------------+
236+
- "| |"
237+
- +---------------------+
238+
"###);
239+
240+
// Test min_by with an empty set
241+
let actual = execution
242+
.run_and_format("SELECT min_by(x, y) FROM (SELECT * FROM (VALUES (1, 10)) WHERE 1=0) as tab(x, y);")
243+
.await;
244+
245+
insta::assert_yaml_snapshot!(actual, @r###"
246+
- +---------------------+
247+
- "| min_by(tab.x,tab.y) |"
248+
- +---------------------+
249+
- "| |"
250+
- +---------------------+
251+
"###);
252+
}

0 commit comments

Comments
 (0)