micromegas/servers/
http_gateway.rs

1use anyhow::{Context, Result};
2use axum::{
3    Extension, Json, Router,
4    body::Body,
5    extract::ConnectInfo,
6    http::{Response, StatusCode},
7    response::IntoResponse,
8    routing::post,
9};
10use chrono::{DateTime, Utc};
11use datafusion::arrow::{
12    array::RecordBatch,
13    json::{Writer, writer::JsonArray},
14};
15use http::{HeaderMap, Uri};
16use micromegas_analytics::time::TimeRange;
17use micromegas_tracing::info;
18use serde::Deserialize;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use thiserror::Error;
22use tonic::transport::{Channel, ClientTlsConfig};
23
24use crate::client::flightsql_client::Client;
25use crate::servers::http_utils;
26
27/// Configuration for forwarding HTTP headers to FlightSQL backend
28#[derive(Debug, Clone, Deserialize)]
29pub struct HeaderForwardingConfig {
30    /// Exact header names to forward (case-insensitive)
31    pub allowed_headers: Vec<String>,
32
33    /// Header prefixes to forward (e.g., "X-Custom-")
34    pub allowed_prefixes: Vec<String>,
35
36    /// Headers to explicitly block (overrides allows)
37    pub blocked_headers: Vec<String>,
38}
39
40impl Default for HeaderForwardingConfig {
41    fn default() -> Self {
42        Self {
43            // Default safe headers to forward
44            allowed_headers: vec![
45                "Authorization".to_string(),
46                "User-Agent".to_string(),
47                "X-Client-Type".to_string(),
48                "X-Correlation-ID".to_string(),
49                "X-Request-ID".to_string(),
50                "X-User-Email".to_string(),
51                "X-User-ID".to_string(),
52                "X-User-Name".to_string(),
53            ],
54            allowed_prefixes: vec![],
55            blocked_headers: vec![
56                "Cookie".to_string(),
57                "Set-Cookie".to_string(),
58                // SECURITY: Gateway always sets this from actual connection
59                "X-Client-IP".to_string(),
60            ],
61        }
62    }
63}
64
65impl HeaderForwardingConfig {
66    /// Load configuration from environment variable or use defaults
67    pub fn from_env() -> Result<Self> {
68        if let Ok(config_json) = std::env::var("MICROMEGAS_GATEWAY_HEADERS") {
69            serde_json::from_str(&config_json).context("Failed to parse MICROMEGAS_GATEWAY_HEADERS")
70        } else {
71            Ok(Self::default())
72        }
73    }
74
75    /// Check if a header should be forwarded based on configuration
76    pub fn should_forward(&self, header_name: &str) -> bool {
77        let name_lower = header_name.to_lowercase();
78
79        // Check blocked list first
80        if self
81            .blocked_headers
82            .iter()
83            .any(|h| h.to_lowercase() == name_lower)
84        {
85            return false;
86        }
87
88        // Check exact matches
89        if self
90            .allowed_headers
91            .iter()
92            .any(|h| h.to_lowercase() == name_lower)
93        {
94            return true;
95        }
96
97        // Check prefixes
98        self.allowed_prefixes
99            .iter()
100            .any(|prefix| name_lower.starts_with(&prefix.to_lowercase()))
101    }
102}
103
104#[derive(Error, Debug)]
105pub enum GatewayError {
106    #[error("Bad request: {0}")]
107    BadRequest(String),
108
109    #[error("Unauthorized: {0}")]
110    Unauthorized(String),
111
112    #[error("Forbidden: {0}")]
113    Forbidden(String),
114
115    #[error("Service unavailable: {0}")]
116    ServiceUnavailable(String),
117
118    #[error("Internal server error: {0}")]
119    Internal(String),
120}
121
122impl IntoResponse for GatewayError {
123    fn into_response(self) -> Response<Body> {
124        let (status, message) = match self {
125            GatewayError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
126            GatewayError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg),
127            GatewayError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg),
128            GatewayError::ServiceUnavailable(msg) => (StatusCode::SERVICE_UNAVAILABLE, msg),
129            GatewayError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
130        };
131        (status, message).into_response()
132    }
133}
134
135#[derive(Debug, Deserialize)]
136pub struct QueryRequest {
137    sql: String,
138    /// Optional time range filter - begin timestamp in RFC3339 format
139    /// Example: "2024-01-01T00:00:00Z"
140    #[serde(default)]
141    time_range_begin: Option<String>,
142    /// Optional time range filter - end timestamp in RFC3339 format
143    /// Example: "2024-01-02T00:00:00Z"
144    #[serde(default)]
145    time_range_end: Option<String>,
146}
147
148/// Build origin tracking metadata for FlightSQL queries
149/// Augments the client type by appending "+gateway" to preserve the full client chain
150///
151/// This function only sets origin tracking headers that the gateway controls:
152/// - x-client-type: augmented with "+gateway"
153/// - x-request-id: generated if not present
154/// - x-client-ip: extracted from actual connection (prevents spoofing)
155///
156/// User attribution headers (x-user-id, x-user-email) are forwarded from client
157/// if present in allowed_headers config. FlightSQL validates these against the
158/// Authorization token.
159pub fn build_origin_metadata(
160    headers: &HeaderMap,
161    addr: &SocketAddr,
162) -> tonic::metadata::MetadataMap {
163    let mut metadata = tonic::metadata::MetadataMap::new();
164
165    // 1. Client Type - augment existing or set to "unknown+gateway"
166    let original_client_type = headers
167        .get("x-client-type")
168        .and_then(|v| v.to_str().ok())
169        .unwrap_or("unknown");
170    let augmented_client_type = format!("{original_client_type}+gateway");
171    if let Ok(value) = augmented_client_type.parse() {
172        metadata.insert("x-client-type", value);
173    }
174
175    // 2. Request ID - generate UUID if not present
176    let request_id = headers
177        .get("x-request-id")
178        .and_then(|v| v.to_str().ok())
179        .map(|s| s.to_string())
180        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
181    if let Ok(value) = request_id.parse() {
182        metadata.insert("x-request-id", value);
183    }
184
185    // 3. Client IP - ALWAYS extract from connection (never from client header)
186    // SECURITY: Prevents IP spoofing in audit logs
187    let mut extensions = http::Extensions::new();
188    extensions.insert(axum::extract::ConnectInfo(*addr));
189    let client_ip = http_utils::get_client_ip(headers, &extensions);
190    if let Ok(value) = client_ip.parse() {
191        metadata.insert("x-client-ip", value);
192    }
193
194    metadata
195}
196
197pub async fn handle_query(
198    Extension(config): Extension<Arc<HeaderForwardingConfig>>,
199    ConnectInfo(addr): ConnectInfo<SocketAddr>,
200    headers: HeaderMap,
201    Json(request): Json<QueryRequest>,
202) -> Result<String, GatewayError> {
203    let start_time = std::time::Instant::now();
204
205    // Build origin tracking metadata
206    let origin_metadata = build_origin_metadata(&headers, &addr);
207    let client_type_header = origin_metadata
208        .get("x-client-type")
209        .and_then(|v| v.to_str().ok())
210        .unwrap_or("unknown+gateway");
211    let request_id_header = origin_metadata
212        .get("x-request-id")
213        .and_then(|v| v.to_str().ok())
214        .unwrap_or("unknown");
215
216    // Request validation
217    let sql = request.sql.trim();
218    if sql.is_empty() {
219        return Err(GatewayError::BadRequest(
220            "SQL query cannot be empty".to_string(),
221        ));
222    }
223
224    // Basic size limit (1MB for SQL query)
225    const MAX_SQL_SIZE: usize = 1_048_576;
226    if sql.len() > MAX_SQL_SIZE {
227        return Err(GatewayError::BadRequest(format!(
228            "SQL query too large: {} bytes (max: {} bytes)",
229            sql.len(),
230            MAX_SQL_SIZE
231        )));
232    }
233
234    // Parse time range if provided
235    let time_range = match (&request.time_range_begin, &request.time_range_end) {
236        (Some(begin_str), Some(end_str)) => {
237            let begin = DateTime::parse_from_rfc3339(begin_str)
238                .map_err(|e| {
239                    GatewayError::BadRequest(format!(
240                        "Invalid time_range_begin format (expected RFC3339): {e}"
241                    ))
242                })?
243                .with_timezone(&Utc);
244            let end = DateTime::parse_from_rfc3339(end_str)
245                .map_err(|e| {
246                    GatewayError::BadRequest(format!(
247                        "Invalid time_range_end format (expected RFC3339): {e}"
248                    ))
249                })?
250                .with_timezone(&Utc);
251
252            if begin > end {
253                return Err(GatewayError::BadRequest(
254                    "time_range_begin must be before time_range_end".to_string(),
255                ));
256            }
257
258            Some(TimeRange::new(begin, end))
259        }
260        (Some(_), None) => {
261            return Err(GatewayError::BadRequest(
262                "time_range_end must be provided when time_range_begin is specified".to_string(),
263            ));
264        }
265        (None, Some(_)) => {
266            return Err(GatewayError::BadRequest(
267                "time_range_begin must be provided when time_range_end is specified".to_string(),
268            ));
269        }
270        (None, None) => None,
271    };
272
273    info!(
274        "Gateway request: request_id={}, client_type={}, time_range={:?}, sql={}",
275        request_id_header, client_type_header, time_range, sql
276    );
277
278    // Connect to FlightSQL backend
279    let flight_url = std::env::var("MICROMEGAS_FLIGHTSQL_URL")
280        .map_err(|_| GatewayError::Internal("MICROMEGAS_FLIGHTSQL_URL not configured".to_string()))?
281        .parse::<Uri>()
282        .map_err(|e| GatewayError::Internal(format!("Invalid FlightSQL URL: {e}")))?;
283
284    let tls_config = ClientTlsConfig::new().with_native_roots();
285    let channel = Channel::builder(flight_url)
286        .tls_config(tls_config)
287        .map_err(|e| GatewayError::Internal(format!("TLS config error: {e}")))?
288        .connect()
289        .await
290        .map_err(|e| {
291            GatewayError::ServiceUnavailable(format!("Failed to connect to FlightSQL: {e}"))
292        })?;
293
294    // Create client and set headers
295    let mut client = Client::new(channel);
296
297    client
298        .inner_mut()
299        .set_header("x-client-type", client_type_header);
300    client
301        .inner_mut()
302        .set_header("x-request-id", request_id_header);
303
304    if let Some(client_ip) = origin_metadata.get("x-client-ip")
305        && let Ok(ip_str) = client_ip.to_str()
306    {
307        client.inner_mut().set_header("x-client-ip", ip_str);
308    }
309
310    // Forward allowed headers from client
311    for (name, value) in headers.iter() {
312        let header_name = name.as_str();
313
314        // Skip headers already set by origin metadata
315        if header_name.eq_ignore_ascii_case("x-client-type")
316            || header_name.eq_ignore_ascii_case("x-request-id")
317            || header_name.eq_ignore_ascii_case("x-client-ip")
318        {
319            continue; // Origin metadata takes precedence
320        }
321
322        if config.should_forward(header_name)
323            && let Ok(value_str) = value.to_str()
324        {
325            client.inner_mut().set_header(header_name, value_str);
326        }
327    }
328
329    // Execute query with error handling
330    let batches = client
331        .query(sql.to_string(), time_range)
332        .await
333        .map_err(|e| {
334            // Map tonic errors to appropriate HTTP status codes
335            if let Some(status) = e.downcast_ref::<tonic::Status>() {
336                match status.code() {
337                    tonic::Code::Unauthenticated => {
338                        GatewayError::Unauthorized(status.message().to_string())
339                    }
340                    tonic::Code::PermissionDenied => {
341                        GatewayError::Forbidden(status.message().to_string())
342                    }
343                    tonic::Code::InvalidArgument => {
344                        GatewayError::BadRequest(status.message().to_string())
345                    }
346                    tonic::Code::Unavailable => {
347                        GatewayError::ServiceUnavailable(status.message().to_string())
348                    }
349                    _ => GatewayError::Internal(format!("Query failed: {}", status.message())),
350                }
351            } else {
352                GatewayError::Internal(format!("Query execution error: {e:?}"))
353            }
354        })?;
355
356    let elapsed = start_time.elapsed();
357    info!(
358        "Gateway request completed: request_id={}, duration={:?}",
359        request_id_header, elapsed
360    );
361
362    if batches.is_empty() {
363        return Ok("[]".to_string());
364    }
365
366    let mut buffer = Vec::new();
367    let mut json_writer = Writer::<_, JsonArray>::new(&mut buffer);
368    let batch_refs: Vec<&RecordBatch> = batches.iter().collect();
369    json_writer
370        .write_batches(&batch_refs)
371        .map_err(|e| GatewayError::Internal(format!("Failed to serialize results: {e}")))?;
372    json_writer
373        .finish()
374        .map_err(|e| GatewayError::Internal(format!("Failed to finish JSON output: {e}")))?;
375
376    String::from_utf8(buffer)
377        .map_err(|e| GatewayError::Internal(format!("Invalid UTF-8 in results: {e}")))
378}
379
380pub fn register_routes(router: Router) -> Router {
381    router.route("/gateway/query", post(handle_query))
382}