micromegas_datafusion_extensions/histogram/
accumulator.rs1use std::sync::Arc;
2
3use datafusion::{
4 arrow::{
5 array::{
6 Array, ArrayBuilder, ArrayRef, Float64Array, Int64Array, ListBuilder, PrimitiveBuilder,
7 StructBuilder, UInt64Builder,
8 },
9 datatypes::{DataType, Field, Float64Type, UInt64Type},
10 },
11 error::DataFusionError,
12 logical_expr::Accumulator,
13 scalar::ScalarValue,
14};
15
16use super::histogram_udaf::HistogramArray;
17
18#[derive(Debug)]
20pub struct HistogramAccumulator {
21 start: Option<f64>,
22 end: Option<f64>,
23 min: f64,
24 max: f64,
25 sum: f64,
26 sum_sq: f64,
27 count: u64,
28 bins: Vec<u64>,
29}
30
31impl HistogramAccumulator {
32 pub fn new(start: f64, end: f64, nb_bins: usize) -> Self {
33 let bins = vec![0; nb_bins];
34 Self {
35 start: Some(start),
36 end: Some(end),
37 bins,
38 min: f64::MAX,
39 max: f64::MIN,
40 sum: 0.0,
41 sum_sq: 0.0,
42 count: 0,
43 }
44 }
45
46 pub fn new_non_configured() -> Self {
47 Self {
48 start: None,
49 end: None,
50 min: f64::MAX,
51 max: f64::MIN,
52 sum: 0.0,
53 sum_sq: 0.0,
54 count: 0,
55 bins: Vec::new(),
56 }
57 }
58
59 pub fn configure_from_params(
61 &mut self,
62 start: f64,
63 end: f64,
64 nb_bins: i64,
65 ) -> datafusion::error::Result<()> {
66 if nb_bins < 1 {
67 return Err(DataFusionError::Execution(format!(
68 "make_histogram: nb_bins must be >= 1, got {nb_bins}"
69 )));
70 }
71 if !start.is_finite() {
72 return Err(DataFusionError::Execution(format!(
73 "make_histogram: start must be finite, got {start}"
74 )));
75 }
76 if !end.is_finite() {
77 return Err(DataFusionError::Execution(format!(
78 "make_histogram: end must be finite, got {end}"
79 )));
80 }
81 if start > end {
82 return Err(DataFusionError::Execution(format!(
83 "make_histogram: start ({start}) must be <= end ({end})"
84 )));
85 }
86 self.start = Some(start);
87 self.end = Some(end);
88 self.bins.resize(nb_bins as usize, 0);
89 Ok(())
90 }
91
92 pub fn configure(&mut self, histo_array: &HistogramArray) -> datafusion::error::Result<()> {
95 if self.start.is_some() {
96 return Ok(());
97 }
98 let Some(idx) = (0..histo_array.len()).find(|&i| !histo_array.is_null_at(i)) else {
99 return Ok(());
100 };
101 self.start = Some(histo_array.get_start(idx)?);
102 self.end = Some(histo_array.get_end(idx)?);
103 self.bins.resize(histo_array.get_bins(idx)?.len(), 0);
104 Ok(())
105 }
106
107 pub fn update_batch_scalars(
108 &mut self,
109 scalars: &Float64Array,
110 ) -> datafusion::error::Result<()> {
111 if self.start.is_none() || self.end.is_none() {
112 return Err(DataFusionError::Execution(
113 "can't record scalar in a non-configured histogram".into(),
114 ));
115 }
116 let start = self.start.unwrap();
117 let range = self.end.unwrap() - start;
118 let bin_width = range / (self.bins.len() as f64);
119 for i in 0..scalars.len() {
120 if !scalars.is_null(i) {
121 let v = scalars.value(i);
122 self.min = self.min.min(v);
123 self.max = self.max.max(v);
124 self.sum += v;
125 self.sum_sq += v * v;
126 self.count += 1;
127 let bin_index = (((v - start) / bin_width).floor()) as usize;
128 let bin_index = bin_index.clamp(0, self.bins.len() - 1);
129 self.bins[bin_index] += 1;
130 }
131 }
132 Ok(())
133 }
134
135 pub fn merge_histograms(
136 &mut self,
137 histo_array: &HistogramArray,
138 ) -> datafusion::error::Result<()> {
139 self.configure(histo_array)?;
140 for index_histo in 0..histo_array.len() {
141 if histo_array.is_null_at(index_histo) {
142 continue;
143 }
144 let start = histo_array.get_start(index_histo)?;
145 if self.start.unwrap() != start {
146 return Err(DataFusionError::Execution(
147 "Error merging incompatible histograms".into(),
148 ));
149 }
150 let end = histo_array.get_end(index_histo)?;
151 if self.end.unwrap() != end {
152 return Err(DataFusionError::Execution(
153 "Error merging incompatible histograms".into(),
154 ));
155 }
156
157 let min = histo_array.get_min(index_histo)?;
158 let max = histo_array.get_max(index_histo)?;
159 let sum = histo_array.get_sum(index_histo)?;
160 let sum_sq = histo_array.get_sum_sq(index_histo)?;
161 let count = histo_array.get_count(index_histo)?;
162 let bins = histo_array.get_bins(index_histo)?;
163 if bins.len() != self.bins.len() {
164 return Err(DataFusionError::Execution(
165 "Error merging incompatible histograms".into(),
166 ));
167 }
168 self.min = self.min.min(min);
169 self.max = self.max.max(max);
170 self.sum += sum;
171 self.sum_sq += sum_sq;
172 self.count += count;
173
174 for i in 0..self.bins.len() {
176 self.bins[i] += bins.value(i);
177 }
178 }
179 Ok(())
180 }
181}
182
183impl Accumulator for HistogramAccumulator {
184 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
185 match values.len() {
190 4 => {
191 if values[0].is_empty() {
192 return Ok(());
193 }
194 for (i, name) in [(0usize, "start"), (1, "end"), (2, "nb_bins")] {
195 if values[i].is_null(0) {
196 return Err(DataFusionError::Execution(format!(
197 "make_histogram: {name} argument is null"
198 )));
199 }
200 }
201 let batch_start = values[0]
202 .as_any()
203 .downcast_ref::<Float64Array>()
204 .ok_or_else(|| {
205 DataFusionError::Execution("values[0] should be a Float64Array".into())
206 })?
207 .value(0);
208 let batch_end = values[1]
209 .as_any()
210 .downcast_ref::<Float64Array>()
211 .ok_or_else(|| {
212 DataFusionError::Execution("values[1] should be a Float64Array".into())
213 })?
214 .value(0);
215 let batch_nb_bins = values[2]
216 .as_any()
217 .downcast_ref::<Int64Array>()
218 .ok_or_else(|| {
219 DataFusionError::Execution("values[2] should be an Int64Array".into())
220 })?
221 .value(0);
222 if let Some(configured_start) = self.start {
223 let configured_end = self.end.expect("end is set whenever start is set");
224 if configured_start != batch_start
225 || configured_end != batch_end
226 || self.bins.len() != batch_nb_bins as usize
227 {
228 return Err(DataFusionError::Execution(
229 "make_histogram: bounds/bins changed between batches".into(),
230 ));
231 }
232 } else {
233 self.configure_from_params(batch_start, batch_end, batch_nb_bins)?;
234 }
235 let scalars = values[3]
236 .as_any()
237 .downcast_ref::<Float64Array>()
238 .ok_or_else(|| {
239 DataFusionError::Execution("values[3] should be a Float64Array".into())
240 })?;
241 self.update_batch_scalars(scalars)
242 }
243 1 => {
244 let histo_array: HistogramArray = values[0].as_ref().try_into()?;
245 self.merge_histograms(&histo_array)
246 }
247
248 other => Err(DataFusionError::Execution(format!(
249 "invalid arguments to HistogramAccumulator::update_batch, nb_values={other}"
250 ))),
251 }
252 }
253
254 fn evaluate(&mut self) -> datafusion::error::Result<datafusion::scalar::ScalarValue> {
255 let fields = state_arrow_fields();
256 let mut struct_builder = StructBuilder::from_fields(fields, 1);
257 let start_builder = struct_builder
258 .field_builder::<PrimitiveBuilder<Float64Type>>(0)
259 .ok_or_else(|| DataFusionError::Execution("Error accessing to start builder".into()))?;
260 start_builder.append_value(self.start.unwrap_or(0.0));
261
262 let end_builder = struct_builder
263 .field_builder::<PrimitiveBuilder<Float64Type>>(1)
264 .ok_or_else(|| DataFusionError::Execution("Error accessing to end builder".into()))?;
265 end_builder.append_value(self.end.unwrap_or(0.0));
266
267 let min_builder = struct_builder
268 .field_builder::<PrimitiveBuilder<Float64Type>>(2)
269 .ok_or_else(|| DataFusionError::Execution("Error accessing to min builder".into()))?;
270 min_builder.append_value(self.min);
271
272 let max_builder = struct_builder
273 .field_builder::<PrimitiveBuilder<Float64Type>>(3)
274 .ok_or_else(|| DataFusionError::Execution("Error accessing to max builder".into()))?;
275 max_builder.append_value(self.max);
276
277 let sum_builder = struct_builder
278 .field_builder::<PrimitiveBuilder<Float64Type>>(4)
279 .ok_or_else(|| DataFusionError::Execution("Error accessing to sum builder".into()))?;
280 sum_builder.append_value(self.sum);
281
282 let sum_sq_builder = struct_builder
283 .field_builder::<PrimitiveBuilder<Float64Type>>(5)
284 .ok_or_else(|| {
285 DataFusionError::Execution("Error accessing to sum_sq builder".into())
286 })?;
287 sum_sq_builder.append_value(self.sum_sq);
288
289 let count_builder = struct_builder
290 .field_builder::<PrimitiveBuilder<UInt64Type>>(6)
291 .ok_or_else(|| DataFusionError::Execution("Error accessing to count builder".into()))?;
292 count_builder.append_value(self.count);
293
294 let bins_builder = struct_builder
295 .field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(7)
296 .ok_or_else(|| DataFusionError::Execution("Error accessing to bins builder".into()))?;
297 let bin_array_builder = bins_builder
298 .values()
299 .as_any_mut()
300 .downcast_mut::<UInt64Builder>()
301 .ok_or_else(|| {
302 DataFusionError::Execution("Error accessing to bins array builder".into())
303 })?;
304 bin_array_builder.append_slice(&self.bins);
305 bins_builder.append(true);
306 struct_builder.append(self.start.is_some());
307 Ok(ScalarValue::Struct(Arc::new(struct_builder.finish())))
308 }
309
310 fn size(&self) -> usize {
311 size_of_val(self) + size_of_val(&self.bins)
312 }
313
314 fn state(&mut self) -> datafusion::error::Result<Vec<datafusion::scalar::ScalarValue>> {
315 Ok(vec![self.evaluate()?])
316 }
317
318 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
319 for state in states {
320 let histo_array: HistogramArray = state.try_into()?;
321 self.merge_histograms(&histo_array)?;
322 }
323 Ok(())
324 }
325}
326
327pub fn state_arrow_fields() -> Vec<Field> {
329 vec![
330 Field::new("start", DataType::Float64, false),
331 Field::new("end", DataType::Float64, false),
332 Field::new("min", DataType::Float64, false),
333 Field::new("max", DataType::Float64, false),
334 Field::new("sum", DataType::Float64, false),
335 Field::new("sum_sq", DataType::Float64, false),
336 Field::new("count", DataType::UInt64, false),
337 Field::new(
338 "bins",
339 DataType::List(Arc::new(Field::new("bin", DataType::UInt64, false))),
340 false,
341 ),
342 ]
343}