micromegas_datafusion_extensions/histogram/
accumulator.rs

1use std::sync::Arc;
2
3use datafusion::{
4    arrow::{
5        array::{
6            Array, ArrayBuilder, ArrayRef, Float64Array, ListBuilder, PrimitiveBuilder,
7            StructBuilder, UInt64Builder,
8        },
9        datatypes::{DataType, Field, Float64Type, UInt64Type},
10    },
11    error::DataFusionError,
12    logical_expr::Accumulator,
13    scalar::ScalarValue,
14};
15
16use super::histogram_udaf::HistogramArray;
17
18/// An accumulator for computing histograms.
19#[derive(Debug)]
20pub struct HistogramAccumulator {
21    start: Option<f64>,
22    end: Option<f64>,
23    min: f64,
24    max: f64,
25    sum: f64,
26    sum_sq: f64,
27    count: u64,
28    bins: Vec<u64>,
29}
30
31impl HistogramAccumulator {
32    pub fn new(start: f64, end: f64, nb_bins: usize) -> Self {
33        let bins = vec![0; nb_bins];
34        Self {
35            start: Some(start),
36            end: Some(end),
37            bins,
38            min: f64::MAX,
39            max: f64::MIN,
40            sum: 0.0,
41            sum_sq: 0.0,
42            count: 0,
43        }
44    }
45
46    pub fn new_non_configured() -> Self {
47        Self {
48            start: None,
49            end: None,
50            min: f64::MAX,
51            max: f64::MIN,
52            sum: 0.0,
53            sum_sq: 0.0,
54            count: 0,
55            bins: Vec::new(),
56        }
57    }
58
59    /// if not configured, will take the first instance of the array as a template
60    /// if already configured or if the array is empty, will do nothing
61    pub fn configure(&mut self, histo_array: &HistogramArray) -> datafusion::error::Result<()> {
62        if self.start.is_some() {
63            return Ok(());
64        }
65        if histo_array.is_empty() {
66            return Ok(());
67        }
68        self.start = Some(histo_array.get_start(0)?);
69        self.end = Some(histo_array.get_end(0)?);
70        self.bins.resize(histo_array.get_bins(0)?.len(), 0);
71        Ok(())
72    }
73
74    pub fn update_batch_scalars(
75        &mut self,
76        scalars: &Float64Array,
77    ) -> datafusion::error::Result<()> {
78        if self.start.is_none() || self.end.is_none() {
79            return Err(DataFusionError::Execution(
80                "can't record scalar in a non-configured histogram".into(),
81            ));
82        }
83        let start = self.start.unwrap();
84        let range = self.end.unwrap() - start;
85        let bin_width = range / (self.bins.len() as f64);
86        for i in 0..scalars.len() {
87            if !scalars.is_null(i) {
88                let v = scalars.value(i);
89                self.min = self.min.min(v);
90                self.max = self.max.max(v);
91                self.sum += v;
92                self.sum_sq += v * v;
93                self.count += 1;
94                let bin_index = (((v - start) / bin_width).floor()) as usize;
95                let bin_index = bin_index.clamp(0, self.bins.len() - 1);
96                self.bins[bin_index] += 1;
97            }
98        }
99        Ok(())
100    }
101
102    pub fn merge_histograms(
103        &mut self,
104        histo_array: &HistogramArray,
105    ) -> datafusion::error::Result<()> {
106        self.configure(histo_array)?;
107        for index_histo in 0..histo_array.len() {
108            let start = histo_array.get_start(index_histo)?;
109            if self.start.unwrap() != start {
110                return Err(DataFusionError::Execution(
111                    "Error merging incompatible histograms".into(),
112                ));
113            }
114            let end = histo_array.get_end(index_histo)?;
115            if self.end.unwrap() != end {
116                return Err(DataFusionError::Execution(
117                    "Error merging incompatible histograms".into(),
118                ));
119            }
120
121            let min = histo_array.get_min(index_histo)?;
122            let max = histo_array.get_max(index_histo)?;
123            let sum = histo_array.get_sum(index_histo)?;
124            let sum_sq = histo_array.get_sum_sq(index_histo)?;
125            let count = histo_array.get_count(index_histo)?;
126            let bins = histo_array.get_bins(index_histo)?;
127            if bins.len() != self.bins.len() {
128                return Err(DataFusionError::Execution(
129                    "Error merging incompatible histograms".into(),
130                ));
131            }
132            self.min = self.min.min(min);
133            self.max = self.max.max(max);
134            self.sum += sum;
135            self.sum_sq += sum_sq;
136            self.count += count;
137
138            // optim opportunity: use arrow compute
139            for i in 0..self.bins.len() {
140                self.bins[i] += bins.value(i);
141            }
142        }
143        Ok(())
144    }
145}
146
147impl Accumulator for HistogramAccumulator {
148    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
149        // we support two signatures
150        // scalar case: [starts, ends, bin_counts, scalars_to_reduce]
151        // merge case: [histograms]
152
153        match values.len() {
154            4 => {
155                let scalars = values[3]
156                    .as_any()
157                    .downcast_ref::<Float64Array>()
158                    .ok_or_else(|| {
159                        DataFusionError::Execution("values[3] should ne a Float64Array".into())
160                    })?;
161                self.update_batch_scalars(scalars)
162            }
163            1 => {
164                let histo_array: HistogramArray = values[0].as_ref().try_into()?;
165                self.merge_histograms(&histo_array)
166            }
167
168            other => Err(DataFusionError::Execution(format!(
169                "invalid arguments to HistogramAccumulator::update_batch, nb_values={other}"
170            ))),
171        }
172    }
173
174    fn evaluate(&mut self) -> datafusion::error::Result<datafusion::scalar::ScalarValue> {
175        let fields = state_arrow_fields();
176        let mut struct_builder = StructBuilder::from_fields(fields, 1);
177        let start_builder = struct_builder
178            .field_builder::<PrimitiveBuilder<Float64Type>>(0)
179            .ok_or_else(|| DataFusionError::Execution("Error accessing to start builder".into()))?;
180        if let Some(start) = self.start {
181            start_builder.append_value(start);
182        } else {
183            start_builder.append_null();
184        }
185
186        let end_builder = struct_builder
187            .field_builder::<PrimitiveBuilder<Float64Type>>(1)
188            .ok_or_else(|| DataFusionError::Execution("Error accessing to end builder".into()))?;
189        if let Some(end) = self.end {
190            end_builder.append_value(end);
191        } else {
192            end_builder.append_null();
193        }
194
195        let min_builder = struct_builder
196            .field_builder::<PrimitiveBuilder<Float64Type>>(2)
197            .ok_or_else(|| DataFusionError::Execution("Error accessing to min builder".into()))?;
198        min_builder.append_value(self.min);
199
200        let max_builder = struct_builder
201            .field_builder::<PrimitiveBuilder<Float64Type>>(3)
202            .ok_or_else(|| DataFusionError::Execution("Error accessing to max builder".into()))?;
203        max_builder.append_value(self.max);
204
205        let sum_builder = struct_builder
206            .field_builder::<PrimitiveBuilder<Float64Type>>(4)
207            .ok_or_else(|| DataFusionError::Execution("Error accessing to sum builder".into()))?;
208        sum_builder.append_value(self.sum);
209
210        let sum_sq_builder = struct_builder
211            .field_builder::<PrimitiveBuilder<Float64Type>>(5)
212            .ok_or_else(|| {
213                DataFusionError::Execution("Error accessing to sum_sq builder".into())
214            })?;
215        sum_sq_builder.append_value(self.sum_sq);
216
217        let count_builder = struct_builder
218            .field_builder::<PrimitiveBuilder<UInt64Type>>(6)
219            .ok_or_else(|| DataFusionError::Execution("Error accessing to count builder".into()))?;
220        count_builder.append_value(self.count);
221
222        let bins_builder = struct_builder
223            .field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(7)
224            .ok_or_else(|| DataFusionError::Execution("Error accessing to bins builder".into()))?;
225        let bin_array_builder = bins_builder
226            .values()
227            .as_any_mut()
228            .downcast_mut::<UInt64Builder>()
229            .ok_or_else(|| {
230                DataFusionError::Execution("Error accessing to bins array builder".into())
231            })?;
232        bin_array_builder.append_slice(&self.bins);
233        bins_builder.append(true);
234        struct_builder.append(true);
235        Ok(ScalarValue::Struct(Arc::new(struct_builder.finish())))
236    }
237
238    fn size(&self) -> usize {
239        size_of_val(self) + size_of_val(&self.bins)
240    }
241
242    fn state(&mut self) -> datafusion::error::Result<Vec<datafusion::scalar::ScalarValue>> {
243        Ok(vec![self.evaluate()?])
244    }
245
246    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
247        for state in states {
248            let histo_array: HistogramArray = state.try_into()?;
249            self.merge_histograms(&histo_array)?;
250        }
251        Ok(())
252    }
253}
254
255/// Returns the Arrow fields for the histogram state.
256pub fn state_arrow_fields() -> Vec<Field> {
257    vec![
258        Field::new("start", DataType::Float64, false),
259        Field::new("end", DataType::Float64, false),
260        Field::new("min", DataType::Float64, false),
261        Field::new("max", DataType::Float64, false),
262        Field::new("sum", DataType::Float64, false),
263        Field::new("sum_sq", DataType::Float64, false),
264        Field::new("count", DataType::UInt64, false),
265        Field::new(
266            "bins",
267            DataType::List(Arc::new(Field::new("bin", DataType::UInt64, false))),
268            false,
269        ),
270    ]
271}