micromegas_auth/
oidc.rs

1use crate::types::{AuthContext, AuthProvider, AuthType};
2use anyhow::{Result, anyhow};
3use base64::Engine;
4use chrono::{DateTime, TimeDelta, Utc};
5use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
6use moka::future::Cache;
7use openidconnect::core::{CoreJsonWebKeySet, CoreProviderMetadata};
8use openidconnect::{IssuerUrl, JsonWebKey};
9use rsa::pkcs1::EncodeRsaPublicKey;
10use rsa::{BigUint, RsaPublicKey};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16/// Create HTTP client for OIDC operations with security best practices
17///
18/// This client is configured with SSRF protection:
19/// - No redirects allowed (prevents open redirect attacks)
20/// - Suitable for OIDC discovery and token exchange operations
21///
22/// # Security
23///
24/// The no-redirect policy is important for OIDC operations because:
25/// - Prevents attackers from redirecting requests to internal services
26/// - Ensures OIDC endpoints are accessed directly without intermediate hops
27/// - Protects against SSRF (Server-Side Request Forgery) attacks
28///
29/// # Example
30///
31/// ```rust
32/// use micromegas_auth::oidc::create_http_client;
33///
34/// # async fn example() -> anyhow::Result<()> {
35/// let client = create_http_client()?;
36/// let response = client.get("https://accounts.google.com/.well-known/openid-configuration")
37///     .send()
38///     .await?;
39/// # Ok(())
40/// # }
41/// ```
42pub fn create_http_client() -> Result<reqwest::Client> {
43    reqwest::ClientBuilder::new()
44        .redirect(reqwest::redirect::Policy::none())
45        .build()
46        .map_err(|e| anyhow!("Failed to create HTTP client: {e:?}"))
47}
48
49/// Fetch JWKS from the OIDC provider using openidconnect's built-in discovery
50async fn fetch_jwks(issuer_url: &IssuerUrl) -> Result<Arc<CoreJsonWebKeySet>> {
51    let http_client = create_http_client()?;
52
53    // Use openidconnect's built-in OIDC discovery
54    let metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
55        .await
56        .map_err(|e| {
57            anyhow!(
58                "Failed to discover OIDC metadata from {}: {e:?}",
59                issuer_url
60            )
61        })?;
62
63    // Fetch JWKS from jwks_uri
64    let jwks_uri = metadata.jwks_uri();
65    let jwks: CoreJsonWebKeySet = http_client
66        .get(jwks_uri.url().as_str())
67        .send()
68        .await
69        .map_err(|e| anyhow!("Failed to fetch JWKS from {}: {e:?}", jwks_uri))?
70        .json()
71        .await
72        .map_err(|e| anyhow!("Failed to parse JWKS: {e:?}"))?;
73
74    Ok(Arc::new(jwks))
75}
76
77/// JWKS cache for an OIDC issuer
78///
79/// Caches JSON Web Key Sets with automatic TTL expiration.
80/// Uses moka for thread-safe caching with atomic cache miss handling.
81struct JwksCache {
82    issuer_url: IssuerUrl,
83    cache: Cache<String, Arc<CoreJsonWebKeySet>>,
84}
85
86impl JwksCache {
87    /// Create a new JWKS cache
88    fn new(issuer_url: IssuerUrl, ttl: Duration) -> Self {
89        let cache = Cache::builder().time_to_live(ttl).build();
90
91        Self { issuer_url, cache }
92    }
93
94    /// Get the JWKS, fetching from the issuer if not cached
95    async fn get(&self) -> Result<Arc<CoreJsonWebKeySet>> {
96        let issuer_url = self.issuer_url.clone();
97
98        self.cache
99            .try_get_with(
100                "jwks".to_string(),
101                async move { fetch_jwks(&issuer_url).await },
102            )
103            .await
104            .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))
105    }
106}
107
108/// Configuration for a single OIDC issuer
109#[derive(Debug, Clone, Deserialize)]
110pub struct OidcIssuer {
111    /// Issuer URL (e.g., <https://accounts.google.com>)
112    pub issuer: String,
113    /// Expected audience
114    /// - For access tokens: API audience (e.g., "<https://api.example.com>")
115    /// - For ID tokens: client ID
116    pub audience: String,
117}
118
119const DEFAULT_JWKS_REFRESH_INTERVAL_SECS: u64 = 3600;
120const DEFAULT_TOKEN_CACHE_SIZE: u64 = 1000;
121const DEFAULT_TOKEN_CACHE_TTL_SECS: u64 = 300;
122
123/// OIDC configuration
124#[derive(Debug, Clone, Deserialize)]
125#[serde(default)]
126pub struct OidcConfig {
127    /// List of configured OIDC issuers
128    pub issuers: Vec<OidcIssuer>,
129    /// JWKS refresh interval in seconds (default: 3600 = 1 hour)
130    pub jwks_refresh_interval_secs: u64,
131    /// Token cache size (default: 1000)
132    pub token_cache_size: u64,
133    /// Token cache TTL in seconds (default: 300 = 5 min)
134    pub token_cache_ttl_secs: u64,
135}
136
137impl Default for OidcConfig {
138    fn default() -> Self {
139        Self {
140            issuers: Vec::new(),
141            jwks_refresh_interval_secs: DEFAULT_JWKS_REFRESH_INTERVAL_SECS,
142            token_cache_size: DEFAULT_TOKEN_CACHE_SIZE,
143            token_cache_ttl_secs: DEFAULT_TOKEN_CACHE_TTL_SECS,
144        }
145    }
146}
147
148impl OidcConfig {
149    /// Load OIDC configuration from environment variable
150    pub fn from_env() -> Result<Self> {
151        Self::from_env_var("MICROMEGAS_OIDC_CONFIG")
152    }
153
154    /// Load OIDC configuration from a named environment variable
155    pub fn from_env_var(name: &str) -> Result<Self> {
156        let json =
157            std::env::var(name).map_err(|_| anyhow!("{name} environment variable not set"))?;
158        let config: OidcConfig =
159            serde_json::from_str(&json).map_err(|e| anyhow!("Failed to parse {name}: {e:?}"))?;
160        Ok(config)
161    }
162}
163
164/// Audience can be either a string or an array of strings in OIDC tokens
165#[derive(Debug, Serialize, Deserialize)]
166#[serde(untagged)]
167enum Audience {
168    Single(String),
169    Multiple(Vec<String>),
170}
171
172impl Audience {
173    fn contains(&self, aud: &str) -> bool {
174        match self {
175            Audience::Single(s) => s == aud,
176            Audience::Multiple(v) => v.iter().any(|a| a == aud),
177        }
178    }
179}
180
181/// JWT Claims from OIDC ID token or access token
182///
183/// This struct supports multiple OIDC providers by including various email claim fields.
184/// Different providers use different claim names for email addresses:
185///
186/// Email extraction priority (see `get_email()` method):
187/// 1. `verified_primary_email` - Azure AD optional claim (most reliable, verified by provider)
188/// 2. `email` - Standard OIDC claim (Google, some Azure AD configurations)
189/// 3. `namespaced_email` - Custom namespaced claims (Auth0: `https://micromegas.io/email`)
190/// 4. `preferred_username` - Azure AD standard claim (often contains email)
191/// 5. `upn` - User Principal Name (Azure AD enterprise)
192/// 6. `unique_name` - Legacy Azure AD claim
193#[derive(Debug, Serialize, Deserialize)]
194struct Claims {
195    /// Issuer - identifies the principal that issued the JWT
196    iss: String,
197    /// Subject - identifies the principal that is the subject of the JWT
198    sub: String,
199    /// Audience - identifies the recipients that the JWT is intended for
200    /// Can be either a single string or an array of strings
201    aud: Audience,
202    /// Expiration time - identifies the expiration time on or after which the JWT must not be accepted
203    exp: i64,
204    /// Email address of the user (standard OIDC claim, used by Google and some Azure AD configurations)
205    #[serde(skip_serializing_if = "Option::is_none")]
206    email: Option<String>,
207    /// Verified primary email (Azure AD optional claim - most reliable, verified by provider)
208    #[serde(skip_serializing_if = "Option::is_none")]
209    verified_primary_email: Option<String>,
210    /// Preferred username (Azure AD standard claim, often contains email)
211    #[serde(skip_serializing_if = "Option::is_none")]
212    preferred_username: Option<String>,
213    /// User Principal Name (Azure AD enterprise claim)
214    #[serde(skip_serializing_if = "Option::is_none")]
215    upn: Option<String>,
216    /// Unique name (legacy Azure AD claim from older tokens)
217    #[serde(skip_serializing_if = "Option::is_none")]
218    unique_name: Option<String>,
219    /// Custom namespaced email claim (Auth0 custom claims via Actions)
220    #[serde(rename = "https://micromegas.io/email")]
221    #[serde(skip_serializing_if = "Option::is_none")]
222    namespaced_email: Option<String>,
223    /// Custom namespaced name claim (Auth0 custom claims via Actions)
224    #[serde(rename = "https://micromegas.io/name")]
225    #[serde(skip_serializing_if = "Option::is_none")]
226    namespaced_name: Option<String>,
227}
228
229impl Claims {
230    /// Get email from various possible claim fields
231    /// Priority order: verified_primary_email (most reliable) → email → namespaced_email → preferred_username → upn → unique_name
232    fn get_email(&self) -> Option<String> {
233        self.verified_primary_email
234            .clone()
235            .or_else(|| self.email.clone())
236            .or_else(|| self.namespaced_email.clone())
237            .or_else(|| self.preferred_username.clone())
238            .or_else(|| self.upn.clone())
239            .or_else(|| self.unique_name.clone())
240    }
241}
242
243/// OIDC issuer client for token validation
244struct OidcIssuerClient {
245    issuer: String,
246    audience: String,
247    jwks_cache: JwksCache,
248}
249
250impl OidcIssuerClient {
251    fn new(issuer: String, audience: String, jwks_ttl: Duration) -> Result<Self> {
252        let issuer_url = IssuerUrl::new(issuer.clone())
253            .map_err(|e| anyhow!("Invalid issuer URL '{}': {e:?}", issuer))?;
254
255        Ok(Self {
256            issuer,
257            audience,
258            jwks_cache: JwksCache::new(issuer_url, jwks_ttl),
259        })
260    }
261}
262
263/// Load admin users from the named environment variable
264fn load_admin_users(admin_var: &str) -> Vec<String> {
265    match std::env::var(admin_var) {
266        Ok(json) => serde_json::from_str::<Vec<String>>(&json).unwrap_or_default(),
267        Err(_) => vec![],
268    }
269}
270
271/// Convert a JWK to a DecodingKey for jsonwebtoken
272fn jwk_to_decoding_key(
273    jwk: &openidconnect::core::CoreJsonWebKey,
274) -> Result<jsonwebtoken::DecodingKey> {
275    // Serialize the JWK to JSON to extract parameters
276    let jwk_json =
277        serde_json::to_value(jwk).map_err(|e| anyhow!("Failed to serialize JWK: {e:?}"))?;
278
279    // Extract n and e parameters
280    let n = jwk_json
281        .get("n")
282        .and_then(|v| v.as_str())
283        .ok_or_else(|| anyhow!("JWK missing 'n' parameter"))?;
284    let e = jwk_json
285        .get("e")
286        .and_then(|v| v.as_str())
287        .ok_or_else(|| anyhow!("JWK missing 'e' parameter"))?;
288
289    // Decode base64url encoded parameters
290    let n_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
291        .decode(n.as_bytes())
292        .map_err(|e| anyhow!("Failed to decode 'n': {e:?}"))?;
293    let e_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
294        .decode(e.as_bytes())
295        .map_err(|e| anyhow!("Failed to decode 'e': {e:?}"))?;
296
297    // Create RSA public key
298    let n_bigint = BigUint::from_bytes_be(&n_bytes);
299    let e_bigint = BigUint::from_bytes_be(&e_bytes);
300
301    let public_key = RsaPublicKey::new(n_bigint, e_bigint)
302        .map_err(|e| anyhow!("Failed to create RSA public key: {e:?}"))?;
303
304    // Convert to PEM format
305    let pem = public_key
306        .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
307        .map_err(|e| anyhow!("Failed to encode public key as PEM: {e:?}"))?;
308
309    // Create DecodingKey
310    jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
311        .map_err(|e| anyhow!("Failed to create decoding key: {e:?}"))
312}
313
314/// OIDC authentication provider
315///
316/// Validates JWT tokens (access or ID tokens) from configured OIDC providers.
317/// Caches both JWKS and validated tokens for performance.
318pub struct OidcAuthProvider {
319    /// Map from issuer URL to list of clients (supports multiple audiences per issuer)
320    clients: HashMap<String, Vec<Arc<OidcIssuerClient>>>,
321    /// Cache for validated tokens
322    token_cache: Cache<String, Arc<AuthContext>>,
323    /// Admin users (by email or subject)
324    admin_users: Vec<String>,
325}
326
327impl std::fmt::Debug for OidcAuthProvider {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        f.debug_struct("OidcAuthProvider")
330            .field("num_clients", &self.clients.len())
331            .field("admin_users", &"(not printed)")
332            .finish()
333    }
334}
335
336impl OidcAuthProvider {
337    /// Create a new OIDC authentication provider.
338    ///
339    /// `admin_var` is the environment-variable name from which the admin user list
340    /// is loaded (e.g. `"MICROMEGAS_ADMINS"` or `"MICROMEGAS_ANALYTICS_ADMINS"`).
341    pub async fn new(config: OidcConfig, admin_var: &str) -> Result<Self> {
342        if config.issuers.is_empty() {
343            return Err(anyhow!("At least one OIDC issuer must be configured"));
344        }
345
346        micromegas_tracing::info!("Configuring OIDC with {} issuer(s)", config.issuers.len());
347        for (idx, issuer_config) in config.issuers.iter().enumerate() {
348            micromegas_tracing::info!(
349                "  Issuer {}: {} (audience: {})",
350                idx + 1,
351                issuer_config.issuer,
352                issuer_config.audience
353            );
354        }
355
356        let jwks_ttl = Duration::from_secs(config.jwks_refresh_interval_secs);
357        let mut clients: HashMap<String, Vec<Arc<OidcIssuerClient>>> = HashMap::new();
358
359        // Initialize a client for each configured issuer
360        // Supports multiple audiences for the same issuer (e.g., access tokens + ID tokens)
361        for issuer_config in config.issuers {
362            let client = OidcIssuerClient::new(
363                issuer_config.issuer.clone(),
364                issuer_config.audience,
365                jwks_ttl,
366            )?;
367
368            clients
369                .entry(issuer_config.issuer)
370                .or_default()
371                .push(Arc::new(client));
372        }
373
374        // Create token cache
375        let token_cache = Cache::builder()
376            .max_capacity(config.token_cache_size)
377            .time_to_live(Duration::from_secs(config.token_cache_ttl_secs))
378            .build();
379
380        // Load admin users from environment
381        let admin_users = load_admin_users(admin_var);
382
383        Ok(Self {
384            clients,
385            token_cache,
386            admin_users,
387        })
388    }
389
390    fn is_admin(&self, subject: &str, email: Option<&str>) -> bool {
391        self.admin_users
392            .iter()
393            .any(|admin| admin == subject || email.map(|e| admin == e).unwrap_or(false))
394    }
395
396    /// Decode JWT payload without validation to extract issuer
397    ///
398    /// JWTs are structured as: header.payload.signature
399    /// Both header and payload are base64url-encoded JSON
400    fn decode_payload_unsafe(&self, token: &str) -> Result<Claims> {
401        let parts: Vec<&str> = token.split('.').collect();
402        if parts.len() != 3 {
403            return Err(anyhow!("Invalid JWT format"));
404        }
405
406        // Decode the payload (second part)
407        let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
408            .decode(parts[1].as_bytes())
409            .map_err(|e| anyhow!("Failed to decode JWT payload: {e:?}"))?;
410
411        let claims: Claims = serde_json::from_slice(&payload_bytes)
412            .map_err(|e| anyhow!("Failed to parse JWT claims: {e:?}"))?;
413
414        Ok(claims)
415    }
416
417    /// Validate a JWT token (access token or ID token) and return authentication context
418    ///
419    /// This implementation follows OAuth 2.0 best practices by:
420    /// 1. Extracting kid from JWT header for direct key lookup (falls back to trying all keys)
421    /// 2. Extracting issuer from JWT payload for direct client lookup
422    /// 3. Using O(1) lookups instead of O(n*m) iteration
423    /// 4. Eliminating timing side-channels
424    ///
425    /// Supports both:
426    /// - Access tokens (with API audience, used by Auth0 and similar)
427    /// - ID tokens (with client_id audience, used by Google/Azure AD)
428    async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext> {
429        // Step 1: Decode header (unsigned) to get key ID (kid)
430        let header = decode_header(token).map_err(|e| anyhow!("Invalid JWT header: {e:?}"))?;
431
432        // kid is optional - some providers omit it when there's only one key
433        let kid = header.kid;
434
435        // Step 2: Decode payload (unsigned) to get issuer and expiration
436        let unverified_claims = self.decode_payload_unsafe(token)?;
437
438        // Step 3: Look up clients for this issuer
439        let issuer_clients = self
440            .clients
441            .get(&unverified_claims.iss)
442            .ok_or_else(|| anyhow!("Unknown issuer: {}", unverified_claims.iss))?;
443
444        // Step 4: Fetch JWKS once (all clients for the same issuer share the same JWKS)
445        // Use the first client's cache - they all point to the same issuer's JWKS endpoint
446        let first_client = issuer_clients
447            .first()
448            .ok_or_else(|| anyhow!("No clients configured for issuer"))?;
449
450        let jwks = first_client
451            .jwks_cache
452            .get()
453            .await
454            .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))?;
455
456        // Step 5: Find the key(s) to try
457        // If kid is present, look up directly; otherwise try all keys
458        let keys_to_try: Vec<_> = if let Some(ref kid_value) = kid {
459            // Direct lookup by kid
460            jwks.keys()
461                .iter()
462                .filter(|k| k.key_id().map(|id| id.as_str()) == Some(kid_value.as_str()))
463                .collect()
464        } else {
465            // No kid provided - try all keys (common for single-key providers)
466            jwks.keys().iter().collect()
467        };
468
469        if keys_to_try.is_empty() {
470            return Err(if let Some(kid_value) = kid {
471                anyhow!("Key with kid '{}' not found in JWKS", kid_value)
472            } else {
473                anyhow!("No keys found in JWKS")
474            });
475        }
476
477        // Step 6: Try each key until one works
478        let mut key_error = anyhow!("No valid key found");
479        for key in keys_to_try {
480            match jwk_to_decoding_key(key) {
481                Ok(decoding_key) => {
482                    match self
483                        .try_validate_with_key(token, &decoding_key, issuer_clients)
484                        .await
485                    {
486                        Ok(auth_ctx) => return Ok(auth_ctx),
487                        Err(e) => key_error = e,
488                    }
489                }
490                Err(e) => key_error = e,
491            }
492        }
493
494        Err(key_error)
495    }
496
497    /// Try to validate a token with a specific decoding key against all configured audiences
498    async fn try_validate_with_key(
499        &self,
500        token: &str,
501        decoding_key: &jsonwebtoken::DecodingKey,
502        issuer_clients: &[Arc<OidcIssuerClient>],
503    ) -> Result<AuthContext> {
504        let configured_audiences: Vec<String> =
505            issuer_clients.iter().map(|c| c.audience.clone()).collect();
506        let mut last_error = anyhow!("No matching audience found");
507
508        for client in issuer_clients {
509            // Validate token with specific key and issuer
510            let mut validation = Validation::new(Algorithm::RS256);
511            // Don't validate audience yet - we'll do it manually
512            validation.validate_aud = false;
513            validation.set_issuer(&[&client.issuer]);
514
515            let claims = match decode::<Claims>(token, decoding_key, &validation) {
516                Ok(token_data) => token_data.claims,
517                Err(e) => {
518                    last_error = anyhow!("Token validation failed: {e:?}");
519                    continue;
520                }
521            };
522
523            // Manually validate audience - if it matches, we found the right client!
524            if claims.aud.contains(&client.audience) {
525                // Success! Validate expiration and return auth context
526                let expires_at = DateTime::from_timestamp(claims.exp, 0)
527                    .ok_or_else(|| anyhow!("Invalid expiration timestamp"))?;
528
529                if expires_at < Utc::now() {
530                    return Err(anyhow!("Token has expired"));
531                }
532
533                let email = claims.get_email();
534                let is_admin = self.is_admin(&claims.sub, email.as_deref());
535
536                return Ok(AuthContext {
537                    subject: claims.sub,
538                    email,
539                    issuer: claims.iss,
540                    audience: Some(client.audience.clone()),
541                    expires_at: Some(expires_at),
542                    auth_type: AuthType::Oidc,
543                    is_admin,
544                    allow_delegation: false,
545                });
546            } else {
547                // Provide detailed error with configured vs actual audiences
548                let actual_audiences = match &claims.aud {
549                    Audience::Single(s) => vec![s.clone()],
550                    Audience::Multiple(v) => v.clone(),
551                };
552                last_error = anyhow!(
553                    "Token audience mismatch - configured audiences: {:?}, token audiences: {:?}",
554                    configured_audiences,
555                    actual_audiences
556                );
557            }
558        }
559
560        // If we get here, none of the clients matched
561        Err(last_error)
562    }
563}
564
565#[async_trait::async_trait]
566impl AuthProvider for OidcAuthProvider {
567    async fn validate_request(
568        &self,
569        parts: &dyn crate::types::RequestParts,
570    ) -> Result<AuthContext> {
571        let token = parts
572            .bearer_token()
573            .ok_or_else(|| anyhow!("missing bearer token"))?;
574
575        // Check token cache first, but verify token hasn't expired
576        // Use 30-second grace period to account for clock skew and network latency
577        if let Some(cached) = self.token_cache.get(token).await {
578            let is_expired = cached
579                .expires_at
580                .map(|exp| exp <= Utc::now() + TimeDelta::seconds(30))
581                .unwrap_or(false);
582
583            if is_expired {
584                self.token_cache.remove(token).await;
585            } else {
586                return Ok((*cached).clone());
587            }
588        }
589
590        // Validate token
591        let auth_ctx = self.validate_jwt_token(token).await?;
592
593        // Cache the result
594        self.token_cache
595            .insert(token.to_string(), Arc::new(auth_ctx.clone()))
596            .await;
597
598        Ok(auth_ctx)
599    }
600}