micromegas_datafusion_extensions/histogram/
accumulator.rs

1use std::sync::Arc;
2
3use datafusion::{
4    arrow::{
5        array::{
6            Array, ArrayBuilder, ArrayRef, Float64Array, Int64Array, ListBuilder, PrimitiveBuilder,
7            StructBuilder, UInt64Builder,
8        },
9        datatypes::{DataType, Field, Float64Type, UInt64Type},
10    },
11    error::DataFusionError,
12    logical_expr::Accumulator,
13    scalar::ScalarValue,
14};
15
16use super::histogram_udaf::HistogramArray;
17
18/// An accumulator for computing histograms.
19#[derive(Debug)]
20pub struct HistogramAccumulator {
21    start: Option<f64>,
22    end: Option<f64>,
23    min: f64,
24    max: f64,
25    sum: f64,
26    sum_sq: f64,
27    count: u64,
28    bins: Vec<u64>,
29}
30
31impl HistogramAccumulator {
32    pub fn new(start: f64, end: f64, nb_bins: usize) -> Self {
33        let bins = vec![0; nb_bins];
34        Self {
35            start: Some(start),
36            end: Some(end),
37            bins,
38            min: f64::MAX,
39            max: f64::MIN,
40            sum: 0.0,
41            sum_sq: 0.0,
42            count: 0,
43        }
44    }
45
46    pub fn new_non_configured() -> Self {
47        Self {
48            start: None,
49            end: None,
50            min: f64::MAX,
51            max: f64::MIN,
52            sum: 0.0,
53            sum_sq: 0.0,
54            count: 0,
55            bins: Vec::new(),
56        }
57    }
58
59    /// Validates and sets histogram bounds from scalar parameters.
60    pub fn configure_from_params(
61        &mut self,
62        start: f64,
63        end: f64,
64        nb_bins: i64,
65    ) -> datafusion::error::Result<()> {
66        if nb_bins < 1 {
67            return Err(DataFusionError::Execution(format!(
68                "make_histogram: nb_bins must be >= 1, got {nb_bins}"
69            )));
70        }
71        if !start.is_finite() {
72            return Err(DataFusionError::Execution(format!(
73                "make_histogram: start must be finite, got {start}"
74            )));
75        }
76        if !end.is_finite() {
77            return Err(DataFusionError::Execution(format!(
78                "make_histogram: end must be finite, got {end}"
79            )));
80        }
81        if start > end {
82            return Err(DataFusionError::Execution(format!(
83                "make_histogram: start ({start}) must be <= end ({end})"
84            )));
85        }
86        self.start = Some(start);
87        self.end = Some(end);
88        self.bins.resize(nb_bins as usize, 0);
89        Ok(())
90    }
91
92    /// If not configured, scans for the first non-null row and uses it as a template.
93    /// If already configured or if all rows are null, does nothing.
94    pub fn configure(&mut self, histo_array: &HistogramArray) -> datafusion::error::Result<()> {
95        if self.start.is_some() {
96            return Ok(());
97        }
98        let Some(idx) = (0..histo_array.len()).find(|&i| !histo_array.is_null_at(i)) else {
99            return Ok(());
100        };
101        self.start = Some(histo_array.get_start(idx)?);
102        self.end = Some(histo_array.get_end(idx)?);
103        self.bins.resize(histo_array.get_bins(idx)?.len(), 0);
104        Ok(())
105    }
106
107    pub fn update_batch_scalars(
108        &mut self,
109        scalars: &Float64Array,
110    ) -> datafusion::error::Result<()> {
111        if self.start.is_none() || self.end.is_none() {
112            return Err(DataFusionError::Execution(
113                "can't record scalar in a non-configured histogram".into(),
114            ));
115        }
116        let start = self.start.unwrap();
117        let range = self.end.unwrap() - start;
118        let bin_width = range / (self.bins.len() as f64);
119        for i in 0..scalars.len() {
120            if !scalars.is_null(i) {
121                let v = scalars.value(i);
122                self.min = self.min.min(v);
123                self.max = self.max.max(v);
124                self.sum += v;
125                self.sum_sq += v * v;
126                self.count += 1;
127                let bin_index = (((v - start) / bin_width).floor()) as usize;
128                let bin_index = bin_index.clamp(0, self.bins.len() - 1);
129                self.bins[bin_index] += 1;
130            }
131        }
132        Ok(())
133    }
134
135    pub fn merge_histograms(
136        &mut self,
137        histo_array: &HistogramArray,
138    ) -> datafusion::error::Result<()> {
139        self.configure(histo_array)?;
140        for index_histo in 0..histo_array.len() {
141            if histo_array.is_null_at(index_histo) {
142                continue;
143            }
144            let start = histo_array.get_start(index_histo)?;
145            if self.start.unwrap() != start {
146                return Err(DataFusionError::Execution(
147                    "Error merging incompatible histograms".into(),
148                ));
149            }
150            let end = histo_array.get_end(index_histo)?;
151            if self.end.unwrap() != end {
152                return Err(DataFusionError::Execution(
153                    "Error merging incompatible histograms".into(),
154                ));
155            }
156
157            let min = histo_array.get_min(index_histo)?;
158            let max = histo_array.get_max(index_histo)?;
159            let sum = histo_array.get_sum(index_histo)?;
160            let sum_sq = histo_array.get_sum_sq(index_histo)?;
161            let count = histo_array.get_count(index_histo)?;
162            let bins = histo_array.get_bins(index_histo)?;
163            if bins.len() != self.bins.len() {
164                return Err(DataFusionError::Execution(
165                    "Error merging incompatible histograms".into(),
166                ));
167            }
168            self.min = self.min.min(min);
169            self.max = self.max.max(max);
170            self.sum += sum;
171            self.sum_sq += sum_sq;
172            self.count += count;
173
174            // optim opportunity: use arrow compute
175            for i in 0..self.bins.len() {
176                self.bins[i] += bins.value(i);
177            }
178        }
179        Ok(())
180    }
181}
182
183impl Accumulator for HistogramAccumulator {
184    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
185        // we support two signatures
186        // scalar case: [starts, ends, bin_counts, scalars_to_reduce]
187        // merge case: [histograms]
188
189        match values.len() {
190            4 => {
191                if values[0].is_empty() {
192                    return Ok(());
193                }
194                for (i, name) in [(0usize, "start"), (1, "end"), (2, "nb_bins")] {
195                    if values[i].is_null(0) {
196                        return Err(DataFusionError::Execution(format!(
197                            "make_histogram: {name} argument is null"
198                        )));
199                    }
200                }
201                let batch_start = values[0]
202                    .as_any()
203                    .downcast_ref::<Float64Array>()
204                    .ok_or_else(|| {
205                        DataFusionError::Execution("values[0] should be a Float64Array".into())
206                    })?
207                    .value(0);
208                let batch_end = values[1]
209                    .as_any()
210                    .downcast_ref::<Float64Array>()
211                    .ok_or_else(|| {
212                        DataFusionError::Execution("values[1] should be a Float64Array".into())
213                    })?
214                    .value(0);
215                let batch_nb_bins = values[2]
216                    .as_any()
217                    .downcast_ref::<Int64Array>()
218                    .ok_or_else(|| {
219                        DataFusionError::Execution("values[2] should be an Int64Array".into())
220                    })?
221                    .value(0);
222                if let Some(configured_start) = self.start {
223                    let configured_end = self.end.expect("end is set whenever start is set");
224                    if configured_start != batch_start
225                        || configured_end != batch_end
226                        || self.bins.len() != batch_nb_bins as usize
227                    {
228                        return Err(DataFusionError::Execution(
229                            "make_histogram: bounds/bins changed between batches".into(),
230                        ));
231                    }
232                } else {
233                    self.configure_from_params(batch_start, batch_end, batch_nb_bins)?;
234                }
235                let scalars = values[3]
236                    .as_any()
237                    .downcast_ref::<Float64Array>()
238                    .ok_or_else(|| {
239                        DataFusionError::Execution("values[3] should be a Float64Array".into())
240                    })?;
241                self.update_batch_scalars(scalars)
242            }
243            1 => {
244                let histo_array: HistogramArray = values[0].as_ref().try_into()?;
245                self.merge_histograms(&histo_array)
246            }
247
248            other => Err(DataFusionError::Execution(format!(
249                "invalid arguments to HistogramAccumulator::update_batch, nb_values={other}"
250            ))),
251        }
252    }
253
254    fn evaluate(&mut self) -> datafusion::error::Result<datafusion::scalar::ScalarValue> {
255        let fields = state_arrow_fields();
256        let mut struct_builder = StructBuilder::from_fields(fields, 1);
257        let start_builder = struct_builder
258            .field_builder::<PrimitiveBuilder<Float64Type>>(0)
259            .ok_or_else(|| DataFusionError::Execution("Error accessing to start builder".into()))?;
260        start_builder.append_value(self.start.unwrap_or(0.0));
261
262        let end_builder = struct_builder
263            .field_builder::<PrimitiveBuilder<Float64Type>>(1)
264            .ok_or_else(|| DataFusionError::Execution("Error accessing to end builder".into()))?;
265        end_builder.append_value(self.end.unwrap_or(0.0));
266
267        let min_builder = struct_builder
268            .field_builder::<PrimitiveBuilder<Float64Type>>(2)
269            .ok_or_else(|| DataFusionError::Execution("Error accessing to min builder".into()))?;
270        min_builder.append_value(self.min);
271
272        let max_builder = struct_builder
273            .field_builder::<PrimitiveBuilder<Float64Type>>(3)
274            .ok_or_else(|| DataFusionError::Execution("Error accessing to max builder".into()))?;
275        max_builder.append_value(self.max);
276
277        let sum_builder = struct_builder
278            .field_builder::<PrimitiveBuilder<Float64Type>>(4)
279            .ok_or_else(|| DataFusionError::Execution("Error accessing to sum builder".into()))?;
280        sum_builder.append_value(self.sum);
281
282        let sum_sq_builder = struct_builder
283            .field_builder::<PrimitiveBuilder<Float64Type>>(5)
284            .ok_or_else(|| {
285                DataFusionError::Execution("Error accessing to sum_sq builder".into())
286            })?;
287        sum_sq_builder.append_value(self.sum_sq);
288
289        let count_builder = struct_builder
290            .field_builder::<PrimitiveBuilder<UInt64Type>>(6)
291            .ok_or_else(|| DataFusionError::Execution("Error accessing to count builder".into()))?;
292        count_builder.append_value(self.count);
293
294        let bins_builder = struct_builder
295            .field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(7)
296            .ok_or_else(|| DataFusionError::Execution("Error accessing to bins builder".into()))?;
297        let bin_array_builder = bins_builder
298            .values()
299            .as_any_mut()
300            .downcast_mut::<UInt64Builder>()
301            .ok_or_else(|| {
302                DataFusionError::Execution("Error accessing to bins array builder".into())
303            })?;
304        bin_array_builder.append_slice(&self.bins);
305        bins_builder.append(true);
306        struct_builder.append(self.start.is_some());
307        Ok(ScalarValue::Struct(Arc::new(struct_builder.finish())))
308    }
309
310    fn size(&self) -> usize {
311        size_of_val(self) + size_of_val(&self.bins)
312    }
313
314    fn state(&mut self) -> datafusion::error::Result<Vec<datafusion::scalar::ScalarValue>> {
315        Ok(vec![self.evaluate()?])
316    }
317
318    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
319        for state in states {
320            let histo_array: HistogramArray = state.try_into()?;
321            self.merge_histograms(&histo_array)?;
322        }
323        Ok(())
324    }
325}
326
327/// Returns the Arrow fields for the histogram state.
328pub fn state_arrow_fields() -> Vec<Field> {
329    vec![
330        Field::new("start", DataType::Float64, false),
331        Field::new("end", DataType::Float64, false),
332        Field::new("min", DataType::Float64, false),
333        Field::new("max", DataType::Float64, false),
334        Field::new("sum", DataType::Float64, false),
335        Field::new("sum_sq", DataType::Float64, false),
336        Field::new("count", DataType::UInt64, false),
337        Field::new(
338            "bins",
339            DataType::List(Arc::new(Field::new("bin", DataType::UInt64, false))),
340            false,
341        ),
342    ]
343}