micromegas/client/
frame_budget_reporting.rs

1use super::flightsql_client::Client;
2use anyhow::{Context, Result};
3use async_stream::try_stream;
4use chrono::{DateTime, Utc};
5use datafusion::{
6    arrow::{
7        self,
8        array::{ListBuilder, RecordBatch, StringBuilder, StructBuilder, TimestampNanosecondArray},
9        datatypes::{DataType, Field, Fields, TimestampNanosecondType},
10    },
11    catalog::MemTable,
12    error::DataFusionError,
13    logical_expr::ScalarUDF,
14    physical_plan::stream::RecordBatchReceiverStreamBuilder,
15    prelude::*,
16    scalar::ScalarValue,
17};
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use micromegas_analytics::{
21    dfext::{
22        string_column_accessor::string_column_by_name,
23        typed_column::{
24            get_only_primitive_value, get_only_string_value,
25            get_single_row_primitive_value_by_name, typed_column_by_name,
26        },
27    },
28    properties::property_get::PropertyGet,
29    time::TimeRange,
30};
31use std::{collections::HashMap, sync::Arc};
32
33/// Defines how to group frame budgets.
34#[derive(Clone)]
35pub enum GroupBy {
36    /// Group by a specific budget, mapping span names to budget categories.
37    Budget(HashMap<String, String>),
38    /// Group by the span name itself.
39    SpanName,
40}
41
42/// Converts a map of span names to budget categories into a `ScalarValue` representing a list of properties.
43///
44/// This function is used to pass the budget mapping to DataFusion as a scalar value.
45pub fn budget_map_to_properties(
46    span_name_to_budget: &HashMap<String, String>,
47) -> Result<ScalarValue> {
48    let prop_struct_fields = vec![
49        Field::new("key", DataType::Utf8, false),
50        Field::new("value", DataType::Utf8, false),
51    ];
52    let prop_field = Arc::new(Field::new(
53        "Property",
54        DataType::Struct(Fields::from(prop_struct_fields.clone())),
55        false,
56    ));
57    let mut props_builder =
58        ListBuilder::new(StructBuilder::from_fields(prop_struct_fields, 10)).with_field(prop_field);
59
60    for (k, v) in span_name_to_budget.iter() {
61        let property_builder = props_builder.values();
62        let key_builder = property_builder
63            .field_builder::<StringBuilder>(0)
64            .with_context(|| "getting key field builder")?;
65        key_builder.append_value(k);
66        let value_builder = property_builder
67            .field_builder::<StringBuilder>(1)
68            .with_context(|| "getting value field builder")?;
69        value_builder.append_value(v);
70        property_builder.append(true);
71    }
72    props_builder.append(true);
73    let array = props_builder.finish();
74    Ok(ScalarValue::List(Arc::new(array)))
75}
76
77/// Retrieves the time range (min begin, max end) from a `RecordBatch`.
78///
79/// This function assumes the `RecordBatch` contains "begin" and "end" columns of type `TimestampNanosecondType`.
80pub fn get_record_batch_time_range(rb: &RecordBatch) -> Result<Option<TimeRange>> {
81    if rb.num_rows() == 0 {
82        return Ok(None);
83    }
84    let begin_column: &TimestampNanosecondArray = typed_column_by_name(rb, "begin")?;
85    let end_column: &TimestampNanosecondArray = typed_column_by_name(rb, "end")?;
86    let min_begin = DateTime::from_timestamp_nanos(
87        arrow::compute::min(begin_column).with_context(|| "min(begin)")?,
88    );
89    let max_end = DateTime::from_timestamp_nanos(
90        arrow::compute::max(end_column).with_context(|| "max(end)")?,
91    );
92    Ok(Some(TimeRange::new(min_begin, max_end)))
93}
94
95/// Fetches spans for a given stream and frames, grouped by the specified configuration.
96///
97/// This function queries the FlightSQL server for spans within the time range of the provided frames
98/// and groups them according to the `group_by_config`.
99pub async fn fetch_spans_batch(
100    client: &mut Client,
101    stream_id: &str,
102    frames_rb: RecordBatch,
103    group_by_config: &GroupBy,
104) -> Result<Vec<RecordBatch>> {
105    let time_range = get_record_batch_time_range(&frames_rb)?;
106    if time_range.is_none() {
107        return Ok(vec![]);
108    }
109    let time_range = time_range.unwrap();
110    match group_by_config {
111        GroupBy::Budget(span_to_budget) => {
112            let sql = format!(
113                "SELECT name, begin, end, duration
114                 FROM view_instance('thread_spans', '{stream_id}')
115                 "
116            );
117            let spans_rbs = client.query(sql, Some(time_range)).await?;
118
119            // add budget column locally
120            let ctx = SessionContext::new();
121            let table = MemTable::try_new(spans_rbs[0].schema(), vec![spans_rbs])?;
122            ctx.register_table("spans", Arc::new(table))?;
123            ctx.register_udf(ScalarUDF::from(PropertyGet::new()));
124
125            let spans = ctx
126		.sql(
127		    "SELECT name, begin, end, duration, property_get($span_to_budget_map, name) as budget
128                     FROM spans
129                     WHERE property_get($span_to_budget_map, name) IS NOT NULL",
130		)
131		.await?
132		.with_param_values(vec![(
133		    "span_to_budget_map",
134		    budget_map_to_properties(span_to_budget)?,
135		)])?
136		.collect()
137		.await?;
138            Ok(spans)
139        }
140        GroupBy::SpanName => {
141            let sql = format!(
142                "SELECT name, name as budget, begin, end, duration
143                 FROM view_instance('thread_spans', '{stream_id}')
144                "
145            );
146            let spans_rbs = client.query(sql, Some(time_range)).await?;
147            Ok(spans_rbs)
148        }
149    }
150}
151
152/// Extracts top offenders from the frame statistics.
153///
154/// This function queries the `frame_stats` table (which is expected to be registered in the `SessionContext`)
155/// and returns the top 100 offenders by `duration_in_frame` for each budget.
156pub async fn extract_top_offenders(ctx: &SessionContext) -> Result<Vec<RecordBatch>> {
157    let budgets_rbs = ctx
158        .sql("SELECT DISTINCT budget FROM frame_stats ORDER BY budget")
159        .await?
160        .collect()
161        .await?;
162    let top_offenders_df = ctx
163        .sql(
164            "SELECT budget, duration_in_frame, begin_frame, end_frame, process_id
165             FROM frame_stats
166             WHERE budget = $budget
167             ORDER BY duration_in_frame DESC
168             LIMIT 100
169             ",
170        )
171        .await?;
172    let mut builder =
173        RecordBatchReceiverStreamBuilder::new(top_offenders_df.schema().inner().clone(), 100);
174    for budgets_rb in budgets_rbs {
175        let budget_column = string_column_by_name(&budgets_rb, "budget")?;
176        for budget_row in 0..budgets_rb.num_rows() {
177            let budget = budget_column.value(budget_row)?;
178            let df = top_offenders_df
179                .clone()
180                .with_param_values(vec![("budget", ScalarValue::Utf8(Some(budget.to_owned())))])?;
181            let sender = builder.tx();
182            builder.spawn(async move {
183                for rb in df.collect().await? {
184                    sender.send(Ok(rb)).await.map_err(|e| {
185                        DataFusionError::Execution(format!("sending record batch: {e:?}"))
186                    })?;
187                }
188                Ok(())
189            });
190        }
191    }
192    let mut top_offenders_rbs = vec![];
193    let mut top_stream = builder.build();
194    while let Some(rb_res) = top_stream.next().await {
195        top_offenders_rbs.push(rb_res?);
196    }
197    Ok(top_offenders_rbs)
198}
199
200/// Computes frame statistics for a batch of frames.
201///
202/// This function queries the `spans` table (which is expected to be registered in the `SessionContext`)
203/// and computes statistics (count, sum of duration) for each budget within the given frame.
204pub async fn compute_frame_stats_for_batch(
205    ctx: &SessionContext,
206    frames_rb: RecordBatch,
207    process_id: &str,
208) -> Result<Vec<RecordBatch>> {
209    let frame_stats_df = ctx
210        .sql(
211            "SELECT budget,
212                    count(*) as count_in_frame,
213                    sum(duration) as duration_in_frame,
214                    to_timestamp_nanos($begin_frame) as begin_frame,
215                    to_timestamp_nanos($end_frame) as end_frame,
216                    arrow_cast($process_id, 'Utf8') as process_id
217             FROM spans
218             WHERE begin >= $begin_frame
219             AND end <= $end_frame
220             GROUP BY budget
221             ",
222        )
223        .await
224        .with_context(|| "frame_stats_df")?;
225
226    let mut builder =
227        RecordBatchReceiverStreamBuilder::new(frame_stats_df.schema().inner().clone(), 100);
228    let utc: Arc<str> = Arc::from("+00:00");
229    let begin_frame_column: &TimestampNanosecondArray =
230        typed_column_by_name(&frames_rb, "begin")
231            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
232    let end_frame_column: &TimestampNanosecondArray = typed_column_by_name(&frames_rb, "end")
233        .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
234    for iframe in 0..frames_rb.num_rows() {
235        let begin_frame = begin_frame_column.value(iframe);
236        let end_frame = end_frame_column.value(iframe);
237        let df = frame_stats_df.clone().with_param_values(vec![
238            (
239                "begin_frame",
240                ScalarValue::TimestampNanosecond(Some(begin_frame), Some(utc.clone())),
241            ),
242            (
243                "end_frame",
244                ScalarValue::TimestampNanosecond(Some(end_frame), Some(utc.clone())),
245            ),
246            ("process_id", ScalarValue::Utf8(Some(process_id.to_owned()))),
247        ])?;
248        let sender = builder.tx();
249        builder.spawn(async move {
250            for rb in df.collect().await? {
251                sender.send(Ok(rb)).await.map_err(|e| {
252                    DataFusionError::Execution(format!("sending record batch: {e:?}"))
253                })?;
254            }
255            Ok(())
256        });
257    }
258
259    let mut frame_stats_rbs = vec![];
260    let mut stream = builder.build();
261    while let Some(rb_res) = stream.next().await {
262        frame_stats_rbs.push(rb_res?);
263    }
264    Ok(frame_stats_rbs)
265}
266
267/// Merges top offenders from multiple record batches.
268///
269/// This function takes a vector of `RecordBatch`es, combines them into a single table,
270/// and then calls `extract_top_offenders` to re-extract the top offenders from the merged data.
271pub async fn merge_top_offenders(top_offenders: Vec<RecordBatch>) -> Result<Vec<RecordBatch>> {
272    if top_offenders.is_empty() {
273        return Ok(top_offenders);
274    }
275    let ctx = SessionContext::new();
276    let table = MemTable::try_new(top_offenders[0].schema(), vec![top_offenders])?;
277    // it works because offenders have the same schema as frame_stats entries
278    ctx.register_table("frame_stats", Arc::new(table))?;
279    extract_top_offenders(&ctx).await
280}
281
282/// Processes a batch of frames, computing frame statistics and extracting top offenders.
283///
284/// This function first computes frame statistics using `compute_frame_stats_for_batch`,
285/// then aggregates these statistics, and finally extracts top offenders.
286pub async fn process_frame_batch(
287    ctx: &SessionContext,
288    frames_rb: RecordBatch,
289    process_id: &str,
290) -> Result<(Vec<RecordBatch>, Vec<RecordBatch>)> {
291    let frame_stats_rbs = compute_frame_stats_for_batch(ctx, frames_rb, process_id).await?;
292    let ctx = SessionContext::new(); // new temp context to keep frame_stats from leaking out
293    let table = MemTable::try_new(frame_stats_rbs[0].schema(), vec![frame_stats_rbs])?;
294    ctx.register_table("frame_stats", Arc::new(table))?;
295    let agg_rbs = ctx
296        .sql(
297            "SELECT budget,
298                    count(*) as nb_frames,
299                    sum(count_in_frame) as sum_counts,
300                    sum(duration_in_frame) as sum_duration,
301                    min(duration_in_frame) as min_duration,
302                    max(duration_in_frame) as max_duration
303             FROM frame_stats
304             GROUP BY budget
305             ",
306        )
307        .await?
308        .collect()
309        .await?;
310    let top_offenders_rbs = extract_top_offenders(&ctx).await?;
311    Ok((agg_rbs, top_offenders_rbs))
312}
313
314/// Retrieves the start time of a process.
315///
316/// This function queries the `processes` table to get the start time for a given `process_id`.
317pub async fn get_process_start_time(
318    client: &mut Client,
319    process_id: &str,
320) -> Result<DateTime<Utc>> {
321    let sql = format!(
322        "SELECT start_time
323         FROM processes
324         WHERE process_id = '{process_id}'"
325    );
326    let rbs = client.query(sql, None).await?;
327    let start_time =
328        DateTime::from_timestamp_nanos(get_only_primitive_value::<TimestampNanosecondType>(&rbs)?);
329    Ok(start_time)
330}
331
332/// Retrieves the stream ID of the main thread for a given process.
333///
334/// This function queries the `blocks` table to find the `stream_id` associated with the main thread
335/// of a given process within a specified time range.
336pub async fn get_main_thread_stream_id(
337    client: &mut Client,
338    process_id: &str,
339    main_thread_name: &str,
340    query_range: TimeRange,
341) -> Result<String> {
342    let sql = format!(
343        r#"SELECT stream_id
344	 FROM blocks
345	 WHERE process_id = '{process_id}'
346	 AND property_get("streams.properties", 'thread-name') = '{main_thread_name}'
347         LIMIT 1"#
348    );
349    let rbs = client.query(sql, Some(query_range)).await?;
350    get_only_string_value(&rbs)
351}
352
353/// Retrieves the time range of a stream.
354///
355/// This function queries the `blocks` table to find the minimum `begin_time` and maximum `end_time`
356/// for a given `stream_id`.
357pub async fn get_stream_time_range(client: &mut Client, stream_id: &str) -> Result<TimeRange> {
358    let sql = format!(
359        "SELECT min(begin_time) as min_begin_time, max(end_time) as max_end_time
360         FROM blocks
361         WHERE stream_id='{stream_id}'"
362    );
363    let rbs = client.query(sql, None).await?;
364    let begin = DateTime::from_timestamp_nanos(get_single_row_primitive_value_by_name::<
365        TimestampNanosecondType,
366    >(&rbs, "min_begin_time")?);
367    let end = DateTime::from_timestamp_nanos(get_single_row_primitive_value_by_name::<
368        TimestampNanosecondType,
369    >(&rbs, "max_end_time")?);
370    Ok(TimeRange::new(begin, end))
371}
372
373/// Retrieves frames for a given stream within a time range and top-level span name.
374///
375/// This function queries the `thread_spans` view to get the `begin` and `end` times
376/// of spans that match the `top_level_span_name` within the specified `time_range`.
377pub async fn get_frames(
378    client: &mut Client,
379    stream_id: &str,
380    time_range: TimeRange,
381    top_level_span_name: &str,
382) -> Result<Vec<RecordBatch>> {
383    let begin_iso = time_range.begin.to_rfc3339();
384    let end_iso = time_range.end.to_rfc3339();
385    let sql = format!(
386        "SELECT begin, end
387         FROM view_instance('thread_spans', '{stream_id}')
388         WHERE name = '{top_level_span_name}'
389         AND begin >= '{begin_iso}'
390         AND end <= '{end_iso}'
391         ORDER BY begin"
392    );
393    client.query(sql, Some(time_range)).await
394}
395
396/// Generates a stream of record batches from a vector of record batches.
397///
398/// This function takes a `Vec<RecordBatch>` and converts it into a `BoxStream`,
399/// optionally slicing large record batches into smaller ones.
400pub fn gen_frame_batches(
401    frames_record_batches: Vec<RecordBatch>,
402) -> BoxStream<'static, Result<RecordBatch>> {
403    Box::pin(try_stream! {
404        for b in frames_record_batches
405        {
406        if b.num_rows() == 0{
407            continue;
408        }
409
410        let max_slice_size = 1024;
411        let nb_slices = (b.num_rows() / max_slice_size) + 1;
412        for islice in 0..nb_slices {
413            let begin_index = islice * max_slice_size;
414            if begin_index >= b.num_rows() {
415            // can happen when num_rows == max_slice_size
416            break;
417            }
418            let end_index = std::cmp::min((islice + 1) * max_slice_size, b.num_rows());
419            let b = b.slice(begin_index, end_index - begin_index);
420            yield b;
421        }
422        }
423    })
424}
425
426/// Generates and sends span batches to a channel.
427pub async fn gen_span_batches(
428    sender: tokio::sync::mpsc::Sender<(RecordBatch, Vec<RecordBatch>, String)>,
429    client: &mut Client,
430    process_id: &str,
431    time_range: TimeRange,
432    main_thread_name: &str,
433    top_level_span_name: &str,
434    group_by_config: &GroupBy,
435) -> Result<()> {
436    //todo: fetch thread id with processes
437    let main_thread_stream_id =
438        get_main_thread_stream_id(client, process_id, main_thread_name, time_range)
439            .await
440            .with_context(|| "get_main_thread_stream_id")?;
441    let mut main_thread_time_range = get_stream_time_range(client, &main_thread_stream_id)
442        .await
443        .with_context(|| "get_stream_time_range")?;
444    main_thread_time_range.begin = main_thread_time_range.begin.max(time_range.begin);
445    main_thread_time_range.end = main_thread_time_range.end.min(time_range.end);
446    let frames_record_batches = get_frames(
447        client,
448        &main_thread_stream_id,
449        main_thread_time_range,
450        top_level_span_name,
451    )
452    .await
453    .with_context(|| "get_frames")?;
454    let mut frame_batch_stream = gen_frame_batches(frames_record_batches);
455    while let Some(res) = frame_batch_stream.next().await {
456        let frame_batch = res?;
457        let spans_rbs = fetch_spans_batch(
458            client,
459            &main_thread_stream_id,
460            frame_batch.clone(),
461            group_by_config,
462        )
463        .await
464        .with_context(|| "fetch_spans_batch")?;
465        sender
466            .send((frame_batch, spans_rbs, process_id.to_owned()))
467            .await?;
468    }
469    Ok(())
470}