micromegas_datafusion_extensions/histogram/
histogram_udaf.rs1use datafusion::{
2 arrow::{
3 array::{Array, ArrayRef, Float64Array, ListArray, StructArray, UInt64Array},
4 datatypes::{DataType, Fields},
5 },
6 error::DataFusionError,
7 logical_expr::{
8 Accumulator, AggregateUDF, ColumnarValue, Volatility, function::AccumulatorArgs,
9 },
10 physical_plan::expressions::Literal,
11 prelude::*,
12 scalar::ScalarValue,
13};
14use std::sync::Arc;
15
16use super::accumulator::{HistogramAccumulator, state_arrow_fields};
17
18#[derive(Debug)]
20pub struct HistogramArray {
21 inner: Arc<StructArray>,
22}
23
24impl HistogramArray {
25 pub fn new(inner: Arc<StructArray>) -> Self {
26 Self { inner }
27 }
28
29 pub fn inner(&self) -> Arc<StructArray> {
30 self.inner.clone()
31 }
32
33 pub fn len(&self) -> usize {
34 self.inner.len()
35 }
36
37 pub fn is_empty(&self) -> bool {
38 self.inner.is_empty()
39 }
40
41 pub fn get_start(&self, index: usize) -> Result<f64, DataFusionError> {
42 let starts = self
43 .inner
44 .column(0)
45 .as_any()
46 .downcast_ref::<Float64Array>()
47 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
48 Ok(starts.value(index))
49 }
50
51 pub fn get_end(&self, index: usize) -> Result<f64, DataFusionError> {
52 let ends = self
53 .inner
54 .column(1)
55 .as_any()
56 .downcast_ref::<Float64Array>()
57 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
58 Ok(ends.value(index))
59 }
60
61 pub fn get_min(&self, index: usize) -> Result<f64, DataFusionError> {
62 let mins = self
63 .inner
64 .column(2)
65 .as_any()
66 .downcast_ref::<Float64Array>()
67 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
68 Ok(mins.value(index))
69 }
70
71 pub fn get_max(&self, index: usize) -> Result<f64, DataFusionError> {
72 let maxs = self
73 .inner
74 .column(3)
75 .as_any()
76 .downcast_ref::<Float64Array>()
77 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
78 Ok(maxs.value(index))
79 }
80
81 pub fn get_sum(&self, index: usize) -> Result<f64, DataFusionError> {
82 let sums = self
83 .inner
84 .column(4)
85 .as_any()
86 .downcast_ref::<Float64Array>()
87 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
88 Ok(sums.value(index))
89 }
90
91 pub fn get_sum_sq(&self, index: usize) -> Result<f64, DataFusionError> {
92 let sums_sq = self
93 .inner
94 .column(5)
95 .as_any()
96 .downcast_ref::<Float64Array>()
97 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
98 Ok(sums_sq.value(index))
99 }
100
101 pub fn get_count(&self, index: usize) -> Result<u64, DataFusionError> {
102 let counts = self
103 .inner
104 .column(6)
105 .as_any()
106 .downcast_ref::<UInt64Array>()
107 .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
108 Ok(counts.value(index))
109 }
110
111 pub fn get_bins(&self, index: usize) -> Result<UInt64Array, DataFusionError> {
112 let bins_list = self
113 .inner
114 .column(7)
115 .as_any()
116 .downcast_ref::<ListArray>()
117 .ok_or_else(|| DataFusionError::Execution("downcasting to ListArray".into()))?;
118 let bins = bins_list.value(index);
119 let bins = bins
120 .as_any()
121 .downcast_ref::<UInt64Array>()
122 .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
123 Ok(bins.clone())
124 }
125}
126
127impl TryFrom<&ArrayRef> for HistogramArray {
128 type Error = DataFusionError;
129
130 fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
131 let struct_array = value
132 .as_any()
133 .downcast_ref::<StructArray>()
134 .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
135 let inner = Arc::new(struct_array.clone());
136 Ok(Self { inner })
137 }
138}
139
140impl TryFrom<&dyn Array> for HistogramArray {
141 type Error = DataFusionError;
142
143 fn try_from(value: &dyn Array) -> Result<Self, Self::Error> {
144 let struct_array = value
145 .as_any()
146 .downcast_ref::<StructArray>()
147 .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
148 let inner = Arc::new(struct_array.clone());
149 Ok(Self { inner })
150 }
151}
152
153impl TryFrom<&ColumnarValue> for HistogramArray {
154 type Error = DataFusionError;
155
156 fn try_from(value: &ColumnarValue) -> Result<Self, Self::Error> {
157 match value {
158 ColumnarValue::Array(array) => array.try_into(),
159 ColumnarValue::Scalar(scalar_value) => {
160 if let ScalarValue::Struct(array) = scalar_value {
161 Ok(Self::new(array.clone()))
162 } else {
163 Err(DataFusionError::Execution( "Can't convert ColumnarValue into HistogramArray: ScalarValue is not a struct".into()))
164 }
165 }
166 }
167 }
168}
169
170fn make_state(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
171 let start_arg = args
172 .exprs
173 .first()
174 .ok_or_else(|| DataFusionError::Execution("Reading first argument".into()))?
175 .as_any()
176 .downcast_ref::<Literal>()
177 .ok_or_else(|| DataFusionError::Execution("Downcasting first argument to Literal".into()))?
178 .value();
179 let start = if let ScalarValue::Float64(Some(start_value)) = start_arg {
180 start_value
181 } else {
182 return Err(DataFusionError::Execution(format!(
183 "arg 0 should be a float64, found {start_arg:?}"
184 )));
185 };
186
187 let end_arg = args
188 .exprs
189 .get(1)
190 .ok_or_else(|| DataFusionError::Execution("Reading argument 1".into()))?
191 .as_any()
192 .downcast_ref::<Literal>()
193 .ok_or_else(|| DataFusionError::Execution("Downcasting argument 1 to Literal".into()))?
194 .value();
195 let end = if let ScalarValue::Float64(Some(end_value)) = end_arg {
196 end_value
197 } else {
198 return Err(DataFusionError::Execution(format!(
199 "arg 0 should be a float64, found {end_arg:?}"
200 )));
201 };
202
203 let nb_bins_arg = args
204 .exprs
205 .get(2)
206 .ok_or_else(|| DataFusionError::Execution("Reading argument 2".into()))?
207 .as_any()
208 .downcast_ref::<Literal>()
209 .ok_or_else(|| DataFusionError::Execution("Downcasting argument 2 to Literal".into()))?
210 .value();
211 let nb_bins = if let ScalarValue::Int64(Some(nb_bins_value)) = nb_bins_arg {
212 nb_bins_value
213 } else {
214 return Err(DataFusionError::Execution(format!(
215 "arg 0 should be a int64, found {nb_bins_arg:?}"
216 )));
217 };
218
219 Ok(Box::new(HistogramAccumulator::new(
220 *start,
221 *end,
222 *nb_bins as usize,
223 )))
224}
225
226pub fn make_histogram_arrow_type() -> DataType {
227 DataType::Struct(Fields::from(state_arrow_fields()))
228}
229
230pub fn make_histo_udaf() -> AggregateUDF {
232 create_udaf(
233 "make_histogram",
234 vec![
235 DataType::Float64,
236 DataType::Float64,
237 DataType::Int64,
238 DataType::Float64,
239 ],
240 Arc::new(make_histogram_arrow_type()),
241 Volatility::Immutable,
242 Arc::new(&make_state),
243 Arc::new(vec![make_histogram_arrow_type()]),
244 )
245}