micromegas_datafusion_extensions/histogram/
expand.rs1use super::histogram_udaf::HistogramArray;
2use async_trait::async_trait;
3use datafusion::arrow::array::{ArrayRef, Float64Array, StructArray, UInt64Array};
4use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5use datafusion::arrow::record_batch::RecordBatch;
6use datafusion::catalog::Session;
7use datafusion::catalog::TableFunctionImpl;
8use datafusion::catalog::TableProvider;
9use datafusion::datasource::TableType;
10use datafusion::datasource::memory::{DataSourceExec, MemorySourceConfig};
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
13use datafusion::physical_plan::ExecutionPlan;
14use datafusion::prelude::Expr;
15use datafusion::scalar::ScalarValue;
16use std::any::Any;
17use std::sync::Arc;
18
19#[derive(Debug)]
30pub struct ExpandHistogramTableFunction {}
31
32impl ExpandHistogramTableFunction {
33 pub fn new() -> Self {
34 Self {}
35 }
36}
37
38impl Default for ExpandHistogramTableFunction {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[derive(Debug, Clone)]
46enum HistogramSource {
47 Literal(ScalarValue),
48 Subquery(Arc<LogicalPlan>),
49}
50
51impl TableFunctionImpl for ExpandHistogramTableFunction {
52 fn call(&self, args: &[Expr]) -> datafusion::error::Result<Arc<dyn TableProvider>> {
53 if args.len() != 1 {
54 return Err(DataFusionError::Plan(
55 "expand_histogram requires exactly one argument (a histogram)".into(),
56 ));
57 }
58
59 let source = match &args[0] {
61 Expr::Literal(scalar, _metadata) => HistogramSource::Literal(scalar.clone()),
62 Expr::ScalarSubquery(subquery) => HistogramSource::Subquery(subquery.subquery.clone()),
63 other => {
64 let plan = LogicalPlanBuilder::empty(true)
65 .project(vec![other.clone()])?
66 .build()?;
67 HistogramSource::Subquery(Arc::new(plan))
68 }
69 };
70
71 Ok(Arc::new(ExpandHistogramTableProvider { source }))
72 }
73}
74
75fn output_schema() -> SchemaRef {
76 Arc::new(Schema::new(vec![
77 Field::new("bin_center", DataType::Float64, false),
78 Field::new("count", DataType::UInt64, false),
79 ]))
80}
81
82fn expand_histogram_to_batch(
83 histo_array: &HistogramArray,
84 index: usize,
85) -> Result<RecordBatch, DataFusionError> {
86 let start = histo_array.get_start(index)?;
87 let end = histo_array.get_end(index)?;
88 let bins = histo_array.get_bins(index)?;
89
90 let num_bins = bins.len();
91 if num_bins == 0 {
92 return Ok(RecordBatch::new_empty(output_schema()));
93 }
94
95 let bin_width = if (end - start).abs() < f64::EPSILON {
97 1.0 } else {
99 (end - start) / (num_bins as f64)
100 };
101
102 let mut bin_centers = Vec::with_capacity(num_bins);
103 let mut counts = Vec::with_capacity(num_bins);
104
105 for i in 0..num_bins {
106 let bin_center = start + (i as f64 + 0.5) * bin_width;
107 bin_centers.push(bin_center);
108 counts.push(bins.value(i));
109 }
110
111 let bin_center_array: ArrayRef = Arc::new(Float64Array::from(bin_centers));
112 let count_array: ArrayRef = Arc::new(UInt64Array::from(counts));
113
114 RecordBatch::try_new(output_schema(), vec![bin_center_array, count_array])
115 .map_err(|e| DataFusionError::External(e.into()))
116}
117
118fn extract_histogram_from_struct(
119 struct_array: &Arc<StructArray>,
120) -> Result<RecordBatch, DataFusionError> {
121 let histo_array = HistogramArray::new(struct_array.clone());
122 if histo_array.is_empty() {
123 return Ok(RecordBatch::new_empty(output_schema()));
124 }
125 expand_histogram_to_batch(&histo_array, 0)
126}
127
128fn scalar_to_batch(scalar: &ScalarValue) -> Result<RecordBatch, DataFusionError> {
129 match scalar {
130 ScalarValue::Struct(struct_array) => extract_histogram_from_struct(struct_array),
131 ScalarValue::Dictionary(_, inner) => scalar_to_batch(inner.as_ref()),
132 _ => Err(DataFusionError::Plan(format!(
133 "expand_histogram argument must be a struct (histogram), got: {:?}",
134 scalar.data_type()
135 ))),
136 }
137}
138
139#[derive(Debug)]
141pub struct ExpandHistogramTableProvider {
142 source: HistogramSource,
143}
144
145impl ExpandHistogramTableProvider {
146 pub fn from_scalar(scalar: ScalarValue) -> Result<Self, DataFusionError> {
148 if !matches!(scalar, ScalarValue::Struct(_)) {
149 return Err(DataFusionError::Plan(format!(
150 "expand_histogram argument must be a struct (histogram), got: {:?}",
151 scalar.data_type()
152 )));
153 }
154 Ok(Self {
155 source: HistogramSource::Literal(scalar),
156 })
157 }
158}
159
160#[async_trait]
161impl TableProvider for ExpandHistogramTableProvider {
162 fn as_any(&self) -> &dyn Any {
163 self
164 }
165
166 fn schema(&self) -> SchemaRef {
167 output_schema()
168 }
169
170 fn table_type(&self) -> TableType {
171 TableType::Temporary
172 }
173
174 async fn scan(
175 &self,
176 state: &dyn Session,
177 projection: Option<&Vec<usize>>,
178 _filters: &[Expr],
179 limit: Option<usize>,
180 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
181 let mut record_batch = match &self.source {
182 HistogramSource::Literal(scalar) => scalar_to_batch(scalar)?,
183 HistogramSource::Subquery(plan) => {
184 let physical_plan = state.create_physical_plan(plan).await?;
186 let task_ctx = state.task_ctx();
187 let batches = datafusion::physical_plan::collect(physical_plan, task_ctx).await?;
188
189 if batches.is_empty() || batches[0].num_rows() == 0 {
190 return Err(DataFusionError::Execution(
191 "expand_histogram subquery returned no rows".into(),
192 ));
193 }
194
195 let batch = &batches[0];
196 if batch.num_columns() != 1 {
197 return Err(DataFusionError::Execution(format!(
198 "expand_histogram subquery must return exactly one column, got {}",
199 batch.num_columns()
200 )));
201 }
202
203 let column = batch.column(0);
205 let struct_array = column.as_any().downcast_ref::<StructArray>().ok_or_else(
206 || {
207 DataFusionError::Execution(format!(
208 "expand_histogram subquery must return a struct (histogram), got {:?}",
209 column.data_type()
210 ))
211 },
212 )?;
213
214 let histo_array = HistogramArray::new(Arc::new(struct_array.clone()));
215 if histo_array.is_empty() {
216 RecordBatch::new_empty(output_schema())
217 } else {
218 expand_histogram_to_batch(&histo_array, 0)?
219 }
220 }
221 };
222
223 if let Some(n) = limit
225 && n < record_batch.num_rows()
226 {
227 record_batch = record_batch.slice(0, n);
228 }
229
230 let source = MemorySourceConfig::try_new(
231 &[vec![record_batch]],
232 self.schema(),
233 projection.map(|v| v.to_owned()),
234 )?;
235 Ok(DataSourceExec::from_data_source(source))
236 }
237}