micromegas/client/
flightsql_client_factory.rs

1use super::flightsql_client::Client;
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use http::Uri;
5use tonic::transport::{Channel, ClientTlsConfig};
6
7/// A trait for creating FlightSQL clients.
8#[async_trait]
9pub trait FlightSQLClientFactory: Send + Sync {
10    async fn make_client(&self) -> Result<Client>;
11}
12
13/// A FlightSQL client factory that uses a bearer token for authentication.
14pub struct BearerFlightSQLClientFactory {
15    url: String,
16    token: String,
17    client_type: Option<String>,
18}
19
20impl BearerFlightSQLClientFactory {
21    /// Creates a new `BearerFlightSQLClientFactory`.
22    ///
23    /// # Arguments
24    ///
25    /// * `url` - The FlightSQL server URL.
26    /// * `token` - The bearer token to use for authentication.
27    pub fn new(url: String, token: String) -> Self {
28        Self {
29            url,
30            token,
31            client_type: None,
32        }
33    }
34
35    /// Creates a new `BearerFlightSQLClientFactory` with a specific client type identifier.
36    ///
37    /// # Arguments
38    ///
39    /// * `url` - The FlightSQL server URL.
40    /// * `token` - The bearer token to use for authentication.
41    /// * `client_type` - The client type identifier (e.g., "web", "cli", "python").
42    pub fn new_with_client_type(url: String, token: String, client_type: String) -> Self {
43        Self {
44            url,
45            token,
46            client_type: Some(client_type),
47        }
48    }
49
50    /// Creates a new `BearerFlightSQLClientFactory` that reads the URL from the
51    /// `MICROMEGAS_FLIGHTSQL_URL` environment variable.
52    pub fn from_env(token: String) -> Result<Self> {
53        let url = std::env::var("MICROMEGAS_FLIGHTSQL_URL")
54            .with_context(|| "error reading MICROMEGAS_FLIGHTSQL_URL environment variable")?;
55        Ok(Self {
56            url,
57            token,
58            client_type: None,
59        })
60    }
61
62    /// Creates a new `BearerFlightSQLClientFactory` that reads the URL from the
63    /// `MICROMEGAS_FLIGHTSQL_URL` environment variable, with a client type.
64    pub fn from_env_with_client_type(token: String, client_type: String) -> Result<Self> {
65        let url = std::env::var("MICROMEGAS_FLIGHTSQL_URL")
66            .with_context(|| "error reading MICROMEGAS_FLIGHTSQL_URL environment variable")?;
67        Ok(Self {
68            url,
69            token,
70            client_type: Some(client_type),
71        })
72    }
73}
74
75#[async_trait]
76impl FlightSQLClientFactory for BearerFlightSQLClientFactory {
77    async fn make_client(&self) -> Result<Client> {
78        let flight_url = self
79            .url
80            .parse::<Uri>()
81            .with_context(|| "parsing flightsql url")?;
82        let mut endpoint = Channel::builder(flight_url.clone());
83        if flight_url.scheme_str() == Some("https") {
84            let tls_config = ClientTlsConfig::new().with_native_roots();
85            endpoint = endpoint
86                .tls_config(tls_config)
87                .with_context(|| "tls_config")?;
88        }
89        let channel = endpoint
90            .connect()
91            .await
92            .with_context(|| "connecting grpc channel")?;
93        let mut client = Client::new(channel);
94        let auth_value = if self.token.starts_with("Bearer ") {
95            self.token.clone()
96        } else {
97            format!("Bearer {}", self.token)
98        };
99
100        client
101            .inner_mut()
102            .set_header(http::header::AUTHORIZATION.as_str(), auth_value);
103
104        // Set client type header if provided
105        if let Some(client_type) = &self.client_type {
106            client
107                .inner_mut()
108                .set_header("x-client-type", client_type.clone());
109        }
110
111        // Preserve dictionary encoding for bandwidth efficiency
112        client.inner_mut().set_header("preserve_dictionary", "true");
113
114        Ok(client)
115    }
116}