micromegas_analytics/lakehouse/
sql_partition_spec.rs

1use super::{
2    dataframe_time_bounds::DataFrameTimeBounds,
3    view::{PartitionSpec, ViewMetadata},
4    write_partition::write_partition_from_rows,
5};
6use crate::{
7    dfext::typed_column::typed_column_by_name, lakehouse::write_partition::PartitionRowSet,
8    record_batch_transformer::RecordBatchTransformer, response_writer::Logger, time::TimeRange,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::{
14        array::{Int64Array, RecordBatch},
15        datatypes::Schema,
16    },
17    prelude::*,
18};
19use futures::StreamExt;
20use micromegas_ingestion::data_lake_connection::DataLakeConnection;
21use micromegas_tracing::prelude::*;
22use std::sync::Arc;
23
24/// A `PartitionSpec` implementation for SQL-defined partitions.
25pub struct SqlPartitionSpec {
26    ctx: SessionContext,
27    transformer: Arc<dyn RecordBatchTransformer>,
28    compute_time_bounds: Arc<dyn DataFrameTimeBounds>,
29    schema: Arc<Schema>,
30    extract_query: String,
31    view_metadata: ViewMetadata,
32    insert_range: TimeRange,
33    record_count: i64,
34}
35
36impl SqlPartitionSpec {
37    #[expect(clippy::too_many_arguments)]
38    pub fn new(
39        ctx: SessionContext,
40        transformer: Arc<dyn RecordBatchTransformer>,
41        compute_time_bounds: Arc<dyn DataFrameTimeBounds>,
42        schema: Arc<Schema>,
43        extract_query: String,
44        view_metadata: ViewMetadata,
45        insert_range: TimeRange,
46        record_count: i64,
47    ) -> Self {
48        Self {
49            ctx,
50            transformer,
51            compute_time_bounds,
52            schema,
53            extract_query,
54            view_metadata,
55            insert_range,
56            record_count,
57        }
58    }
59}
60
61impl std::fmt::Debug for SqlPartitionSpec {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "SqlPartitionSpec")
64    }
65}
66
67#[async_trait]
68impl PartitionSpec for SqlPartitionSpec {
69    fn is_empty(&self) -> bool {
70        self.record_count < 1
71    }
72
73    fn get_source_data_hash(&self) -> Vec<u8> {
74        self.record_count.to_le_bytes().to_vec()
75    }
76
77    async fn write(&self, lake: Arc<DataLakeConnection>, logger: Arc<dyn Logger>) -> Result<()> {
78        // Allow empty record_count - write_partition_from_rows will create
79        // an empty partition record if no data is sent through the channel
80        let desc = format!(
81            "[{}, {}] {} {}",
82            self.view_metadata.view_set_name,
83            self.view_metadata.view_instance_id,
84            self.insert_range.begin.to_rfc3339(),
85            self.insert_range.end.to_rfc3339()
86        );
87        logger.write_log_entry(format!("writing {desc}")).await?;
88        let df = self.ctx.sql(&self.extract_query).await?;
89        let mut stream = df.execute_stream().await?;
90
91        let (tx, rx) = tokio::sync::mpsc::channel(1);
92        let join_handle = spawn_with_context(write_partition_from_rows(
93            lake.clone(),
94            self.view_metadata.clone(),
95            self.schema.clone(),
96            self.insert_range,
97            self.get_source_data_hash(),
98            rx,
99            logger.clone(),
100        ));
101
102        while let Some(rb_res) = stream.next().await {
103            let rb = self.transformer.transform(rb_res?).await?;
104            let event_time_range = self
105                .compute_time_bounds
106                .get_time_bounds(self.ctx.read_batch(rb.clone())?)
107                .await?;
108            tx.send(PartitionRowSet::new(event_time_range, rb)).await?;
109        }
110        drop(tx);
111        join_handle.await??;
112        Ok(())
113    }
114}
115
116/// Fetches a `SqlPartitionSpec` by executing a count query and an extract query.
117#[expect(clippy::too_many_arguments)]
118pub async fn fetch_sql_partition_spec(
119    ctx: SessionContext,
120    transformer: Arc<dyn RecordBatchTransformer>,
121    compute_time_bounds: Arc<dyn DataFrameTimeBounds>,
122    schema: Arc<Schema>,
123    count_src_sql: String,
124    extract_query: String,
125    view_metadata: ViewMetadata,
126    insert_range: TimeRange,
127) -> Result<SqlPartitionSpec> {
128    let df = ctx.sql(&count_src_sql).await?;
129    let batches: Vec<RecordBatch> = df.collect().await?;
130    if batches.len() != 1 {
131        anyhow::bail!("fetch_sql_partition_spec: query should return a single batch");
132    }
133    let rb = &batches[0];
134    let count_column: &Int64Array = typed_column_by_name(rb, "count")?;
135    if count_column.len() != 1 {
136        anyhow::bail!("fetch_sql_partition_spec: query should return a single row");
137    }
138    let count = count_column.value(0);
139    if count > 0 {
140        trace!(
141            "fetch_sql_partition_spec for view {}, count={count}",
142            &*view_metadata.view_set_name
143        );
144    }
145    Ok(SqlPartitionSpec::new(
146        ctx,
147        transformer,
148        compute_time_bounds,
149        schema,
150        extract_query,
151        view_metadata,
152        insert_range,
153        count,
154    ))
155}