micromegas_datafusion_extensions/jsonb/
get.rs

1use datafusion::arrow::array::{
2    Array, BinaryDictionaryBuilder, DictionaryArray, GenericBinaryArray, StringArray,
3};
4use datafusion::arrow::datatypes::{DataType, Int32Type};
5use datafusion::common::{Result, internal_err};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::{
8    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
9};
10use jsonb::RawJsonb;
11use std::any::Any;
12use std::sync::Arc;
13
14/// A scalar UDF that retrieves a value from a JSONB object by name.
15///
16/// Accepts both Binary and Dictionary<Int32, Binary> inputs.
17/// Returns Dictionary<Int32, Binary> for memory efficiency.
18#[derive(Debug, PartialEq, Eq, Hash)]
19pub struct JsonbGet {
20    signature: Signature,
21}
22
23impl JsonbGet {
24    pub fn new() -> Self {
25        Self {
26            signature: Signature::any(2, Volatility::Immutable),
27        }
28    }
29}
30
31impl Default for JsonbGet {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37fn extract_jsonb_value(jsonb_bytes: &[u8], name: &str) -> Result<Option<Vec<u8>>> {
38    let jsonb = RawJsonb::new(jsonb_bytes);
39    match jsonb.get_by_name(name, true) {
40        Ok(Some(value)) => Ok(Some(value.to_vec())),
41        Ok(None) => Ok(None),
42        Err(e) => Err(DataFusionError::External(e.into())),
43    }
44}
45
46impl ScalarUDFImpl for JsonbGet {
47    fn as_any(&self) -> &dyn Any {
48        self
49    }
50
51    fn name(&self) -> &str {
52        "jsonb_get"
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
60        Ok(DataType::Dictionary(
61            Box::new(DataType::Int32),
62            Box::new(DataType::Binary),
63        ))
64    }
65
66    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
67        let args = ColumnarValue::values_to_arrays(&args.args)?;
68        if args.len() != 2 {
69            return internal_err!("wrong number of arguments to jsonb_get()");
70        }
71
72        let names = args[1]
73            .as_any()
74            .downcast_ref::<StringArray>()
75            .ok_or_else(|| {
76                DataFusionError::Execution("second argument must be a string array".into())
77            })?;
78
79        match args[0].data_type() {
80            DataType::Binary => {
81                // Handle plain Binary JSONB array
82                let binary_array = args[0]
83                    .as_any()
84                    .downcast_ref::<GenericBinaryArray<i32>>()
85                    .ok_or_else(|| {
86                        DataFusionError::Internal("error casting to binary array".into())
87                    })?;
88
89                if binary_array.len() != names.len() {
90                    return internal_err!("arrays of different lengths in jsonb_get()");
91                }
92
93                let mut dict_builder = BinaryDictionaryBuilder::<Int32Type>::new();
94                for i in 0..binary_array.len() {
95                    if binary_array.is_null(i) {
96                        dict_builder.append_null();
97                    } else {
98                        let jsonb_bytes = binary_array.value(i);
99                        let name = names.value(i);
100                        if let Some(value) = extract_jsonb_value(jsonb_bytes, name)? {
101                            dict_builder.append_value(&value);
102                        } else {
103                            dict_builder.append_null();
104                        }
105                    }
106                }
107                Ok(ColumnarValue::Array(Arc::new(dict_builder.finish())))
108            }
109            DataType::Dictionary(_, value_type)
110                if matches!(value_type.as_ref(), DataType::Binary) =>
111            {
112                // Handle dictionary-encoded JSONB array
113                let dict_array = args[0]
114                    .as_any()
115                    .downcast_ref::<DictionaryArray<Int32Type>>()
116                    .ok_or_else(|| {
117                        DataFusionError::Internal("error casting dictionary array".into())
118                    })?;
119
120                if dict_array.len() != names.len() {
121                    return internal_err!("arrays of different lengths in jsonb_get()");
122                }
123
124                let binary_values = dict_array
125                    .values()
126                    .as_any()
127                    .downcast_ref::<GenericBinaryArray<i32>>()
128                    .ok_or_else(|| {
129                        DataFusionError::Internal("dictionary values are not a binary array".into())
130                    })?;
131
132                let mut dict_builder = BinaryDictionaryBuilder::<Int32Type>::new();
133                for i in 0..dict_array.len() {
134                    if dict_array.is_null(i) {
135                        dict_builder.append_null();
136                    } else {
137                        let key_index = dict_array.keys().value(i) as usize;
138                        if key_index < binary_values.len() {
139                            let jsonb_bytes = binary_values.value(key_index);
140                            let name = names.value(i);
141                            if let Some(value) = extract_jsonb_value(jsonb_bytes, name)? {
142                                dict_builder.append_value(&value);
143                            } else {
144                                dict_builder.append_null();
145                            }
146                        } else {
147                            return internal_err!(
148                                "Dictionary key index out of bounds in jsonb_get"
149                            );
150                        }
151                    }
152                }
153                Ok(ColumnarValue::Array(Arc::new(dict_builder.finish())))
154            }
155            _ => internal_err!(
156                "jsonb_get: unsupported input type, expected Binary or Dictionary<Int32, Binary>"
157            ),
158        }
159    }
160}
161
162/// Creates a user-defined function to get a value from a JSONB object by name.
163pub fn make_jsonb_get_udf() -> ScalarUDF {
164    ScalarUDF::new_from_impl(JsonbGet::new())
165}