1use super::flightsql_client::Client;
2use anyhow::{Context, Result};
3use async_stream::try_stream;
4use chrono::{DateTime, Utc};
5use datafusion::{
6 arrow::{
7 self,
8 array::{ListBuilder, RecordBatch, StringBuilder, StructBuilder, TimestampNanosecondArray},
9 datatypes::{DataType, Field, Fields, TimestampNanosecondType},
10 },
11 catalog::MemTable,
12 error::DataFusionError,
13 logical_expr::ScalarUDF,
14 physical_plan::stream::RecordBatchReceiverStreamBuilder,
15 prelude::*,
16 scalar::ScalarValue,
17};
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use micromegas_analytics::{
21 dfext::{
22 string_column_accessor::string_column_by_name,
23 typed_column::{
24 get_only_primitive_value, get_only_string_value,
25 get_single_row_primitive_value_by_name, typed_column_by_name,
26 },
27 },
28 properties::property_get::PropertyGet,
29 time::TimeRange,
30};
31use std::{collections::HashMap, sync::Arc};
32
33#[derive(Clone)]
35pub enum GroupBy {
36 Budget(HashMap<String, String>),
38 SpanName,
40}
41
42pub fn budget_map_to_properties(
46 span_name_to_budget: &HashMap<String, String>,
47) -> Result<ScalarValue> {
48 let prop_struct_fields = vec![
49 Field::new("key", DataType::Utf8, false),
50 Field::new("value", DataType::Utf8, false),
51 ];
52 let prop_field = Arc::new(Field::new(
53 "Property",
54 DataType::Struct(Fields::from(prop_struct_fields.clone())),
55 false,
56 ));
57 let mut props_builder =
58 ListBuilder::new(StructBuilder::from_fields(prop_struct_fields, 10)).with_field(prop_field);
59
60 for (k, v) in span_name_to_budget.iter() {
61 let property_builder = props_builder.values();
62 let key_builder = property_builder
63 .field_builder::<StringBuilder>(0)
64 .with_context(|| "getting key field builder")?;
65 key_builder.append_value(k);
66 let value_builder = property_builder
67 .field_builder::<StringBuilder>(1)
68 .with_context(|| "getting value field builder")?;
69 value_builder.append_value(v);
70 property_builder.append(true);
71 }
72 props_builder.append(true);
73 let array = props_builder.finish();
74 Ok(ScalarValue::List(Arc::new(array)))
75}
76
77pub fn get_record_batch_time_range(rb: &RecordBatch) -> Result<Option<TimeRange>> {
81 if rb.num_rows() == 0 {
82 return Ok(None);
83 }
84 let begin_column: &TimestampNanosecondArray = typed_column_by_name(rb, "begin")?;
85 let end_column: &TimestampNanosecondArray = typed_column_by_name(rb, "end")?;
86 let min_begin = DateTime::from_timestamp_nanos(
87 arrow::compute::min(begin_column).with_context(|| "min(begin)")?,
88 );
89 let max_end = DateTime::from_timestamp_nanos(
90 arrow::compute::max(end_column).with_context(|| "max(end)")?,
91 );
92 Ok(Some(TimeRange::new(min_begin, max_end)))
93}
94
95pub async fn fetch_spans_batch(
100 client: &mut Client,
101 stream_id: &str,
102 frames_rb: RecordBatch,
103 group_by_config: &GroupBy,
104) -> Result<Vec<RecordBatch>> {
105 let time_range = get_record_batch_time_range(&frames_rb)?;
106 if time_range.is_none() {
107 return Ok(vec![]);
108 }
109 let time_range = time_range.unwrap();
110 match group_by_config {
111 GroupBy::Budget(span_to_budget) => {
112 let sql = format!(
113 "SELECT name, begin, end, duration
114 FROM view_instance('thread_spans', '{stream_id}')
115 "
116 );
117 let spans_rbs = client.query(sql, Some(time_range)).await?;
118
119 let ctx = SessionContext::new();
121 let table = MemTable::try_new(spans_rbs[0].schema(), vec![spans_rbs])?;
122 ctx.register_table("spans", Arc::new(table))?;
123 ctx.register_udf(ScalarUDF::from(PropertyGet::new()));
124
125 let spans = ctx
126 .sql(
127 "SELECT name, begin, end, duration, property_get($span_to_budget_map, name) as budget
128 FROM spans
129 WHERE property_get($span_to_budget_map, name) IS NOT NULL",
130 )
131 .await?
132 .with_param_values(vec![(
133 "span_to_budget_map",
134 budget_map_to_properties(span_to_budget)?,
135 )])?
136 .collect()
137 .await?;
138 Ok(spans)
139 }
140 GroupBy::SpanName => {
141 let sql = format!(
142 "SELECT name, name as budget, begin, end, duration
143 FROM view_instance('thread_spans', '{stream_id}')
144 "
145 );
146 let spans_rbs = client.query(sql, Some(time_range)).await?;
147 Ok(spans_rbs)
148 }
149 }
150}
151
152pub async fn extract_top_offenders(ctx: &SessionContext) -> Result<Vec<RecordBatch>> {
157 let budgets_rbs = ctx
158 .sql("SELECT DISTINCT budget FROM frame_stats ORDER BY budget")
159 .await?
160 .collect()
161 .await?;
162 let top_offenders_df = ctx
163 .sql(
164 "SELECT budget, duration_in_frame, begin_frame, end_frame, process_id
165 FROM frame_stats
166 WHERE budget = $budget
167 ORDER BY duration_in_frame DESC
168 LIMIT 100
169 ",
170 )
171 .await?;
172 let mut builder =
173 RecordBatchReceiverStreamBuilder::new(top_offenders_df.schema().inner().clone(), 100);
174 for budgets_rb in budgets_rbs {
175 let budget_column = string_column_by_name(&budgets_rb, "budget")?;
176 for budget_row in 0..budgets_rb.num_rows() {
177 let budget = budget_column.value(budget_row)?;
178 let df = top_offenders_df
179 .clone()
180 .with_param_values(vec![("budget", ScalarValue::Utf8(Some(budget.to_owned())))])?;
181 let sender = builder.tx();
182 builder.spawn(async move {
183 for rb in df.collect().await? {
184 sender.send(Ok(rb)).await.map_err(|e| {
185 DataFusionError::Execution(format!("sending record batch: {e:?}"))
186 })?;
187 }
188 Ok(())
189 });
190 }
191 }
192 let mut top_offenders_rbs = vec![];
193 let mut top_stream = builder.build();
194 while let Some(rb_res) = top_stream.next().await {
195 top_offenders_rbs.push(rb_res?);
196 }
197 Ok(top_offenders_rbs)
198}
199
200pub async fn compute_frame_stats_for_batch(
205 ctx: &SessionContext,
206 frames_rb: RecordBatch,
207 process_id: &str,
208) -> Result<Vec<RecordBatch>> {
209 let frame_stats_df = ctx
210 .sql(
211 "SELECT budget,
212 count(*) as count_in_frame,
213 sum(duration) as duration_in_frame,
214 to_timestamp_nanos($begin_frame) as begin_frame,
215 to_timestamp_nanos($end_frame) as end_frame,
216 arrow_cast($process_id, 'Utf8') as process_id
217 FROM spans
218 WHERE begin >= $begin_frame
219 AND end <= $end_frame
220 GROUP BY budget
221 ",
222 )
223 .await
224 .with_context(|| "frame_stats_df")?;
225
226 let mut builder =
227 RecordBatchReceiverStreamBuilder::new(frame_stats_df.schema().inner().clone(), 100);
228 let utc: Arc<str> = Arc::from("+00:00");
229 let begin_frame_column: &TimestampNanosecondArray =
230 typed_column_by_name(&frames_rb, "begin")
231 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
232 let end_frame_column: &TimestampNanosecondArray = typed_column_by_name(&frames_rb, "end")
233 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
234 for iframe in 0..frames_rb.num_rows() {
235 let begin_frame = begin_frame_column.value(iframe);
236 let end_frame = end_frame_column.value(iframe);
237 let df = frame_stats_df.clone().with_param_values(vec![
238 (
239 "begin_frame",
240 ScalarValue::TimestampNanosecond(Some(begin_frame), Some(utc.clone())),
241 ),
242 (
243 "end_frame",
244 ScalarValue::TimestampNanosecond(Some(end_frame), Some(utc.clone())),
245 ),
246 ("process_id", ScalarValue::Utf8(Some(process_id.to_owned()))),
247 ])?;
248 let sender = builder.tx();
249 builder.spawn(async move {
250 for rb in df.collect().await? {
251 sender.send(Ok(rb)).await.map_err(|e| {
252 DataFusionError::Execution(format!("sending record batch: {e:?}"))
253 })?;
254 }
255 Ok(())
256 });
257 }
258
259 let mut frame_stats_rbs = vec![];
260 let mut stream = builder.build();
261 while let Some(rb_res) = stream.next().await {
262 frame_stats_rbs.push(rb_res?);
263 }
264 Ok(frame_stats_rbs)
265}
266
267pub async fn merge_top_offenders(top_offenders: Vec<RecordBatch>) -> Result<Vec<RecordBatch>> {
272 if top_offenders.is_empty() {
273 return Ok(top_offenders);
274 }
275 let ctx = SessionContext::new();
276 let table = MemTable::try_new(top_offenders[0].schema(), vec![top_offenders])?;
277 ctx.register_table("frame_stats", Arc::new(table))?;
279 extract_top_offenders(&ctx).await
280}
281
282pub async fn process_frame_batch(
287 ctx: &SessionContext,
288 frames_rb: RecordBatch,
289 process_id: &str,
290) -> Result<(Vec<RecordBatch>, Vec<RecordBatch>)> {
291 let frame_stats_rbs = compute_frame_stats_for_batch(ctx, frames_rb, process_id).await?;
292 let ctx = SessionContext::new(); let table = MemTable::try_new(frame_stats_rbs[0].schema(), vec![frame_stats_rbs])?;
294 ctx.register_table("frame_stats", Arc::new(table))?;
295 let agg_rbs = ctx
296 .sql(
297 "SELECT budget,
298 count(*) as nb_frames,
299 sum(count_in_frame) as sum_counts,
300 sum(duration_in_frame) as sum_duration,
301 min(duration_in_frame) as min_duration,
302 max(duration_in_frame) as max_duration
303 FROM frame_stats
304 GROUP BY budget
305 ",
306 )
307 .await?
308 .collect()
309 .await?;
310 let top_offenders_rbs = extract_top_offenders(&ctx).await?;
311 Ok((agg_rbs, top_offenders_rbs))
312}
313
314pub async fn get_process_start_time(
318 client: &mut Client,
319 process_id: &str,
320) -> Result<DateTime<Utc>> {
321 let sql = format!(
322 "SELECT start_time
323 FROM processes
324 WHERE process_id = '{process_id}'"
325 );
326 let rbs = client.query(sql, None).await?;
327 let start_time =
328 DateTime::from_timestamp_nanos(get_only_primitive_value::<TimestampNanosecondType>(&rbs)?);
329 Ok(start_time)
330}
331
332pub async fn get_main_thread_stream_id(
337 client: &mut Client,
338 process_id: &str,
339 main_thread_name: &str,
340 query_range: TimeRange,
341) -> Result<String> {
342 let sql = format!(
343 r#"SELECT stream_id
344 FROM blocks
345 WHERE process_id = '{process_id}'
346 AND property_get("streams.properties", 'thread-name') = '{main_thread_name}'
347 LIMIT 1"#
348 );
349 let rbs = client.query(sql, Some(query_range)).await?;
350 get_only_string_value(&rbs)
351}
352
353pub async fn get_stream_time_range(client: &mut Client, stream_id: &str) -> Result<TimeRange> {
358 let sql = format!(
359 "SELECT min(begin_time) as min_begin_time, max(end_time) as max_end_time
360 FROM blocks
361 WHERE stream_id='{stream_id}'"
362 );
363 let rbs = client.query(sql, None).await?;
364 let begin = DateTime::from_timestamp_nanos(get_single_row_primitive_value_by_name::<
365 TimestampNanosecondType,
366 >(&rbs, "min_begin_time")?);
367 let end = DateTime::from_timestamp_nanos(get_single_row_primitive_value_by_name::<
368 TimestampNanosecondType,
369 >(&rbs, "max_end_time")?);
370 Ok(TimeRange::new(begin, end))
371}
372
373pub async fn get_frames(
378 client: &mut Client,
379 stream_id: &str,
380 time_range: TimeRange,
381 top_level_span_name: &str,
382) -> Result<Vec<RecordBatch>> {
383 let begin_iso = time_range.begin.to_rfc3339();
384 let end_iso = time_range.end.to_rfc3339();
385 let sql = format!(
386 "SELECT begin, end
387 FROM view_instance('thread_spans', '{stream_id}')
388 WHERE name = '{top_level_span_name}'
389 AND begin >= '{begin_iso}'
390 AND end <= '{end_iso}'
391 ORDER BY begin"
392 );
393 client.query(sql, Some(time_range)).await
394}
395
396pub fn gen_frame_batches(
401 frames_record_batches: Vec<RecordBatch>,
402) -> BoxStream<'static, Result<RecordBatch>> {
403 Box::pin(try_stream! {
404 for b in frames_record_batches
405 {
406 if b.num_rows() == 0{
407 continue;
408 }
409
410 let max_slice_size = 1024;
411 let nb_slices = (b.num_rows() / max_slice_size) + 1;
412 for islice in 0..nb_slices {
413 let begin_index = islice * max_slice_size;
414 if begin_index >= b.num_rows() {
415 break;
417 }
418 let end_index = std::cmp::min((islice + 1) * max_slice_size, b.num_rows());
419 let b = b.slice(begin_index, end_index - begin_index);
420 yield b;
421 }
422 }
423 })
424}
425
426pub async fn gen_span_batches(
428 sender: tokio::sync::mpsc::Sender<(RecordBatch, Vec<RecordBatch>, String)>,
429 client: &mut Client,
430 process_id: &str,
431 time_range: TimeRange,
432 main_thread_name: &str,
433 top_level_span_name: &str,
434 group_by_config: &GroupBy,
435) -> Result<()> {
436 let main_thread_stream_id =
438 get_main_thread_stream_id(client, process_id, main_thread_name, time_range)
439 .await
440 .with_context(|| "get_main_thread_stream_id")?;
441 let mut main_thread_time_range = get_stream_time_range(client, &main_thread_stream_id)
442 .await
443 .with_context(|| "get_stream_time_range")?;
444 main_thread_time_range.begin = main_thread_time_range.begin.max(time_range.begin);
445 main_thread_time_range.end = main_thread_time_range.end.min(time_range.end);
446 let frames_record_batches = get_frames(
447 client,
448 &main_thread_stream_id,
449 main_thread_time_range,
450 top_level_span_name,
451 )
452 .await
453 .with_context(|| "get_frames")?;
454 let mut frame_batch_stream = gen_frame_batches(frames_record_batches);
455 while let Some(res) = frame_batch_stream.next().await {
456 let frame_batch = res?;
457 let spans_rbs = fetch_spans_batch(
458 client,
459 &main_thread_stream_id,
460 frame_batch.clone(),
461 group_by_config,
462 )
463 .await
464 .with_context(|| "fetch_spans_batch")?;
465 sender
466 .send((frame_batch, spans_rbs, process_id.to_owned()))
467 .await?;
468 }
469 Ok(())
470}