micromegas_datafusion_extensions/histogram/
histogram_udaf.rs

1use datafusion::{
2    arrow::{
3        array::{Array, ArrayRef, Float64Array, ListArray, StructArray, UInt64Array},
4        datatypes::{DataType, Fields},
5    },
6    error::DataFusionError,
7    logical_expr::{
8        Accumulator, AggregateUDF, ColumnarValue, Volatility, function::AccumulatorArgs,
9    },
10    physical_plan::expressions::Literal,
11    prelude::*,
12    scalar::ScalarValue,
13};
14use std::sync::Arc;
15
16use super::accumulator::{HistogramAccumulator, state_arrow_fields};
17
18/// An array of histograms.
19#[derive(Debug)]
20pub struct HistogramArray {
21    inner: Arc<StructArray>,
22}
23
24impl HistogramArray {
25    pub fn new(inner: Arc<StructArray>) -> Self {
26        Self { inner }
27    }
28
29    pub fn inner(&self) -> Arc<StructArray> {
30        self.inner.clone()
31    }
32
33    pub fn len(&self) -> usize {
34        self.inner.len()
35    }
36
37    pub fn is_empty(&self) -> bool {
38        self.inner.is_empty()
39    }
40
41    pub fn get_start(&self, index: usize) -> Result<f64, DataFusionError> {
42        let starts = self
43            .inner
44            .column(0)
45            .as_any()
46            .downcast_ref::<Float64Array>()
47            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
48        Ok(starts.value(index))
49    }
50
51    pub fn get_end(&self, index: usize) -> Result<f64, DataFusionError> {
52        let ends = self
53            .inner
54            .column(1)
55            .as_any()
56            .downcast_ref::<Float64Array>()
57            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
58        Ok(ends.value(index))
59    }
60
61    pub fn get_min(&self, index: usize) -> Result<f64, DataFusionError> {
62        let mins = self
63            .inner
64            .column(2)
65            .as_any()
66            .downcast_ref::<Float64Array>()
67            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
68        Ok(mins.value(index))
69    }
70
71    pub fn get_max(&self, index: usize) -> Result<f64, DataFusionError> {
72        let maxs = self
73            .inner
74            .column(3)
75            .as_any()
76            .downcast_ref::<Float64Array>()
77            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
78        Ok(maxs.value(index))
79    }
80
81    pub fn get_sum(&self, index: usize) -> Result<f64, DataFusionError> {
82        let sums = self
83            .inner
84            .column(4)
85            .as_any()
86            .downcast_ref::<Float64Array>()
87            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
88        Ok(sums.value(index))
89    }
90
91    pub fn get_sum_sq(&self, index: usize) -> Result<f64, DataFusionError> {
92        let sums_sq = self
93            .inner
94            .column(5)
95            .as_any()
96            .downcast_ref::<Float64Array>()
97            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
98        Ok(sums_sq.value(index))
99    }
100
101    pub fn get_count(&self, index: usize) -> Result<u64, DataFusionError> {
102        let counts = self
103            .inner
104            .column(6)
105            .as_any()
106            .downcast_ref::<UInt64Array>()
107            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
108        Ok(counts.value(index))
109    }
110
111    pub fn get_bins(&self, index: usize) -> Result<UInt64Array, DataFusionError> {
112        let bins_list = self
113            .inner
114            .column(7)
115            .as_any()
116            .downcast_ref::<ListArray>()
117            .ok_or_else(|| DataFusionError::Execution("downcasting to ListArray".into()))?;
118        let bins = bins_list.value(index);
119        let bins = bins
120            .as_any()
121            .downcast_ref::<UInt64Array>()
122            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
123        Ok(bins.clone())
124    }
125}
126
127impl TryFrom<&ArrayRef> for HistogramArray {
128    type Error = DataFusionError;
129
130    fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
131        let struct_array = value
132            .as_any()
133            .downcast_ref::<StructArray>()
134            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
135        let inner = Arc::new(struct_array.clone());
136        Ok(Self { inner })
137    }
138}
139
140impl TryFrom<&dyn Array> for HistogramArray {
141    type Error = DataFusionError;
142
143    fn try_from(value: &dyn Array) -> Result<Self, Self::Error> {
144        let struct_array = value
145            .as_any()
146            .downcast_ref::<StructArray>()
147            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
148        let inner = Arc::new(struct_array.clone());
149        Ok(Self { inner })
150    }
151}
152
153impl TryFrom<&ColumnarValue> for HistogramArray {
154    type Error = DataFusionError;
155
156    fn try_from(value: &ColumnarValue) -> Result<Self, Self::Error> {
157        match value {
158            ColumnarValue::Array(array) => array.try_into(),
159            ColumnarValue::Scalar(scalar_value) => {
160                if let ScalarValue::Struct(array) = scalar_value {
161                    Ok(Self::new(array.clone()))
162                } else {
163                    Err(DataFusionError::Execution( "Can't convert ColumnarValue into HistogramArray: ScalarValue is not a struct".into()))
164                }
165            }
166        }
167    }
168}
169
170fn make_state(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
171    let start_arg = args
172        .exprs
173        .first()
174        .ok_or_else(|| DataFusionError::Execution("Reading first argument".into()))?
175        .as_any()
176        .downcast_ref::<Literal>()
177        .ok_or_else(|| DataFusionError::Execution("Downcasting first argument to Literal".into()))?
178        .value();
179    let start = if let ScalarValue::Float64(Some(start_value)) = start_arg {
180        start_value
181    } else {
182        return Err(DataFusionError::Execution(format!(
183            "arg 0 should be a float64, found {start_arg:?}"
184        )));
185    };
186
187    let end_arg = args
188        .exprs
189        .get(1)
190        .ok_or_else(|| DataFusionError::Execution("Reading argument 1".into()))?
191        .as_any()
192        .downcast_ref::<Literal>()
193        .ok_or_else(|| DataFusionError::Execution("Downcasting argument 1 to Literal".into()))?
194        .value();
195    let end = if let ScalarValue::Float64(Some(end_value)) = end_arg {
196        end_value
197    } else {
198        return Err(DataFusionError::Execution(format!(
199            "arg 0 should be a float64, found {end_arg:?}"
200        )));
201    };
202
203    let nb_bins_arg = args
204        .exprs
205        .get(2)
206        .ok_or_else(|| DataFusionError::Execution("Reading argument 2".into()))?
207        .as_any()
208        .downcast_ref::<Literal>()
209        .ok_or_else(|| DataFusionError::Execution("Downcasting argument 2 to Literal".into()))?
210        .value();
211    let nb_bins = if let ScalarValue::Int64(Some(nb_bins_value)) = nb_bins_arg {
212        nb_bins_value
213    } else {
214        return Err(DataFusionError::Execution(format!(
215            "arg 0 should be a int64, found {nb_bins_arg:?}"
216        )));
217    };
218
219    Ok(Box::new(HistogramAccumulator::new(
220        *start,
221        *end,
222        *nb_bins as usize,
223    )))
224}
225
226pub fn make_histogram_arrow_type() -> DataType {
227    DataType::Struct(Fields::from(state_arrow_fields()))
228}
229
230/// Creates a user-defined aggregate function to compute histograms.
231pub fn make_histo_udaf() -> AggregateUDF {
232    create_udaf(
233        "make_histogram",
234        vec![
235            DataType::Float64,
236            DataType::Float64,
237            DataType::Int64,
238            DataType::Float64,
239        ],
240        Arc::new(make_histogram_arrow_type()),
241        Volatility::Immutable,
242        Arc::new(&make_state),
243        Arc::new(vec![make_histogram_arrow_type()]),
244    )
245}