micromegas_datafusion_extensions/histogram/
histogram_udaf.rs1use datafusion::{
2 arrow::{
3 array::{Array, ArrayRef, Float64Array, ListArray, StructArray, UInt64Array},
4 datatypes::{DataType, Fields},
5 },
6 error::DataFusionError,
7 logical_expr::{
8 Accumulator, AggregateUDF, ColumnarValue, Volatility, function::AccumulatorArgs,
9 },
10 physical_plan::expressions::Literal,
11 prelude::*,
12 scalar::ScalarValue,
13};
14use std::sync::Arc;
15
16use super::accumulator::{HistogramAccumulator, state_arrow_fields};
17
18#[derive(Debug)]
20pub struct HistogramArray {
21 inner: Arc<StructArray>,
22}
23
24impl HistogramArray {
25 pub fn new(inner: Arc<StructArray>) -> Self {
26 Self { inner }
27 }
28
29 pub fn inner(&self) -> Arc<StructArray> {
30 self.inner.clone()
31 }
32
33 pub fn len(&self) -> usize {
34 self.inner.len()
35 }
36
37 pub fn is_empty(&self) -> bool {
38 self.inner.is_empty()
39 }
40
41 pub fn is_null_at(&self, index: usize) -> bool {
42 self.inner.is_null(index)
43 }
44
45 pub fn get_start(&self, index: usize) -> Result<f64, DataFusionError> {
46 let starts = self
47 .inner
48 .column(0)
49 .as_any()
50 .downcast_ref::<Float64Array>()
51 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
52 Ok(starts.value(index))
53 }
54
55 pub fn get_end(&self, index: usize) -> Result<f64, DataFusionError> {
56 let ends = self
57 .inner
58 .column(1)
59 .as_any()
60 .downcast_ref::<Float64Array>()
61 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
62 Ok(ends.value(index))
63 }
64
65 pub fn get_min(&self, index: usize) -> Result<f64, DataFusionError> {
66 let mins = self
67 .inner
68 .column(2)
69 .as_any()
70 .downcast_ref::<Float64Array>()
71 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
72 Ok(mins.value(index))
73 }
74
75 pub fn get_max(&self, index: usize) -> Result<f64, DataFusionError> {
76 let maxs = self
77 .inner
78 .column(3)
79 .as_any()
80 .downcast_ref::<Float64Array>()
81 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
82 Ok(maxs.value(index))
83 }
84
85 pub fn get_sum(&self, index: usize) -> Result<f64, DataFusionError> {
86 let sums = self
87 .inner
88 .column(4)
89 .as_any()
90 .downcast_ref::<Float64Array>()
91 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
92 Ok(sums.value(index))
93 }
94
95 pub fn get_sum_sq(&self, index: usize) -> Result<f64, DataFusionError> {
96 let sums_sq = self
97 .inner
98 .column(5)
99 .as_any()
100 .downcast_ref::<Float64Array>()
101 .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
102 Ok(sums_sq.value(index))
103 }
104
105 pub fn get_count(&self, index: usize) -> Result<u64, DataFusionError> {
106 let counts = self
107 .inner
108 .column(6)
109 .as_any()
110 .downcast_ref::<UInt64Array>()
111 .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
112 Ok(counts.value(index))
113 }
114
115 pub fn get_bins(&self, index: usize) -> Result<UInt64Array, DataFusionError> {
116 let bins_list = self
117 .inner
118 .column(7)
119 .as_any()
120 .downcast_ref::<ListArray>()
121 .ok_or_else(|| DataFusionError::Execution("downcasting to ListArray".into()))?;
122 let bins = bins_list.value(index);
123 let bins = bins
124 .as_any()
125 .downcast_ref::<UInt64Array>()
126 .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
127 Ok(bins.clone())
128 }
129}
130
131impl TryFrom<&ArrayRef> for HistogramArray {
132 type Error = DataFusionError;
133
134 fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
135 let struct_array = value
136 .as_any()
137 .downcast_ref::<StructArray>()
138 .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
139 let inner = Arc::new(struct_array.clone());
140 Ok(Self { inner })
141 }
142}
143
144impl TryFrom<&dyn Array> for HistogramArray {
145 type Error = DataFusionError;
146
147 fn try_from(value: &dyn Array) -> Result<Self, Self::Error> {
148 let struct_array = value
149 .as_any()
150 .downcast_ref::<StructArray>()
151 .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
152 let inner = Arc::new(struct_array.clone());
153 Ok(Self { inner })
154 }
155}
156
157impl TryFrom<&ColumnarValue> for HistogramArray {
158 type Error = DataFusionError;
159
160 fn try_from(value: &ColumnarValue) -> Result<Self, Self::Error> {
161 match value {
162 ColumnarValue::Array(array) => array.try_into(),
163 ColumnarValue::Scalar(scalar_value) => {
164 if let ScalarValue::Struct(array) = scalar_value {
165 Ok(Self::new(array.clone()))
166 } else {
167 Err(DataFusionError::Execution( "Can't convert ColumnarValue into HistogramArray: ScalarValue is not a struct".into()))
168 }
169 }
170 }
171 }
172}
173
174fn make_state(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
175 let start_literal = args
176 .exprs
177 .first()
178 .and_then(|e| e.as_any().downcast_ref::<Literal>())
179 .and_then(|l| {
180 if let ScalarValue::Float64(Some(v)) = l.value() {
181 Some(*v)
182 } else {
183 None
184 }
185 });
186 let end_literal = args
187 .exprs
188 .get(1)
189 .and_then(|e| e.as_any().downcast_ref::<Literal>())
190 .and_then(|l| {
191 if let ScalarValue::Float64(Some(v)) = l.value() {
192 Some(*v)
193 } else {
194 None
195 }
196 });
197 let nb_bins_literal = args
198 .exprs
199 .get(2)
200 .and_then(|e| e.as_any().downcast_ref::<Literal>())
201 .and_then(|l| {
202 if let ScalarValue::Int64(Some(v)) = l.value() {
203 Some(*v)
204 } else {
205 None
206 }
207 });
208
209 let mut acc = HistogramAccumulator::new_non_configured();
210 if let (Some(start), Some(end), Some(nb_bins)) = (start_literal, end_literal, nb_bins_literal) {
211 acc.configure_from_params(start, end, nb_bins)?;
212 }
213 Ok(Box::new(acc))
214}
215
216pub fn make_histogram_arrow_type() -> DataType {
217 DataType::Struct(Fields::from(state_arrow_fields()))
218}
219
220pub fn make_histo_udaf() -> AggregateUDF {
222 create_udaf(
223 "make_histogram",
224 vec![
225 DataType::Float64,
226 DataType::Float64,
227 DataType::Int64,
228 DataType::Float64,
229 ],
230 Arc::new(make_histogram_arrow_type()),
231 Volatility::Immutable,
232 Arc::new(&make_state),
233 Arc::new(vec![make_histogram_arrow_type()]),
234 )
235}