micromegas_datafusion_extensions/histogram/
expand.rs1use super::histogram_udaf::HistogramArray;
2use async_trait::async_trait;
3use datafusion::arrow::array::{ArrayRef, Float64Array, StructArray, UInt64Array};
4use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5use datafusion::arrow::record_batch::RecordBatch;
6use datafusion::catalog::Session;
7use datafusion::catalog::TableFunctionImpl;
8use datafusion::catalog::TableProvider;
9use datafusion::datasource::TableType;
10use datafusion::datasource::memory::{DataSourceExec, MemorySourceConfig};
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
13use datafusion::physical_plan::ExecutionPlan;
14use datafusion::prelude::Expr;
15use datafusion::scalar::ScalarValue;
16use std::any::Any;
17use std::sync::Arc;
18
19#[derive(Debug)]
30pub struct ExpandHistogramTableFunction {}
31
32impl ExpandHistogramTableFunction {
33 pub fn new() -> Self {
34 Self {}
35 }
36}
37
38impl Default for ExpandHistogramTableFunction {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[derive(Debug, Clone)]
46enum HistogramSource {
47 Literal(ScalarValue),
48 Subquery(Arc<LogicalPlan>),
49}
50
51impl TableFunctionImpl for ExpandHistogramTableFunction {
52 fn call(&self, args: &[Expr]) -> datafusion::error::Result<Arc<dyn TableProvider>> {
53 if args.len() != 1 {
54 return Err(DataFusionError::Plan(
55 "expand_histogram requires exactly one argument (a histogram)".into(),
56 ));
57 }
58
59 let source = match &args[0] {
61 Expr::Literal(scalar, _metadata) => HistogramSource::Literal(scalar.clone()),
62 Expr::ScalarSubquery(subquery) => HistogramSource::Subquery(subquery.subquery.clone()),
63 other => {
64 let plan = LogicalPlanBuilder::empty(true)
65 .project(vec![other.clone()])?
66 .build()?;
67 HistogramSource::Subquery(Arc::new(plan))
68 }
69 };
70
71 Ok(Arc::new(ExpandHistogramTableProvider { source }))
72 }
73}
74
75fn output_schema() -> SchemaRef {
76 Arc::new(Schema::new(vec![
77 Field::new("bin_center", DataType::Float64, false),
78 Field::new("count", DataType::UInt64, false),
79 ]))
80}
81
82fn expand_histogram_to_batch(
83 histo_array: &HistogramArray,
84 index: usize,
85) -> Result<RecordBatch, DataFusionError> {
86 if histo_array.is_null_at(index) {
87 return Ok(RecordBatch::new_empty(output_schema()));
88 }
89 let start = histo_array.get_start(index)?;
90 let end = histo_array.get_end(index)?;
91 let bins = histo_array.get_bins(index)?;
92
93 let num_bins = bins.len();
94 if num_bins == 0 {
95 return Ok(RecordBatch::new_empty(output_schema()));
96 }
97
98 let bin_width = if (end - start).abs() < f64::EPSILON {
100 1.0 } else {
102 (end - start) / (num_bins as f64)
103 };
104
105 let mut bin_centers = Vec::with_capacity(num_bins);
106 let mut counts = Vec::with_capacity(num_bins);
107
108 for i in 0..num_bins {
109 let bin_center = start + (i as f64 + 0.5) * bin_width;
110 bin_centers.push(bin_center);
111 counts.push(bins.value(i));
112 }
113
114 let bin_center_array: ArrayRef = Arc::new(Float64Array::from(bin_centers));
115 let count_array: ArrayRef = Arc::new(UInt64Array::from(counts));
116
117 RecordBatch::try_new(output_schema(), vec![bin_center_array, count_array])
118 .map_err(|e| DataFusionError::External(e.into()))
119}
120
121fn extract_histogram_from_struct(
122 struct_array: &Arc<StructArray>,
123) -> Result<RecordBatch, DataFusionError> {
124 let histo_array = HistogramArray::new(struct_array.clone());
125 if histo_array.is_empty() {
126 return Ok(RecordBatch::new_empty(output_schema()));
127 }
128 expand_histogram_to_batch(&histo_array, 0)
129}
130
131fn scalar_to_batch(scalar: &ScalarValue) -> Result<RecordBatch, DataFusionError> {
132 match scalar {
133 ScalarValue::Struct(struct_array) => extract_histogram_from_struct(struct_array),
134 ScalarValue::Dictionary(_, inner) => scalar_to_batch(inner.as_ref()),
135 _ => Err(DataFusionError::Plan(format!(
136 "expand_histogram argument must be a struct (histogram), got: {:?}",
137 scalar.data_type()
138 ))),
139 }
140}
141
142#[derive(Debug)]
144pub struct ExpandHistogramTableProvider {
145 source: HistogramSource,
146}
147
148impl ExpandHistogramTableProvider {
149 pub fn from_scalar(scalar: ScalarValue) -> Result<Self, DataFusionError> {
151 if !matches!(scalar, ScalarValue::Struct(_)) {
152 return Err(DataFusionError::Plan(format!(
153 "expand_histogram argument must be a struct (histogram), got: {:?}",
154 scalar.data_type()
155 )));
156 }
157 Ok(Self {
158 source: HistogramSource::Literal(scalar),
159 })
160 }
161}
162
163#[async_trait]
164impl TableProvider for ExpandHistogramTableProvider {
165 fn as_any(&self) -> &dyn Any {
166 self
167 }
168
169 fn schema(&self) -> SchemaRef {
170 output_schema()
171 }
172
173 fn table_type(&self) -> TableType {
174 TableType::Temporary
175 }
176
177 async fn scan(
178 &self,
179 state: &dyn Session,
180 projection: Option<&Vec<usize>>,
181 _filters: &[Expr],
182 limit: Option<usize>,
183 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
184 let mut record_batch = match &self.source {
185 HistogramSource::Literal(scalar) => scalar_to_batch(scalar)?,
186 HistogramSource::Subquery(plan) => {
187 let physical_plan = state.create_physical_plan(plan).await?;
189 let task_ctx = state.task_ctx();
190 let batches = datafusion::physical_plan::collect(physical_plan, task_ctx).await?;
191
192 if batches.is_empty() || batches[0].num_rows() == 0 {
193 return Err(DataFusionError::Execution(
194 "expand_histogram subquery returned no rows".into(),
195 ));
196 }
197
198 let batch = &batches[0];
199 if batch.num_columns() != 1 {
200 return Err(DataFusionError::Execution(format!(
201 "expand_histogram subquery must return exactly one column, got {}",
202 batch.num_columns()
203 )));
204 }
205
206 let column = batch.column(0);
208 let struct_array = column.as_any().downcast_ref::<StructArray>().ok_or_else(
209 || {
210 DataFusionError::Execution(format!(
211 "expand_histogram subquery must return a struct (histogram), got {:?}",
212 column.data_type()
213 ))
214 },
215 )?;
216
217 let histo_array = HistogramArray::new(Arc::new(struct_array.clone()));
218 if histo_array.is_empty() {
219 RecordBatch::new_empty(output_schema())
220 } else {
221 expand_histogram_to_batch(&histo_array, 0)?
222 }
223 }
224 };
225
226 if let Some(n) = limit
228 && n < record_batch.num_rows()
229 {
230 record_batch = record_batch.slice(0, n);
231 }
232
233 let source = MemorySourceConfig::try_new(
234 &[vec![record_batch]],
235 self.schema(),
236 projection.map(|v| v.to_owned()),
237 )?;
238 Ok(DataSourceExec::from_data_source(source))
239 }
240}