micromegas_datafusion_extensions/histogram/
accumulator.rs1use std::sync::Arc;
2
3use datafusion::{
4 arrow::{
5 array::{
6 Array, ArrayBuilder, ArrayRef, Float64Array, ListBuilder, PrimitiveBuilder,
7 StructBuilder, UInt64Builder,
8 },
9 datatypes::{DataType, Field, Float64Type, UInt64Type},
10 },
11 error::DataFusionError,
12 logical_expr::Accumulator,
13 scalar::ScalarValue,
14};
15
16use super::histogram_udaf::HistogramArray;
17
18#[derive(Debug)]
20pub struct HistogramAccumulator {
21 start: Option<f64>,
22 end: Option<f64>,
23 min: f64,
24 max: f64,
25 sum: f64,
26 sum_sq: f64,
27 count: u64,
28 bins: Vec<u64>,
29}
30
31impl HistogramAccumulator {
32 pub fn new(start: f64, end: f64, nb_bins: usize) -> Self {
33 let bins = vec![0; nb_bins];
34 Self {
35 start: Some(start),
36 end: Some(end),
37 bins,
38 min: f64::MAX,
39 max: f64::MIN,
40 sum: 0.0,
41 sum_sq: 0.0,
42 count: 0,
43 }
44 }
45
46 pub fn new_non_configured() -> Self {
47 Self {
48 start: None,
49 end: None,
50 min: f64::MAX,
51 max: f64::MIN,
52 sum: 0.0,
53 sum_sq: 0.0,
54 count: 0,
55 bins: Vec::new(),
56 }
57 }
58
59 pub fn configure(&mut self, histo_array: &HistogramArray) -> datafusion::error::Result<()> {
62 if self.start.is_some() {
63 return Ok(());
64 }
65 if histo_array.is_empty() {
66 return Ok(());
67 }
68 self.start = Some(histo_array.get_start(0)?);
69 self.end = Some(histo_array.get_end(0)?);
70 self.bins.resize(histo_array.get_bins(0)?.len(), 0);
71 Ok(())
72 }
73
74 pub fn update_batch_scalars(
75 &mut self,
76 scalars: &Float64Array,
77 ) -> datafusion::error::Result<()> {
78 if self.start.is_none() || self.end.is_none() {
79 return Err(DataFusionError::Execution(
80 "can't record scalar in a non-configured histogram".into(),
81 ));
82 }
83 let start = self.start.unwrap();
84 let range = self.end.unwrap() - start;
85 let bin_width = range / (self.bins.len() as f64);
86 for i in 0..scalars.len() {
87 if !scalars.is_null(i) {
88 let v = scalars.value(i);
89 self.min = self.min.min(v);
90 self.max = self.max.max(v);
91 self.sum += v;
92 self.sum_sq += v * v;
93 self.count += 1;
94 let bin_index = (((v - start) / bin_width).floor()) as usize;
95 let bin_index = bin_index.clamp(0, self.bins.len() - 1);
96 self.bins[bin_index] += 1;
97 }
98 }
99 Ok(())
100 }
101
102 pub fn merge_histograms(
103 &mut self,
104 histo_array: &HistogramArray,
105 ) -> datafusion::error::Result<()> {
106 self.configure(histo_array)?;
107 for index_histo in 0..histo_array.len() {
108 let start = histo_array.get_start(index_histo)?;
109 if self.start.unwrap() != start {
110 return Err(DataFusionError::Execution(
111 "Error merging incompatible histograms".into(),
112 ));
113 }
114 let end = histo_array.get_end(index_histo)?;
115 if self.end.unwrap() != end {
116 return Err(DataFusionError::Execution(
117 "Error merging incompatible histograms".into(),
118 ));
119 }
120
121 let min = histo_array.get_min(index_histo)?;
122 let max = histo_array.get_max(index_histo)?;
123 let sum = histo_array.get_sum(index_histo)?;
124 let sum_sq = histo_array.get_sum_sq(index_histo)?;
125 let count = histo_array.get_count(index_histo)?;
126 let bins = histo_array.get_bins(index_histo)?;
127 if bins.len() != self.bins.len() {
128 return Err(DataFusionError::Execution(
129 "Error merging incompatible histograms".into(),
130 ));
131 }
132 self.min = self.min.min(min);
133 self.max = self.max.max(max);
134 self.sum += sum;
135 self.sum_sq += sum_sq;
136 self.count += count;
137
138 for i in 0..self.bins.len() {
140 self.bins[i] += bins.value(i);
141 }
142 }
143 Ok(())
144 }
145}
146
147impl Accumulator for HistogramAccumulator {
148 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
149 match values.len() {
154 4 => {
155 let scalars = values[3]
156 .as_any()
157 .downcast_ref::<Float64Array>()
158 .ok_or_else(|| {
159 DataFusionError::Execution("values[3] should ne a Float64Array".into())
160 })?;
161 self.update_batch_scalars(scalars)
162 }
163 1 => {
164 let histo_array: HistogramArray = values[0].as_ref().try_into()?;
165 self.merge_histograms(&histo_array)
166 }
167
168 other => Err(DataFusionError::Execution(format!(
169 "invalid arguments to HistogramAccumulator::update_batch, nb_values={other}"
170 ))),
171 }
172 }
173
174 fn evaluate(&mut self) -> datafusion::error::Result<datafusion::scalar::ScalarValue> {
175 let fields = state_arrow_fields();
176 let mut struct_builder = StructBuilder::from_fields(fields, 1);
177 let start_builder = struct_builder
178 .field_builder::<PrimitiveBuilder<Float64Type>>(0)
179 .ok_or_else(|| DataFusionError::Execution("Error accessing to start builder".into()))?;
180 if let Some(start) = self.start {
181 start_builder.append_value(start);
182 } else {
183 start_builder.append_null();
184 }
185
186 let end_builder = struct_builder
187 .field_builder::<PrimitiveBuilder<Float64Type>>(1)
188 .ok_or_else(|| DataFusionError::Execution("Error accessing to end builder".into()))?;
189 if let Some(end) = self.end {
190 end_builder.append_value(end);
191 } else {
192 end_builder.append_null();
193 }
194
195 let min_builder = struct_builder
196 .field_builder::<PrimitiveBuilder<Float64Type>>(2)
197 .ok_or_else(|| DataFusionError::Execution("Error accessing to min builder".into()))?;
198 min_builder.append_value(self.min);
199
200 let max_builder = struct_builder
201 .field_builder::<PrimitiveBuilder<Float64Type>>(3)
202 .ok_or_else(|| DataFusionError::Execution("Error accessing to max builder".into()))?;
203 max_builder.append_value(self.max);
204
205 let sum_builder = struct_builder
206 .field_builder::<PrimitiveBuilder<Float64Type>>(4)
207 .ok_or_else(|| DataFusionError::Execution("Error accessing to sum builder".into()))?;
208 sum_builder.append_value(self.sum);
209
210 let sum_sq_builder = struct_builder
211 .field_builder::<PrimitiveBuilder<Float64Type>>(5)
212 .ok_or_else(|| {
213 DataFusionError::Execution("Error accessing to sum_sq builder".into())
214 })?;
215 sum_sq_builder.append_value(self.sum_sq);
216
217 let count_builder = struct_builder
218 .field_builder::<PrimitiveBuilder<UInt64Type>>(6)
219 .ok_or_else(|| DataFusionError::Execution("Error accessing to count builder".into()))?;
220 count_builder.append_value(self.count);
221
222 let bins_builder = struct_builder
223 .field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(7)
224 .ok_or_else(|| DataFusionError::Execution("Error accessing to bins builder".into()))?;
225 let bin_array_builder = bins_builder
226 .values()
227 .as_any_mut()
228 .downcast_mut::<UInt64Builder>()
229 .ok_or_else(|| {
230 DataFusionError::Execution("Error accessing to bins array builder".into())
231 })?;
232 bin_array_builder.append_slice(&self.bins);
233 bins_builder.append(true);
234 struct_builder.append(true);
235 Ok(ScalarValue::Struct(Arc::new(struct_builder.finish())))
236 }
237
238 fn size(&self) -> usize {
239 size_of_val(self) + size_of_val(&self.bins)
240 }
241
242 fn state(&mut self) -> datafusion::error::Result<Vec<datafusion::scalar::ScalarValue>> {
243 Ok(vec![self.evaluate()?])
244 }
245
246 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
247 for state in states {
248 let histo_array: HistogramArray = state.try_into()?;
249 self.merge_histograms(&histo_array)?;
250 }
251 Ok(())
252 }
253}
254
255pub fn state_arrow_fields() -> Vec<Field> {
257 vec![
258 Field::new("start", DataType::Float64, false),
259 Field::new("end", DataType::Float64, false),
260 Field::new("min", DataType::Float64, false),
261 Field::new("max", DataType::Float64, false),
262 Field::new("sum", DataType::Float64, false),
263 Field::new("sum_sq", DataType::Float64, false),
264 Field::new("count", DataType::UInt64, false),
265 Field::new(
266 "bins",
267 DataType::List(Arc::new(Field::new("bin", DataType::UInt64, false))),
268 false,
269 ),
270 ]
271}