micromegas_datafusion_extensions/jsonb/
each.rs

1use async_trait::async_trait;
2use datafusion::arrow::array::{Array, ArrayRef, BinaryArray, DictionaryArray, GenericBinaryArray};
3use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
4use datafusion::arrow::record_batch::RecordBatch;
5use datafusion::catalog::Session;
6use datafusion::catalog::TableFunctionImpl;
7use datafusion::catalog::TableProvider;
8use datafusion::datasource::TableType;
9use datafusion::datasource::memory::{DataSourceExec, MemorySourceConfig};
10use datafusion::error::DataFusionError;
11use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
12use datafusion::physical_plan::ExecutionPlan;
13use datafusion::prelude::Expr;
14use datafusion::scalar::ScalarValue;
15use jsonb::RawJsonb;
16use std::any::Any;
17use std::sync::Arc;
18
19/// A DataFusion `TableFunctionImpl` that expands a JSONB object or array into rows of (key, value).
20///
21/// For objects, `key` is the field name. For arrays, `key` is the element index (as a string).
22///
23/// Usage:
24/// ```sql
25/// SELECT key, jsonb_as_string(value)
26/// FROM jsonb_each(
27///   (SELECT properties FROM processes WHERE process_id = '...')
28/// )
29/// ```
30#[derive(Debug)]
31pub struct JsonbEachTableFunction {}
32
33impl JsonbEachTableFunction {
34    pub fn new() -> Self {
35        Self {}
36    }
37}
38
39impl Default for JsonbEachTableFunction {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45/// The source of JSONB data — either a literal value or a subquery to evaluate.
46#[derive(Debug, Clone)]
47enum JsonbSource {
48    Literal(ScalarValue),
49    Subquery(Arc<LogicalPlan>),
50}
51
52impl TableFunctionImpl for JsonbEachTableFunction {
53    fn call(&self, args: &[Expr]) -> datafusion::error::Result<Arc<dyn TableProvider>> {
54        if args.len() != 1 {
55            return Err(DataFusionError::Plan(
56                "jsonb_each requires exactly one argument (a JSONB object)".into(),
57            ));
58        }
59
60        let source = match &args[0] {
61            Expr::Literal(scalar, _metadata) => JsonbSource::Literal(scalar.clone()),
62            Expr::ScalarSubquery(subquery) => JsonbSource::Subquery(subquery.subquery.clone()),
63            other => {
64                let plan = LogicalPlanBuilder::empty(true)
65                    .project(vec![other.clone()])?
66                    .build()?;
67                JsonbSource::Subquery(Arc::new(plan))
68            }
69        };
70
71        Ok(Arc::new(JsonbEachTableProvider { source }))
72    }
73}
74
75fn output_schema() -> SchemaRef {
76    Arc::new(Schema::new(vec![
77        Field::new("key", DataType::Utf8, false),
78        Field::new("value", DataType::Binary, false),
79    ]))
80}
81
82/// Extract key-value entries from JSONB bytes.
83///
84/// For objects, uses `object_each()` with field names as keys.
85/// For arrays, uses `array_values()` with element indices as keys.
86fn extract_entries_from_jsonb(
87    jsonb_bytes: &[u8],
88) -> Result<Vec<(String, Vec<u8>)>, DataFusionError> {
89    let jsonb = RawJsonb::new(jsonb_bytes);
90    match jsonb.object_each() {
91        Ok(Some(entries)) => {
92            return Ok(entries
93                .into_iter()
94                .map(|(k, v)| (k, v.as_ref().to_vec()))
95                .collect());
96        }
97        Ok(None) => {}
98        Err(e) => return Err(DataFusionError::External(e.into())),
99    }
100    match jsonb.array_values() {
101        Ok(Some(values)) => Ok(values
102            .into_iter()
103            .enumerate()
104            .map(|(i, v)| (i.to_string(), v.as_ref().to_vec()))
105            .collect()),
106        Ok(None) => Err(DataFusionError::Execution(
107            "jsonb_each: input is not a JSONB object or array".into(),
108        )),
109        Err(e) => Err(DataFusionError::External(e.into())),
110    }
111}
112
113fn entries_to_batch(entries: &[(String, Vec<u8>)]) -> Result<RecordBatch, DataFusionError> {
114    if entries.is_empty() {
115        return Ok(RecordBatch::new_empty(output_schema()));
116    }
117
118    let keys: Vec<&str> = entries.iter().map(|(k, _)| k.as_str()).collect();
119    let values: Vec<&[u8]> = entries.iter().map(|(_, v)| v.as_slice()).collect();
120
121    let key_array: ArrayRef = Arc::new(datafusion::arrow::array::StringArray::from(keys));
122    let value_array: ArrayRef = Arc::new(BinaryArray::from(values));
123
124    RecordBatch::try_new(output_schema(), vec![key_array, value_array])
125        .map_err(|e| DataFusionError::External(e.into()))
126}
127
128fn scalar_to_entries(scalar: &ScalarValue) -> Result<Vec<(String, Vec<u8>)>, DataFusionError> {
129    match scalar {
130        ScalarValue::Binary(Some(bytes)) => extract_entries_from_jsonb(bytes),
131        ScalarValue::Binary(None) => Ok(vec![]),
132        ScalarValue::Dictionary(_, inner) => scalar_to_entries(inner.as_ref()),
133        _ => Err(DataFusionError::Plan(format!(
134            "jsonb_each argument must be Binary (JSONB), got: {:?}",
135            scalar.data_type()
136        ))),
137    }
138}
139
140/// Extract JSONB bytes from all rows of a column, handling both plain Binary
141/// and Dictionary<Int32, Binary> encodings.
142fn extract_all_jsonb_bytes_from_column(column: &ArrayRef) -> Result<Vec<Vec<u8>>, DataFusionError> {
143    match column.data_type() {
144        DataType::Binary => {
145            let binary_array = column
146                .as_any()
147                .downcast_ref::<GenericBinaryArray<i32>>()
148                .ok_or_else(|| {
149                    DataFusionError::Execution("failed to cast column to BinaryArray".into())
150                })?;
151            Ok((0..binary_array.len())
152                .filter(|&i| !binary_array.is_null(i))
153                .map(|i| binary_array.value(i).to_vec())
154                .collect())
155        }
156        DataType::Dictionary(_, value_type) if matches!(value_type.as_ref(), DataType::Binary) => {
157            let dict_array = column
158                .as_any()
159                .downcast_ref::<DictionaryArray<Int32Type>>()
160                .ok_or_else(|| {
161                    DataFusionError::Execution(
162                        "failed to cast column to DictionaryArray<Int32, Binary>".into(),
163                    )
164                })?;
165            let binary_values = dict_array
166                .values()
167                .as_any()
168                .downcast_ref::<GenericBinaryArray<i32>>()
169                .ok_or_else(|| {
170                    DataFusionError::Execution("dictionary values are not a binary array".into())
171                })?;
172            Ok((0..dict_array.len())
173                .filter(|&i| !dict_array.is_null(i))
174                .map(|i| {
175                    let key_index = dict_array.keys().value(i) as usize;
176                    binary_values.value(key_index).to_vec()
177                })
178                .collect())
179        }
180        other => Err(DataFusionError::Execution(format!(
181            "jsonb_each subquery must return a Binary or Dictionary<Int32, Binary> column, got: {other:?}"
182        ))),
183    }
184}
185
186/// Table provider for expanding JSONB objects into key-value rows.
187#[derive(Debug)]
188pub struct JsonbEachTableProvider {
189    source: JsonbSource,
190}
191
192impl JsonbEachTableProvider {
193    /// Creates a new provider from a JSONB scalar value (for testing).
194    pub fn from_scalar(scalar: ScalarValue) -> Result<Self, DataFusionError> {
195        if !matches!(&scalar, ScalarValue::Binary(Some(_))) {
196            return Err(DataFusionError::Plan(format!(
197                "jsonb_each argument must be Binary (JSONB), got: {:?}",
198                scalar.data_type()
199            )));
200        }
201        Ok(Self {
202            source: JsonbSource::Literal(scalar),
203        })
204    }
205}
206
207#[async_trait]
208impl TableProvider for JsonbEachTableProvider {
209    fn as_any(&self) -> &dyn Any {
210        self
211    }
212
213    fn schema(&self) -> SchemaRef {
214        output_schema()
215    }
216
217    fn table_type(&self) -> TableType {
218        TableType::Temporary
219    }
220
221    async fn scan(
222        &self,
223        state: &dyn Session,
224        projection: Option<&Vec<usize>>,
225        _filters: &[Expr],
226        limit: Option<usize>,
227    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
228        let entries = match &self.source {
229            JsonbSource::Literal(scalar) => scalar_to_entries(scalar)?,
230            JsonbSource::Subquery(plan) => {
231                let physical_plan = state.create_physical_plan(plan).await?;
232                let task_ctx = state.task_ctx();
233                let batches = datafusion::physical_plan::collect(physical_plan, task_ctx).await?;
234
235                if batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0) {
236                    return Err(DataFusionError::Execution(
237                        "jsonb_each subquery returned no rows".into(),
238                    ));
239                }
240
241                let mut all_entries = Vec::new();
242                for batch in &batches {
243                    if batch.num_columns() != 1 {
244                        return Err(DataFusionError::Execution(format!(
245                            "jsonb_each subquery must return exactly one column, got {}",
246                            batch.num_columns()
247                        )));
248                    }
249                    for jsonb_bytes in extract_all_jsonb_bytes_from_column(batch.column(0))? {
250                        all_entries.extend(extract_entries_from_jsonb(&jsonb_bytes)?);
251                    }
252                }
253                all_entries
254            }
255        };
256
257        let mut record_batch = entries_to_batch(&entries)?;
258
259        // Apply limit if specified
260        if let Some(n) = limit
261            && n < record_batch.num_rows()
262        {
263            record_batch = record_batch.slice(0, n);
264        }
265
266        let source = MemorySourceConfig::try_new(
267            &[vec![record_batch]],
268            self.schema(),
269            projection.map(|v| v.to_owned()),
270        )?;
271        Ok(DataSourceExec::from_data_source(source))
272    }
273}