micromegas_datafusion_extensions/histogram/
histogram_udaf.rs

1use datafusion::{
2    arrow::{
3        array::{Array, ArrayRef, Float64Array, ListArray, StructArray, UInt64Array},
4        datatypes::{DataType, Fields},
5    },
6    error::DataFusionError,
7    logical_expr::{
8        Accumulator, AggregateUDF, ColumnarValue, Volatility, function::AccumulatorArgs,
9    },
10    physical_plan::expressions::Literal,
11    prelude::*,
12    scalar::ScalarValue,
13};
14use std::sync::Arc;
15
16use super::accumulator::{HistogramAccumulator, state_arrow_fields};
17
18/// An array of histograms.
19#[derive(Debug)]
20pub struct HistogramArray {
21    inner: Arc<StructArray>,
22}
23
24impl HistogramArray {
25    pub fn new(inner: Arc<StructArray>) -> Self {
26        Self { inner }
27    }
28
29    pub fn inner(&self) -> Arc<StructArray> {
30        self.inner.clone()
31    }
32
33    pub fn len(&self) -> usize {
34        self.inner.len()
35    }
36
37    pub fn is_empty(&self) -> bool {
38        self.inner.is_empty()
39    }
40
41    pub fn is_null_at(&self, index: usize) -> bool {
42        self.inner.is_null(index)
43    }
44
45    pub fn get_start(&self, index: usize) -> Result<f64, DataFusionError> {
46        let starts = self
47            .inner
48            .column(0)
49            .as_any()
50            .downcast_ref::<Float64Array>()
51            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
52        Ok(starts.value(index))
53    }
54
55    pub fn get_end(&self, index: usize) -> Result<f64, DataFusionError> {
56        let ends = self
57            .inner
58            .column(1)
59            .as_any()
60            .downcast_ref::<Float64Array>()
61            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
62        Ok(ends.value(index))
63    }
64
65    pub fn get_min(&self, index: usize) -> Result<f64, DataFusionError> {
66        let mins = self
67            .inner
68            .column(2)
69            .as_any()
70            .downcast_ref::<Float64Array>()
71            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
72        Ok(mins.value(index))
73    }
74
75    pub fn get_max(&self, index: usize) -> Result<f64, DataFusionError> {
76        let maxs = self
77            .inner
78            .column(3)
79            .as_any()
80            .downcast_ref::<Float64Array>()
81            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
82        Ok(maxs.value(index))
83    }
84
85    pub fn get_sum(&self, index: usize) -> Result<f64, DataFusionError> {
86        let sums = self
87            .inner
88            .column(4)
89            .as_any()
90            .downcast_ref::<Float64Array>()
91            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
92        Ok(sums.value(index))
93    }
94
95    pub fn get_sum_sq(&self, index: usize) -> Result<f64, DataFusionError> {
96        let sums_sq = self
97            .inner
98            .column(5)
99            .as_any()
100            .downcast_ref::<Float64Array>()
101            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
102        Ok(sums_sq.value(index))
103    }
104
105    pub fn get_count(&self, index: usize) -> Result<u64, DataFusionError> {
106        let counts = self
107            .inner
108            .column(6)
109            .as_any()
110            .downcast_ref::<UInt64Array>()
111            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
112        Ok(counts.value(index))
113    }
114
115    pub fn get_bins(&self, index: usize) -> Result<UInt64Array, DataFusionError> {
116        let bins_list = self
117            .inner
118            .column(7)
119            .as_any()
120            .downcast_ref::<ListArray>()
121            .ok_or_else(|| DataFusionError::Execution("downcasting to ListArray".into()))?;
122        let bins = bins_list.value(index);
123        let bins = bins
124            .as_any()
125            .downcast_ref::<UInt64Array>()
126            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
127        Ok(bins.clone())
128    }
129}
130
131impl TryFrom<&ArrayRef> for HistogramArray {
132    type Error = DataFusionError;
133
134    fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
135        let struct_array = value
136            .as_any()
137            .downcast_ref::<StructArray>()
138            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
139        let inner = Arc::new(struct_array.clone());
140        Ok(Self { inner })
141    }
142}
143
144impl TryFrom<&dyn Array> for HistogramArray {
145    type Error = DataFusionError;
146
147    fn try_from(value: &dyn Array) -> Result<Self, Self::Error> {
148        let struct_array = value
149            .as_any()
150            .downcast_ref::<StructArray>()
151            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
152        let inner = Arc::new(struct_array.clone());
153        Ok(Self { inner })
154    }
155}
156
157impl TryFrom<&ColumnarValue> for HistogramArray {
158    type Error = DataFusionError;
159
160    fn try_from(value: &ColumnarValue) -> Result<Self, Self::Error> {
161        match value {
162            ColumnarValue::Array(array) => array.try_into(),
163            ColumnarValue::Scalar(scalar_value) => {
164                if let ScalarValue::Struct(array) = scalar_value {
165                    Ok(Self::new(array.clone()))
166                } else {
167                    Err(DataFusionError::Execution( "Can't convert ColumnarValue into HistogramArray: ScalarValue is not a struct".into()))
168                }
169            }
170        }
171    }
172}
173
174fn make_state(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
175    let start_literal = args
176        .exprs
177        .first()
178        .and_then(|e| e.as_any().downcast_ref::<Literal>())
179        .and_then(|l| {
180            if let ScalarValue::Float64(Some(v)) = l.value() {
181                Some(*v)
182            } else {
183                None
184            }
185        });
186    let end_literal = args
187        .exprs
188        .get(1)
189        .and_then(|e| e.as_any().downcast_ref::<Literal>())
190        .and_then(|l| {
191            if let ScalarValue::Float64(Some(v)) = l.value() {
192                Some(*v)
193            } else {
194                None
195            }
196        });
197    let nb_bins_literal = args
198        .exprs
199        .get(2)
200        .and_then(|e| e.as_any().downcast_ref::<Literal>())
201        .and_then(|l| {
202            if let ScalarValue::Int64(Some(v)) = l.value() {
203                Some(*v)
204            } else {
205                None
206            }
207        });
208
209    let mut acc = HistogramAccumulator::new_non_configured();
210    if let (Some(start), Some(end), Some(nb_bins)) = (start_literal, end_literal, nb_bins_literal) {
211        acc.configure_from_params(start, end, nb_bins)?;
212    }
213    Ok(Box::new(acc))
214}
215
216pub fn make_histogram_arrow_type() -> DataType {
217    DataType::Struct(Fields::from(state_arrow_fields()))
218}
219
220/// Creates a user-defined aggregate function to compute histograms.
221pub fn make_histo_udaf() -> AggregateUDF {
222    create_udaf(
223        "make_histogram",
224        vec![
225            DataType::Float64,
226            DataType::Float64,
227            DataType::Int64,
228            DataType::Float64,
229        ],
230        Arc::new(make_histogram_arrow_type()),
231        Volatility::Immutable,
232        Arc::new(&make_state),
233        Arc::new(vec![make_histogram_arrow_type()]),
234    )
235}