micromegas_analytics/lakehouse/
get_payload_function.rs1use async_trait::async_trait;
2use datafusion::{
3 arrow::{
4 array::{Array, BinaryBuilder, StringArray},
5 datatypes::DataType,
6 },
7 common::{internal_err, not_impl_err},
8 error::DataFusionError,
9 logical_expr::{
10 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
11 async_udf::AsyncScalarUDFImpl,
12 },
13};
14use futures::stream::StreamExt;
15use micromegas_ingestion::data_lake_connection::DataLakeConnection;
16use micromegas_tracing::prelude::*;
17use std::sync::Arc;
18
19#[derive(Debug)]
21pub struct GetPayload {
22 signature: Signature,
23 lake: Arc<DataLakeConnection>,
24}
25
26impl PartialEq for GetPayload {
27 fn eq(&self, other: &Self) -> bool {
28 self.signature == other.signature
29 }
30}
31
32impl Eq for GetPayload {}
33
34impl std::hash::Hash for GetPayload {
35 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36 self.signature.hash(state);
37 }
38}
39
40impl GetPayload {
41 pub fn new(lake: Arc<DataLakeConnection>) -> Self {
42 Self {
43 signature: Signature::exact(
44 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
45 Volatility::Immutable,
46 ),
47 lake,
48 }
49 }
50}
51
52impl ScalarUDFImpl for GetPayload {
53 fn as_any(&self) -> &dyn std::any::Any {
54 self
55 }
56
57 fn name(&self) -> &str {
58 "get_payload"
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
66 Ok(DataType::Binary)
67 }
68
69 fn invoke_with_args(
70 &self,
71 _args: datafusion::logical_expr::ScalarFunctionArgs,
72 ) -> datafusion::error::Result<ColumnarValue> {
73 not_impl_err!("GetPayload can only be called from async contexts")
74 }
75}
76
77#[async_trait]
78impl AsyncScalarUDFImpl for GetPayload {
79 async fn invoke_async_with_args(
80 &self,
81 args: ScalarFunctionArgs,
82 ) -> datafusion::error::Result<ColumnarValue> {
83 let args = ColumnarValue::values_to_arrays(&args.args)?;
84 if args.len() != 3 {
85 return internal_err!("wrong number of arguments to get_payload()");
86 }
87 let process_ids = args[0]
88 .as_any()
89 .downcast_ref::<StringArray>()
90 .ok_or_else(|| {
91 DataFusionError::Execution("downcasting process_ids in GetPayload".into())
92 })?
93 .clone();
94 let stream_ids = args[1]
95 .as_any()
96 .downcast_ref::<StringArray>()
97 .ok_or_else(|| {
98 DataFusionError::Execution("downcasting stream_ids in GetPayload".into())
99 })?
100 .clone();
101 let block_ids = args[2]
102 .as_any()
103 .downcast_ref::<StringArray>()
104 .ok_or_else(|| {
105 DataFusionError::Execution("downcasting block_ids in GetPayload".into())
106 })?
107 .clone();
108 let lake = self.lake.clone();
109 let mut stream = futures::stream::iter(0..process_ids.len())
110 .map(|i| {
111 let process_id = process_ids.value(i);
112 let stream_id = stream_ids.value(i);
113 let block_id = block_ids.value(i);
114 let obj_path = format!("blobs/{process_id}/{stream_id}/{block_id}");
115 let lake = lake.clone();
116 spawn_with_context(async move { lake.blob_storage.read_blob(&obj_path).await })
117 })
118 .buffered(10);
119 let mut result_builder = BinaryBuilder::with_capacity(block_ids.len(), 1024 * 1024);
120 while let Some(res) = stream.next().await {
121 result_builder.append_value(
122 res.map_err(|e| {
123 DataFusionError::Execution(format!("error downloading payload: {e:?}"))
124 })?
125 .map_err(|e| {
126 DataFusionError::Execution(format!("error downloading payload: {e:?}"))
127 })?,
128 );
129 }
130 Ok(ColumnarValue::Array(Arc::new(result_builder.finish())))
131 }
132}