micromegas_analytics/lakehouse/
get_payload_function.rs

1use async_trait::async_trait;
2use datafusion::{
3    arrow::{
4        array::{Array, BinaryBuilder, StringArray},
5        datatypes::DataType,
6    },
7    common::{internal_err, not_impl_err},
8    error::DataFusionError,
9    logical_expr::{
10        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
11        async_udf::AsyncScalarUDFImpl,
12    },
13};
14use futures::stream::StreamExt;
15use micromegas_ingestion::data_lake_connection::DataLakeConnection;
16use micromegas_tracing::prelude::*;
17use std::sync::Arc;
18
19/// A scalar UDF that retrieves the payload of a block from the data lake.
20#[derive(Debug)]
21pub struct GetPayload {
22    signature: Signature,
23    lake: Arc<DataLakeConnection>,
24}
25
26impl PartialEq for GetPayload {
27    fn eq(&self, other: &Self) -> bool {
28        self.signature == other.signature
29    }
30}
31
32impl Eq for GetPayload {}
33
34impl std::hash::Hash for GetPayload {
35    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36        self.signature.hash(state);
37    }
38}
39
40impl GetPayload {
41    pub fn new(lake: Arc<DataLakeConnection>) -> Self {
42        Self {
43            signature: Signature::exact(
44                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
45                Volatility::Immutable,
46            ),
47            lake,
48        }
49    }
50}
51
52impl ScalarUDFImpl for GetPayload {
53    fn as_any(&self) -> &dyn std::any::Any {
54        self
55    }
56
57    fn name(&self) -> &str {
58        "get_payload"
59    }
60
61    fn signature(&self) -> &Signature {
62        &self.signature
63    }
64
65    fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
66        Ok(DataType::Binary)
67    }
68
69    fn invoke_with_args(
70        &self,
71        _args: datafusion::logical_expr::ScalarFunctionArgs,
72    ) -> datafusion::error::Result<ColumnarValue> {
73        not_impl_err!("GetPayload can only be called from async contexts")
74    }
75}
76
77#[async_trait]
78impl AsyncScalarUDFImpl for GetPayload {
79    async fn invoke_async_with_args(
80        &self,
81        args: ScalarFunctionArgs,
82    ) -> datafusion::error::Result<ColumnarValue> {
83        let args = ColumnarValue::values_to_arrays(&args.args)?;
84        if args.len() != 3 {
85            return internal_err!("wrong number of arguments to get_payload()");
86        }
87        let process_ids = args[0]
88            .as_any()
89            .downcast_ref::<StringArray>()
90            .ok_or_else(|| {
91                DataFusionError::Execution("downcasting process_ids in GetPayload".into())
92            })?
93            .clone();
94        let stream_ids = args[1]
95            .as_any()
96            .downcast_ref::<StringArray>()
97            .ok_or_else(|| {
98                DataFusionError::Execution("downcasting stream_ids in GetPayload".into())
99            })?
100            .clone();
101        let block_ids = args[2]
102            .as_any()
103            .downcast_ref::<StringArray>()
104            .ok_or_else(|| {
105                DataFusionError::Execution("downcasting block_ids in GetPayload".into())
106            })?
107            .clone();
108        let lake = self.lake.clone();
109        let mut stream = futures::stream::iter(0..process_ids.len())
110            .map(|i| {
111                let process_id = process_ids.value(i);
112                let stream_id = stream_ids.value(i);
113                let block_id = block_ids.value(i);
114                let obj_path = format!("blobs/{process_id}/{stream_id}/{block_id}");
115                let lake = lake.clone();
116                spawn_with_context(async move { lake.blob_storage.read_blob(&obj_path).await })
117            })
118            .buffered(10);
119        let mut result_builder = BinaryBuilder::with_capacity(block_ids.len(), 1024 * 1024);
120        while let Some(res) = stream.next().await {
121            result_builder.append_value(
122                res.map_err(|e| {
123                    DataFusionError::Execution(format!("error downloading payload: {e:?}"))
124                })?
125                .map_err(|e| {
126                    DataFusionError::Execution(format!("error downloading payload: {e:?}"))
127                })?,
128            );
129        }
130        Ok(ColumnarValue::Array(Arc::new(result_builder.finish())))
131    }
132}