micromegas_datafusion_extensions/jsonb/
keys.rs

1use datafusion::arrow::array::{
2    Array, ArrayRef, DictionaryArray, GenericBinaryArray, Int32Array, ListBuilder, StringBuilder,
3};
4use datafusion::arrow::datatypes::{DataType, Field, Int32Type};
5use datafusion::common::{Result, internal_err};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::{
8    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
9};
10use jsonb::RawJsonb;
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15/// A scalar UDF that extracts the keys from a JSONB object.
16///
17/// Accepts both Binary and Dictionary<Int32, Binary> inputs.
18/// Returns Dictionary<Int32, List<Utf8>> containing the object keys, or null if input is not an object.
19/// Dictionary encoding is used because JSONB values (especially properties) are often repeated.
20#[derive(Debug, PartialEq, Eq, Hash)]
21pub struct JsonbObjectKeys {
22    signature: Signature,
23}
24
25impl JsonbObjectKeys {
26    pub fn new() -> Self {
27        Self {
28            signature: Signature::any(1, Volatility::Immutable),
29        }
30    }
31}
32
33impl Default for JsonbObjectKeys {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39fn extract_keys_from_jsonb(jsonb_bytes: &[u8]) -> Result<Option<Vec<String>>> {
40    let jsonb = RawJsonb::new(jsonb_bytes);
41    match jsonb.object_keys() {
42        Ok(Some(keys_jsonb)) => {
43            // keys_jsonb is a JSONB array of string keys
44            let keys_raw = keys_jsonb.as_raw();
45            match keys_raw.array_values() {
46                Ok(Some(values)) => {
47                    let mut keys = Vec::with_capacity(values.len());
48                    for value in values {
49                        let raw = value.as_raw();
50                        match raw.as_str() {
51                            Ok(Some(s)) => keys.push(s.to_string()),
52                            Ok(None) => {
53                                // Key is not a string (shouldn't happen for object keys)
54                                return Ok(None);
55                            }
56                            Err(e) => return Err(DataFusionError::External(e.into())),
57                        }
58                    }
59                    Ok(Some(keys))
60                }
61                Ok(None) => Ok(Some(Vec::new())), // Empty array
62                Err(e) => Err(DataFusionError::External(e.into())),
63            }
64        }
65        Ok(None) => Ok(None), // Input is not an object
66        Err(e) => Err(DataFusionError::External(e.into())),
67    }
68}
69
70impl ScalarUDFImpl for JsonbObjectKeys {
71    fn as_any(&self) -> &dyn Any {
72        self
73    }
74
75    fn name(&self) -> &str {
76        "jsonb_object_keys"
77    }
78
79    fn signature(&self) -> &Signature {
80        &self.signature
81    }
82
83    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
84        Ok(DataType::Dictionary(
85            Box::new(DataType::Int32),
86            Box::new(DataType::List(Arc::new(Field::new_list_field(
87                DataType::Utf8,
88                true,
89            )))),
90        ))
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        let args = ColumnarValue::values_to_arrays(&args.args)?;
95        if args.len() != 1 {
96            return internal_err!("wrong number of arguments to jsonb_object_keys()");
97        }
98
99        match args[0].data_type() {
100            DataType::Binary => {
101                let binary_array = args[0]
102                    .as_any()
103                    .downcast_ref::<GenericBinaryArray<i32>>()
104                    .ok_or_else(|| {
105                        DataFusionError::Internal("error casting to binary array".into())
106                    })?;
107
108                let result = build_dict_list_array(binary_array.len(), |i| {
109                    if binary_array.is_null(i) {
110                        Ok(None)
111                    } else {
112                        extract_keys_from_jsonb(binary_array.value(i))
113                    }
114                })?;
115                Ok(ColumnarValue::Array(result))
116            }
117            DataType::Dictionary(_, value_type)
118                if matches!(value_type.as_ref(), DataType::Binary) =>
119            {
120                let dict_array = args[0]
121                    .as_any()
122                    .downcast_ref::<DictionaryArray<Int32Type>>()
123                    .ok_or_else(|| {
124                        DataFusionError::Internal("error casting dictionary array".into())
125                    })?;
126
127                let binary_values = dict_array
128                    .values()
129                    .as_any()
130                    .downcast_ref::<GenericBinaryArray<i32>>()
131                    .ok_or_else(|| {
132                        DataFusionError::Internal("dictionary values are not a binary array".into())
133                    })?;
134
135                let result = build_dict_list_array(dict_array.len(), |i| {
136                    if dict_array.is_null(i) {
137                        Ok(None)
138                    } else {
139                        let key_index = dict_array.keys().value(i) as usize;
140                        if key_index < binary_values.len() {
141                            extract_keys_from_jsonb(binary_values.value(key_index))
142                        } else {
143                            internal_err!("Dictionary key index out of bounds in jsonb_object_keys")
144                        }
145                    }
146                })?;
147                Ok(ColumnarValue::Array(result))
148            }
149            _ => internal_err!(
150                "jsonb_object_keys: unsupported input type, expected Binary or Dictionary<Int32, Binary>"
151            ),
152        }
153    }
154}
155
156/// Build a Dictionary<Int32, List<Utf8>> array from a function that returns keys for each index.
157/// Uses a HashMap to deduplicate identical key lists for memory efficiency.
158/// Returns None from get_keys to indicate a null output (distinct from Some(empty vec) for empty objects).
159fn build_dict_list_array<F>(len: usize, mut get_keys: F) -> Result<ArrayRef>
160where
161    F: FnMut(usize) -> Result<Option<Vec<String>>>,
162{
163    // Map from key list to dictionary index (only for non-null results)
164    let mut unique_lists: HashMap<Vec<String>, i32> = HashMap::new();
165    let mut key_indices: Vec<Option<i32>> = Vec::with_capacity(len);
166    let mut ordered_lists: Vec<Vec<String>> = Vec::new();
167
168    // First pass: collect all values and deduplicate
169    for i in 0..len {
170        let keys = get_keys(i)?;
171        match keys {
172            Some(key_list) => {
173                if let Some(idx) = unique_lists.get(&key_list) {
174                    key_indices.push(Some(*idx));
175                } else {
176                    let idx = ordered_lists.len() as i32;
177                    unique_lists.insert(key_list.clone(), idx);
178                    key_indices.push(Some(idx));
179                    ordered_lists.push(key_list);
180                }
181            }
182            None => {
183                // Null input produces null dictionary entry (null key)
184                key_indices.push(None);
185            }
186        }
187    }
188
189    // Build the values array (List<Utf8>) from unique lists
190    let mut list_builder = ListBuilder::new(StringBuilder::new());
191    for keys in &ordered_lists {
192        for key in keys {
193            list_builder.values().append_value(key);
194        }
195        list_builder.append(true);
196    }
197    let values_array = Arc::new(list_builder.finish());
198
199    // Build the keys array (None values become null keys)
200    let keys_array = Int32Array::from(key_indices);
201
202    // Construct the dictionary array
203    let dict_array =
204        DictionaryArray::<Int32Type>::try_new(keys_array, values_array).map_err(|e| {
205            DataFusionError::Internal(format!("Failed to create dictionary array: {e}"))
206        })?;
207
208    Ok(Arc::new(dict_array))
209}
210
211/// Creates a user-defined function to extract the keys from a JSONB object.
212pub fn make_jsonb_object_keys_udf() -> ScalarUDF {
213    ScalarUDF::new_from_impl(JsonbObjectKeys::new())
214}