1use anyhow::{Context, Result};
2use axum::{
3 Extension, Json, Router,
4 body::Body,
5 extract::ConnectInfo,
6 http::{Response, StatusCode},
7 response::IntoResponse,
8 routing::post,
9};
10use chrono::{DateTime, Utc};
11use datafusion::arrow::{
12 array::RecordBatch,
13 json::{Writer, writer::JsonArray},
14};
15use http::{HeaderMap, Uri};
16use micromegas_analytics::time::TimeRange;
17use micromegas_tracing::info;
18use serde::Deserialize;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use thiserror::Error;
22use tonic::transport::{Channel, ClientTlsConfig};
23
24use crate::client::flightsql_client::Client;
25use crate::servers::http_utils;
26
27#[derive(Debug, Clone, Deserialize)]
29pub struct HeaderForwardingConfig {
30 pub allowed_headers: Vec<String>,
32
33 pub allowed_prefixes: Vec<String>,
35
36 pub blocked_headers: Vec<String>,
38}
39
40impl Default for HeaderForwardingConfig {
41 fn default() -> Self {
42 Self {
43 allowed_headers: vec![
45 "Authorization".to_string(),
46 "User-Agent".to_string(),
47 "X-Client-Type".to_string(),
48 "X-Correlation-ID".to_string(),
49 "X-Request-ID".to_string(),
50 "X-User-Email".to_string(),
51 "X-User-ID".to_string(),
52 "X-User-Name".to_string(),
53 ],
54 allowed_prefixes: vec![],
55 blocked_headers: vec![
56 "Cookie".to_string(),
57 "Set-Cookie".to_string(),
58 "X-Client-IP".to_string(),
60 ],
61 }
62 }
63}
64
65impl HeaderForwardingConfig {
66 pub fn from_env() -> Result<Self> {
68 if let Ok(config_json) = std::env::var("MICROMEGAS_GATEWAY_HEADERS") {
69 serde_json::from_str(&config_json).context("Failed to parse MICROMEGAS_GATEWAY_HEADERS")
70 } else {
71 Ok(Self::default())
72 }
73 }
74
75 pub fn should_forward(&self, header_name: &str) -> bool {
77 let name_lower = header_name.to_lowercase();
78
79 if self
81 .blocked_headers
82 .iter()
83 .any(|h| h.to_lowercase() == name_lower)
84 {
85 return false;
86 }
87
88 if self
90 .allowed_headers
91 .iter()
92 .any(|h| h.to_lowercase() == name_lower)
93 {
94 return true;
95 }
96
97 self.allowed_prefixes
99 .iter()
100 .any(|prefix| name_lower.starts_with(&prefix.to_lowercase()))
101 }
102}
103
104#[derive(Error, Debug)]
105pub enum GatewayError {
106 #[error("Bad request: {0}")]
107 BadRequest(String),
108
109 #[error("Unauthorized: {0}")]
110 Unauthorized(String),
111
112 #[error("Forbidden: {0}")]
113 Forbidden(String),
114
115 #[error("Service unavailable: {0}")]
116 ServiceUnavailable(String),
117
118 #[error("Internal server error: {0}")]
119 Internal(String),
120}
121
122impl IntoResponse for GatewayError {
123 fn into_response(self) -> Response<Body> {
124 let (status, message) = match self {
125 GatewayError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
126 GatewayError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg),
127 GatewayError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg),
128 GatewayError::ServiceUnavailable(msg) => (StatusCode::SERVICE_UNAVAILABLE, msg),
129 GatewayError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
130 };
131 (status, message).into_response()
132 }
133}
134
135#[derive(Debug, Deserialize)]
136pub struct QueryRequest {
137 sql: String,
138 #[serde(default)]
141 time_range_begin: Option<String>,
142 #[serde(default)]
145 time_range_end: Option<String>,
146}
147
148pub fn build_origin_metadata(
160 headers: &HeaderMap,
161 addr: &SocketAddr,
162) -> tonic::metadata::MetadataMap {
163 let mut metadata = tonic::metadata::MetadataMap::new();
164
165 let original_client_type = headers
167 .get("x-client-type")
168 .and_then(|v| v.to_str().ok())
169 .unwrap_or("unknown");
170 let augmented_client_type = format!("{original_client_type}+gateway");
171 if let Ok(value) = augmented_client_type.parse() {
172 metadata.insert("x-client-type", value);
173 }
174
175 let request_id = headers
177 .get("x-request-id")
178 .and_then(|v| v.to_str().ok())
179 .map(|s| s.to_string())
180 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
181 if let Ok(value) = request_id.parse() {
182 metadata.insert("x-request-id", value);
183 }
184
185 let mut extensions = http::Extensions::new();
188 extensions.insert(axum::extract::ConnectInfo(*addr));
189 let client_ip = http_utils::get_client_ip(headers, &extensions);
190 if let Ok(value) = client_ip.parse() {
191 metadata.insert("x-client-ip", value);
192 }
193
194 metadata
195}
196
197pub async fn handle_query(
198 Extension(config): Extension<Arc<HeaderForwardingConfig>>,
199 ConnectInfo(addr): ConnectInfo<SocketAddr>,
200 headers: HeaderMap,
201 Json(request): Json<QueryRequest>,
202) -> Result<String, GatewayError> {
203 let start_time = std::time::Instant::now();
204
205 let origin_metadata = build_origin_metadata(&headers, &addr);
207 let client_type_header = origin_metadata
208 .get("x-client-type")
209 .and_then(|v| v.to_str().ok())
210 .unwrap_or("unknown+gateway");
211 let request_id_header = origin_metadata
212 .get("x-request-id")
213 .and_then(|v| v.to_str().ok())
214 .unwrap_or("unknown");
215
216 let sql = request.sql.trim();
218 if sql.is_empty() {
219 return Err(GatewayError::BadRequest(
220 "SQL query cannot be empty".to_string(),
221 ));
222 }
223
224 const MAX_SQL_SIZE: usize = 1_048_576;
226 if sql.len() > MAX_SQL_SIZE {
227 return Err(GatewayError::BadRequest(format!(
228 "SQL query too large: {} bytes (max: {} bytes)",
229 sql.len(),
230 MAX_SQL_SIZE
231 )));
232 }
233
234 let time_range = match (&request.time_range_begin, &request.time_range_end) {
236 (Some(begin_str), Some(end_str)) => {
237 let begin = DateTime::parse_from_rfc3339(begin_str)
238 .map_err(|e| {
239 GatewayError::BadRequest(format!(
240 "Invalid time_range_begin format (expected RFC3339): {e}"
241 ))
242 })?
243 .with_timezone(&Utc);
244 let end = DateTime::parse_from_rfc3339(end_str)
245 .map_err(|e| {
246 GatewayError::BadRequest(format!(
247 "Invalid time_range_end format (expected RFC3339): {e}"
248 ))
249 })?
250 .with_timezone(&Utc);
251
252 if begin > end {
253 return Err(GatewayError::BadRequest(
254 "time_range_begin must be before time_range_end".to_string(),
255 ));
256 }
257
258 Some(TimeRange::new(begin, end))
259 }
260 (Some(_), None) => {
261 return Err(GatewayError::BadRequest(
262 "time_range_end must be provided when time_range_begin is specified".to_string(),
263 ));
264 }
265 (None, Some(_)) => {
266 return Err(GatewayError::BadRequest(
267 "time_range_begin must be provided when time_range_end is specified".to_string(),
268 ));
269 }
270 (None, None) => None,
271 };
272
273 info!(
274 "Gateway request: request_id={}, client_type={}, time_range={:?}, sql={}",
275 request_id_header, client_type_header, time_range, sql
276 );
277
278 let flight_url = std::env::var("MICROMEGAS_FLIGHTSQL_URL")
280 .map_err(|_| GatewayError::Internal("MICROMEGAS_FLIGHTSQL_URL not configured".to_string()))?
281 .parse::<Uri>()
282 .map_err(|e| GatewayError::Internal(format!("Invalid FlightSQL URL: {e}")))?;
283
284 let tls_config = ClientTlsConfig::new().with_native_roots();
285 let channel = Channel::builder(flight_url)
286 .tls_config(tls_config)
287 .map_err(|e| GatewayError::Internal(format!("TLS config error: {e}")))?
288 .connect()
289 .await
290 .map_err(|e| {
291 GatewayError::ServiceUnavailable(format!("Failed to connect to FlightSQL: {e}"))
292 })?;
293
294 let mut client = Client::new(channel);
296
297 client
298 .inner_mut()
299 .set_header("x-client-type", client_type_header);
300 client
301 .inner_mut()
302 .set_header("x-request-id", request_id_header);
303
304 if let Some(client_ip) = origin_metadata.get("x-client-ip")
305 && let Ok(ip_str) = client_ip.to_str()
306 {
307 client.inner_mut().set_header("x-client-ip", ip_str);
308 }
309
310 for (name, value) in headers.iter() {
312 let header_name = name.as_str();
313
314 if header_name.eq_ignore_ascii_case("x-client-type")
316 || header_name.eq_ignore_ascii_case("x-request-id")
317 || header_name.eq_ignore_ascii_case("x-client-ip")
318 {
319 continue; }
321
322 if config.should_forward(header_name)
323 && let Ok(value_str) = value.to_str()
324 {
325 client.inner_mut().set_header(header_name, value_str);
326 }
327 }
328
329 let batches = client
331 .query(sql.to_string(), time_range)
332 .await
333 .map_err(|e| {
334 if let Some(status) = e.downcast_ref::<tonic::Status>() {
336 match status.code() {
337 tonic::Code::Unauthenticated => {
338 GatewayError::Unauthorized(status.message().to_string())
339 }
340 tonic::Code::PermissionDenied => {
341 GatewayError::Forbidden(status.message().to_string())
342 }
343 tonic::Code::InvalidArgument => {
344 GatewayError::BadRequest(status.message().to_string())
345 }
346 tonic::Code::Unavailable => {
347 GatewayError::ServiceUnavailable(status.message().to_string())
348 }
349 _ => GatewayError::Internal(format!("Query failed: {}", status.message())),
350 }
351 } else {
352 GatewayError::Internal(format!("Query execution error: {e:?}"))
353 }
354 })?;
355
356 let elapsed = start_time.elapsed();
357 info!(
358 "Gateway request completed: request_id={}, duration={:?}",
359 request_id_header, elapsed
360 );
361
362 if batches.is_empty() {
363 return Ok("[]".to_string());
364 }
365
366 let mut buffer = Vec::new();
367 let mut json_writer = Writer::<_, JsonArray>::new(&mut buffer);
368 let batch_refs: Vec<&RecordBatch> = batches.iter().collect();
369 json_writer
370 .write_batches(&batch_refs)
371 .map_err(|e| GatewayError::Internal(format!("Failed to serialize results: {e}")))?;
372 json_writer
373 .finish()
374 .map_err(|e| GatewayError::Internal(format!("Failed to finish JSON output: {e}")))?;
375
376 String::from_utf8(buffer)
377 .map_err(|e| GatewayError::Internal(format!("Invalid UTF-8 in results: {e}")))
378}
379
380pub fn register_routes(router: Router) -> Router {
381 router.route("/gateway/query", post(handle_query))
382}