micromegas_datafusion_extensions/histogram/
expand.rs

1use super::histogram_udaf::HistogramArray;
2use async_trait::async_trait;
3use datafusion::arrow::array::{ArrayRef, Float64Array, StructArray, UInt64Array};
4use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5use datafusion::arrow::record_batch::RecordBatch;
6use datafusion::catalog::Session;
7use datafusion::catalog::TableFunctionImpl;
8use datafusion::catalog::TableProvider;
9use datafusion::datasource::TableType;
10use datafusion::datasource::memory::{DataSourceExec, MemorySourceConfig};
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
13use datafusion::physical_plan::ExecutionPlan;
14use datafusion::prelude::Expr;
15use datafusion::scalar::ScalarValue;
16use std::any::Any;
17use std::sync::Arc;
18
19/// A DataFusion `TableFunctionImpl` that expands a histogram struct into rows of (bin_center, count).
20///
21/// Usage:
22/// ```sql
23/// SELECT bin_center, count
24/// FROM expand_histogram(
25///   (SELECT make_histogram(0.0, 100.0, 100, value)
26///    FROM measures WHERE name = 'cpu_usage')
27/// )
28/// ```
29#[derive(Debug)]
30pub struct ExpandHistogramTableFunction {}
31
32impl ExpandHistogramTableFunction {
33    pub fn new() -> Self {
34        Self {}
35    }
36}
37
38impl Default for ExpandHistogramTableFunction {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44/// The source of histogram data - either a literal value or a subquery to evaluate.
45#[derive(Debug, Clone)]
46enum HistogramSource {
47    Literal(ScalarValue),
48    Subquery(Arc<LogicalPlan>),
49}
50
51impl TableFunctionImpl for ExpandHistogramTableFunction {
52    fn call(&self, args: &[Expr]) -> datafusion::error::Result<Arc<dyn TableProvider>> {
53        if args.len() != 1 {
54            return Err(DataFusionError::Plan(
55                "expand_histogram requires exactly one argument (a histogram)".into(),
56            ));
57        }
58
59        // Extract the histogram from the expression
60        let source = match &args[0] {
61            Expr::Literal(scalar, _metadata) => HistogramSource::Literal(scalar.clone()),
62            Expr::ScalarSubquery(subquery) => HistogramSource::Subquery(subquery.subquery.clone()),
63            other => {
64                let plan = LogicalPlanBuilder::empty(true)
65                    .project(vec![other.clone()])?
66                    .build()?;
67                HistogramSource::Subquery(Arc::new(plan))
68            }
69        };
70
71        Ok(Arc::new(ExpandHistogramTableProvider { source }))
72    }
73}
74
75fn output_schema() -> SchemaRef {
76    Arc::new(Schema::new(vec![
77        Field::new("bin_center", DataType::Float64, false),
78        Field::new("count", DataType::UInt64, false),
79    ]))
80}
81
82fn expand_histogram_to_batch(
83    histo_array: &HistogramArray,
84    index: usize,
85) -> Result<RecordBatch, DataFusionError> {
86    let start = histo_array.get_start(index)?;
87    let end = histo_array.get_end(index)?;
88    let bins = histo_array.get_bins(index)?;
89
90    let num_bins = bins.len();
91    if num_bins == 0 {
92        return Ok(RecordBatch::new_empty(output_schema()));
93    }
94
95    // Handle edge case where start == end (all values in a single point)
96    let bin_width = if (end - start).abs() < f64::EPSILON {
97        1.0 // Use unit width when range is zero
98    } else {
99        (end - start) / (num_bins as f64)
100    };
101
102    let mut bin_centers = Vec::with_capacity(num_bins);
103    let mut counts = Vec::with_capacity(num_bins);
104
105    for i in 0..num_bins {
106        let bin_center = start + (i as f64 + 0.5) * bin_width;
107        bin_centers.push(bin_center);
108        counts.push(bins.value(i));
109    }
110
111    let bin_center_array: ArrayRef = Arc::new(Float64Array::from(bin_centers));
112    let count_array: ArrayRef = Arc::new(UInt64Array::from(counts));
113
114    RecordBatch::try_new(output_schema(), vec![bin_center_array, count_array])
115        .map_err(|e| DataFusionError::External(e.into()))
116}
117
118fn extract_histogram_from_struct(
119    struct_array: &Arc<StructArray>,
120) -> Result<RecordBatch, DataFusionError> {
121    let histo_array = HistogramArray::new(struct_array.clone());
122    if histo_array.is_empty() {
123        return Ok(RecordBatch::new_empty(output_schema()));
124    }
125    expand_histogram_to_batch(&histo_array, 0)
126}
127
128fn scalar_to_batch(scalar: &ScalarValue) -> Result<RecordBatch, DataFusionError> {
129    match scalar {
130        ScalarValue::Struct(struct_array) => extract_histogram_from_struct(struct_array),
131        ScalarValue::Dictionary(_, inner) => scalar_to_batch(inner.as_ref()),
132        _ => Err(DataFusionError::Plan(format!(
133            "expand_histogram argument must be a struct (histogram), got: {:?}",
134            scalar.data_type()
135        ))),
136    }
137}
138
139/// Table provider for expanding histogram data.
140#[derive(Debug)]
141pub struct ExpandHistogramTableProvider {
142    source: HistogramSource,
143}
144
145impl ExpandHistogramTableProvider {
146    /// Creates a new provider from a histogram scalar value.
147    pub fn from_scalar(scalar: ScalarValue) -> Result<Self, DataFusionError> {
148        if !matches!(scalar, ScalarValue::Struct(_)) {
149            return Err(DataFusionError::Plan(format!(
150                "expand_histogram argument must be a struct (histogram), got: {:?}",
151                scalar.data_type()
152            )));
153        }
154        Ok(Self {
155            source: HistogramSource::Literal(scalar),
156        })
157    }
158}
159
160#[async_trait]
161impl TableProvider for ExpandHistogramTableProvider {
162    fn as_any(&self) -> &dyn Any {
163        self
164    }
165
166    fn schema(&self) -> SchemaRef {
167        output_schema()
168    }
169
170    fn table_type(&self) -> TableType {
171        TableType::Temporary
172    }
173
174    async fn scan(
175        &self,
176        state: &dyn Session,
177        projection: Option<&Vec<usize>>,
178        _filters: &[Expr],
179        limit: Option<usize>,
180    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
181        let mut record_batch = match &self.source {
182            HistogramSource::Literal(scalar) => scalar_to_batch(scalar)?,
183            HistogramSource::Subquery(plan) => {
184                // Execute the subquery to get the histogram scalar
185                let physical_plan = state.create_physical_plan(plan).await?;
186                let task_ctx = state.task_ctx();
187                let batches = datafusion::physical_plan::collect(physical_plan, task_ctx).await?;
188
189                if batches.is_empty() || batches[0].num_rows() == 0 {
190                    return Err(DataFusionError::Execution(
191                        "expand_histogram subquery returned no rows".into(),
192                    ));
193                }
194
195                let batch = &batches[0];
196                if batch.num_columns() != 1 {
197                    return Err(DataFusionError::Execution(format!(
198                        "expand_histogram subquery must return exactly one column, got {}",
199                        batch.num_columns()
200                    )));
201                }
202
203                // Extract the struct from the first row
204                let column = batch.column(0);
205                let struct_array = column.as_any().downcast_ref::<StructArray>().ok_or_else(
206                    || {
207                        DataFusionError::Execution(format!(
208                            "expand_histogram subquery must return a struct (histogram), got {:?}",
209                            column.data_type()
210                        ))
211                    },
212                )?;
213
214                let histo_array = HistogramArray::new(Arc::new(struct_array.clone()));
215                if histo_array.is_empty() {
216                    RecordBatch::new_empty(output_schema())
217                } else {
218                    expand_histogram_to_batch(&histo_array, 0)?
219                }
220            }
221        };
222
223        // Apply limit if specified
224        if let Some(n) = limit
225            && n < record_batch.num_rows()
226        {
227            record_batch = record_batch.slice(0, n);
228        }
229
230        let source = MemorySourceConfig::try_new(
231            &[vec![record_batch]],
232            self.schema(),
233            projection.map(|v| v.to_owned()),
234        )?;
235        Ok(DataSourceExec::from_data_source(source))
236    }
237}