micromegas_auth/
oidc.rs

1use crate::types::{AuthContext, AuthProvider, AuthType};
2use anyhow::{Result, anyhow};
3use base64::Engine;
4use chrono::{DateTime, TimeDelta, Utc};
5use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
6use moka::future::Cache;
7use openidconnect::core::{CoreJsonWebKeySet, CoreProviderMetadata};
8use openidconnect::{IssuerUrl, JsonWebKey};
9use rsa::pkcs1::EncodeRsaPublicKey;
10use rsa::{BigUint, RsaPublicKey};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16/// Create HTTP client for OIDC operations with security best practices
17///
18/// This client is configured with SSRF protection:
19/// - No redirects allowed (prevents open redirect attacks)
20/// - Suitable for OIDC discovery and token exchange operations
21///
22/// # Security
23///
24/// The no-redirect policy is important for OIDC operations because:
25/// - Prevents attackers from redirecting requests to internal services
26/// - Ensures OIDC endpoints are accessed directly without intermediate hops
27/// - Protects against SSRF (Server-Side Request Forgery) attacks
28///
29/// # Example
30///
31/// ```rust
32/// use micromegas_auth::oidc::create_http_client;
33///
34/// # async fn example() -> anyhow::Result<()> {
35/// let client = create_http_client()?;
36/// let response = client.get("https://accounts.google.com/.well-known/openid-configuration")
37///     .send()
38///     .await?;
39/// # Ok(())
40/// # }
41/// ```
42pub fn create_http_client() -> Result<reqwest::Client> {
43    reqwest::ClientBuilder::new()
44        .redirect(reqwest::redirect::Policy::none())
45        .build()
46        .map_err(|e| anyhow!("Failed to create HTTP client: {e:?}"))
47}
48
49/// Fetch JWKS from the OIDC provider using openidconnect's built-in discovery
50async fn fetch_jwks(issuer_url: &IssuerUrl) -> Result<Arc<CoreJsonWebKeySet>> {
51    let http_client = create_http_client()?;
52
53    // Use openidconnect's built-in OIDC discovery
54    let metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
55        .await
56        .map_err(|e| {
57            anyhow!(
58                "Failed to discover OIDC metadata from {}: {e:?}",
59                issuer_url
60            )
61        })?;
62
63    // Fetch JWKS from jwks_uri
64    let jwks_uri = metadata.jwks_uri();
65    let jwks: CoreJsonWebKeySet = http_client
66        .get(jwks_uri.url().as_str())
67        .send()
68        .await
69        .map_err(|e| anyhow!("Failed to fetch JWKS from {}: {e:?}", jwks_uri))?
70        .json()
71        .await
72        .map_err(|e| anyhow!("Failed to parse JWKS: {e:?}"))?;
73
74    Ok(Arc::new(jwks))
75}
76
77/// JWKS cache for an OIDC issuer
78///
79/// Caches JSON Web Key Sets with automatic TTL expiration.
80/// Uses moka for thread-safe caching with atomic cache miss handling.
81struct JwksCache {
82    issuer_url: IssuerUrl,
83    cache: Cache<String, Arc<CoreJsonWebKeySet>>,
84}
85
86impl JwksCache {
87    /// Create a new JWKS cache
88    fn new(issuer_url: IssuerUrl, ttl: Duration) -> Self {
89        let cache = Cache::builder().time_to_live(ttl).build();
90
91        Self { issuer_url, cache }
92    }
93
94    /// Get the JWKS, fetching from the issuer if not cached
95    async fn get(&self) -> Result<Arc<CoreJsonWebKeySet>> {
96        let issuer_url = self.issuer_url.clone();
97
98        self.cache
99            .try_get_with(
100                "jwks".to_string(),
101                async move { fetch_jwks(&issuer_url).await },
102            )
103            .await
104            .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))
105    }
106}
107
108/// Configuration for a single OIDC issuer
109#[derive(Debug, Clone, Deserialize)]
110pub struct OidcIssuer {
111    /// Issuer URL (e.g., <https://accounts.google.com>)
112    pub issuer: String,
113    /// Expected audience
114    /// - For access tokens: API audience (e.g., "<https://api.example.com>")
115    /// - For ID tokens: client ID
116    pub audience: String,
117}
118
119const DEFAULT_JWKS_REFRESH_INTERVAL_SECS: u64 = 3600;
120const DEFAULT_TOKEN_CACHE_SIZE: u64 = 1000;
121const DEFAULT_TOKEN_CACHE_TTL_SECS: u64 = 300;
122
123/// OIDC configuration
124#[derive(Debug, Clone, Deserialize)]
125#[serde(default)]
126pub struct OidcConfig {
127    /// List of configured OIDC issuers
128    pub issuers: Vec<OidcIssuer>,
129    /// JWKS refresh interval in seconds (default: 3600 = 1 hour)
130    pub jwks_refresh_interval_secs: u64,
131    /// Token cache size (default: 1000)
132    pub token_cache_size: u64,
133    /// Token cache TTL in seconds (default: 300 = 5 min)
134    pub token_cache_ttl_secs: u64,
135}
136
137impl Default for OidcConfig {
138    fn default() -> Self {
139        Self {
140            issuers: Vec::new(),
141            jwks_refresh_interval_secs: DEFAULT_JWKS_REFRESH_INTERVAL_SECS,
142            token_cache_size: DEFAULT_TOKEN_CACHE_SIZE,
143            token_cache_ttl_secs: DEFAULT_TOKEN_CACHE_TTL_SECS,
144        }
145    }
146}
147
148impl OidcConfig {
149    /// Load OIDC configuration from environment variable
150    pub fn from_env() -> Result<Self> {
151        let json = std::env::var("MICROMEGAS_OIDC_CONFIG")
152            .map_err(|_| anyhow!("MICROMEGAS_OIDC_CONFIG environment variable not set"))?;
153        let config: OidcConfig = serde_json::from_str(&json)
154            .map_err(|e| anyhow!("Failed to parse MICROMEGAS_OIDC_CONFIG: {e:?}"))?;
155        Ok(config)
156    }
157}
158
159/// Audience can be either a string or an array of strings in OIDC tokens
160#[derive(Debug, Serialize, Deserialize)]
161#[serde(untagged)]
162enum Audience {
163    Single(String),
164    Multiple(Vec<String>),
165}
166
167impl Audience {
168    fn contains(&self, aud: &str) -> bool {
169        match self {
170            Audience::Single(s) => s == aud,
171            Audience::Multiple(v) => v.iter().any(|a| a == aud),
172        }
173    }
174}
175
176/// JWT Claims from OIDC ID token or access token
177///
178/// This struct supports multiple OIDC providers by including various email claim fields.
179/// Different providers use different claim names for email addresses:
180///
181/// Email extraction priority (see `get_email()` method):
182/// 1. `verified_primary_email` - Azure AD optional claim (most reliable, verified by provider)
183/// 2. `email` - Standard OIDC claim (Google, some Azure AD configurations)
184/// 3. `namespaced_email` - Custom namespaced claims (Auth0: `https://micromegas.io/email`)
185/// 4. `preferred_username` - Azure AD standard claim (often contains email)
186/// 5. `upn` - User Principal Name (Azure AD enterprise)
187/// 6. `unique_name` - Legacy Azure AD claim
188#[derive(Debug, Serialize, Deserialize)]
189struct Claims {
190    /// Issuer - identifies the principal that issued the JWT
191    iss: String,
192    /// Subject - identifies the principal that is the subject of the JWT
193    sub: String,
194    /// Audience - identifies the recipients that the JWT is intended for
195    /// Can be either a single string or an array of strings
196    aud: Audience,
197    /// Expiration time - identifies the expiration time on or after which the JWT must not be accepted
198    exp: i64,
199    /// Email address of the user (standard OIDC claim, used by Google and some Azure AD configurations)
200    #[serde(skip_serializing_if = "Option::is_none")]
201    email: Option<String>,
202    /// Verified primary email (Azure AD optional claim - most reliable, verified by provider)
203    #[serde(skip_serializing_if = "Option::is_none")]
204    verified_primary_email: Option<String>,
205    /// Preferred username (Azure AD standard claim, often contains email)
206    #[serde(skip_serializing_if = "Option::is_none")]
207    preferred_username: Option<String>,
208    /// User Principal Name (Azure AD enterprise claim)
209    #[serde(skip_serializing_if = "Option::is_none")]
210    upn: Option<String>,
211    /// Unique name (legacy Azure AD claim from older tokens)
212    #[serde(skip_serializing_if = "Option::is_none")]
213    unique_name: Option<String>,
214    /// Custom namespaced email claim (Auth0 custom claims via Actions)
215    #[serde(rename = "https://micromegas.io/email")]
216    #[serde(skip_serializing_if = "Option::is_none")]
217    namespaced_email: Option<String>,
218    /// Custom namespaced name claim (Auth0 custom claims via Actions)
219    #[serde(rename = "https://micromegas.io/name")]
220    #[serde(skip_serializing_if = "Option::is_none")]
221    namespaced_name: Option<String>,
222}
223
224impl Claims {
225    /// Get email from various possible claim fields
226    /// Priority order: verified_primary_email (most reliable) → email → namespaced_email → preferred_username → upn → unique_name
227    fn get_email(&self) -> Option<String> {
228        self.verified_primary_email
229            .clone()
230            .or_else(|| self.email.clone())
231            .or_else(|| self.namespaced_email.clone())
232            .or_else(|| self.preferred_username.clone())
233            .or_else(|| self.upn.clone())
234            .or_else(|| self.unique_name.clone())
235    }
236}
237
238/// OIDC issuer client for token validation
239struct OidcIssuerClient {
240    issuer: String,
241    audience: String,
242    jwks_cache: JwksCache,
243}
244
245impl OidcIssuerClient {
246    fn new(issuer: String, audience: String, jwks_ttl: Duration) -> Result<Self> {
247        let issuer_url = IssuerUrl::new(issuer.clone())
248            .map_err(|e| anyhow!("Invalid issuer URL '{}': {e:?}", issuer))?;
249
250        Ok(Self {
251            issuer,
252            audience,
253            jwks_cache: JwksCache::new(issuer_url, jwks_ttl),
254        })
255    }
256}
257
258/// Load admin users from environment variable
259fn load_admin_users() -> Vec<String> {
260    match std::env::var("MICROMEGAS_ADMINS") {
261        Ok(json) => serde_json::from_str::<Vec<String>>(&json).unwrap_or_default(),
262        Err(_) => vec![],
263    }
264}
265
266/// Convert a JWK to a DecodingKey for jsonwebtoken
267fn jwk_to_decoding_key(
268    jwk: &openidconnect::core::CoreJsonWebKey,
269) -> Result<jsonwebtoken::DecodingKey> {
270    // Serialize the JWK to JSON to extract parameters
271    let jwk_json =
272        serde_json::to_value(jwk).map_err(|e| anyhow!("Failed to serialize JWK: {e:?}"))?;
273
274    // Extract n and e parameters
275    let n = jwk_json
276        .get("n")
277        .and_then(|v| v.as_str())
278        .ok_or_else(|| anyhow!("JWK missing 'n' parameter"))?;
279    let e = jwk_json
280        .get("e")
281        .and_then(|v| v.as_str())
282        .ok_or_else(|| anyhow!("JWK missing 'e' parameter"))?;
283
284    // Decode base64url encoded parameters
285    let n_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
286        .decode(n.as_bytes())
287        .map_err(|e| anyhow!("Failed to decode 'n': {e:?}"))?;
288    let e_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
289        .decode(e.as_bytes())
290        .map_err(|e| anyhow!("Failed to decode 'e': {e:?}"))?;
291
292    // Create RSA public key
293    let n_bigint = BigUint::from_bytes_be(&n_bytes);
294    let e_bigint = BigUint::from_bytes_be(&e_bytes);
295
296    let public_key = RsaPublicKey::new(n_bigint, e_bigint)
297        .map_err(|e| anyhow!("Failed to create RSA public key: {e:?}"))?;
298
299    // Convert to PEM format
300    let pem = public_key
301        .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
302        .map_err(|e| anyhow!("Failed to encode public key as PEM: {e:?}"))?;
303
304    // Create DecodingKey
305    jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
306        .map_err(|e| anyhow!("Failed to create decoding key: {e:?}"))
307}
308
309/// OIDC authentication provider
310///
311/// Validates JWT tokens (access or ID tokens) from configured OIDC providers.
312/// Caches both JWKS and validated tokens for performance.
313pub struct OidcAuthProvider {
314    /// Map from issuer URL to list of clients (supports multiple audiences per issuer)
315    clients: HashMap<String, Vec<Arc<OidcIssuerClient>>>,
316    /// Cache for validated tokens
317    token_cache: Cache<String, Arc<AuthContext>>,
318    /// Admin users (by email or subject)
319    admin_users: Vec<String>,
320}
321
322impl std::fmt::Debug for OidcAuthProvider {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("OidcAuthProvider")
325            .field("num_clients", &self.clients.len())
326            .field("admin_users", &"(not printed)")
327            .finish()
328    }
329}
330
331impl OidcAuthProvider {
332    /// Create a new OIDC authentication provider
333    pub async fn new(config: OidcConfig) -> Result<Self> {
334        if config.issuers.is_empty() {
335            return Err(anyhow!("At least one OIDC issuer must be configured"));
336        }
337
338        micromegas_tracing::info!("Configuring OIDC with {} issuer(s)", config.issuers.len());
339        for (idx, issuer_config) in config.issuers.iter().enumerate() {
340            micromegas_tracing::info!(
341                "  Issuer {}: {} (audience: {})",
342                idx + 1,
343                issuer_config.issuer,
344                issuer_config.audience
345            );
346        }
347
348        let jwks_ttl = Duration::from_secs(config.jwks_refresh_interval_secs);
349        let mut clients: HashMap<String, Vec<Arc<OidcIssuerClient>>> = HashMap::new();
350
351        // Initialize a client for each configured issuer
352        // Supports multiple audiences for the same issuer (e.g., access tokens + ID tokens)
353        for issuer_config in config.issuers {
354            let client = OidcIssuerClient::new(
355                issuer_config.issuer.clone(),
356                issuer_config.audience,
357                jwks_ttl,
358            )?;
359
360            clients
361                .entry(issuer_config.issuer)
362                .or_default()
363                .push(Arc::new(client));
364        }
365
366        // Create token cache
367        let token_cache = Cache::builder()
368            .max_capacity(config.token_cache_size)
369            .time_to_live(Duration::from_secs(config.token_cache_ttl_secs))
370            .build();
371
372        // Load admin users from environment
373        let admin_users = load_admin_users();
374
375        Ok(Self {
376            clients,
377            token_cache,
378            admin_users,
379        })
380    }
381
382    fn is_admin(&self, subject: &str, email: Option<&str>) -> bool {
383        self.admin_users
384            .iter()
385            .any(|admin| admin == subject || email.map(|e| admin == e).unwrap_or(false))
386    }
387
388    /// Decode JWT payload without validation to extract issuer
389    ///
390    /// JWTs are structured as: header.payload.signature
391    /// Both header and payload are base64url-encoded JSON
392    fn decode_payload_unsafe(&self, token: &str) -> Result<Claims> {
393        let parts: Vec<&str> = token.split('.').collect();
394        if parts.len() != 3 {
395            return Err(anyhow!("Invalid JWT format"));
396        }
397
398        // Decode the payload (second part)
399        let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
400            .decode(parts[1].as_bytes())
401            .map_err(|e| anyhow!("Failed to decode JWT payload: {e:?}"))?;
402
403        let claims: Claims = serde_json::from_slice(&payload_bytes)
404            .map_err(|e| anyhow!("Failed to parse JWT claims: {e:?}"))?;
405
406        Ok(claims)
407    }
408
409    /// Validate a JWT token (access token or ID token) and return authentication context
410    ///
411    /// This implementation follows OAuth 2.0 best practices by:
412    /// 1. Extracting kid from JWT header for direct key lookup (falls back to trying all keys)
413    /// 2. Extracting issuer from JWT payload for direct client lookup
414    /// 3. Using O(1) lookups instead of O(n*m) iteration
415    /// 4. Eliminating timing side-channels
416    ///
417    /// Supports both:
418    /// - Access tokens (with API audience, used by Auth0 and similar)
419    /// - ID tokens (with client_id audience, used by Google/Azure AD)
420    async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext> {
421        // Step 1: Decode header (unsigned) to get key ID (kid)
422        let header = decode_header(token).map_err(|e| anyhow!("Invalid JWT header: {e:?}"))?;
423
424        // kid is optional - some providers omit it when there's only one key
425        let kid = header.kid;
426
427        // Step 2: Decode payload (unsigned) to get issuer and expiration
428        let unverified_claims = self.decode_payload_unsafe(token)?;
429
430        // Step 3: Look up clients for this issuer
431        let issuer_clients = self
432            .clients
433            .get(&unverified_claims.iss)
434            .ok_or_else(|| anyhow!("Unknown issuer: {}", unverified_claims.iss))?;
435
436        // Step 4: Fetch JWKS once (all clients for the same issuer share the same JWKS)
437        // Use the first client's cache - they all point to the same issuer's JWKS endpoint
438        let first_client = issuer_clients
439            .first()
440            .ok_or_else(|| anyhow!("No clients configured for issuer"))?;
441
442        let jwks = first_client
443            .jwks_cache
444            .get()
445            .await
446            .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))?;
447
448        // Step 5: Find the key(s) to try
449        // If kid is present, look up directly; otherwise try all keys
450        let keys_to_try: Vec<_> = if let Some(ref kid_value) = kid {
451            // Direct lookup by kid
452            jwks.keys()
453                .iter()
454                .filter(|k| k.key_id().map(|id| id.as_str()) == Some(kid_value.as_str()))
455                .collect()
456        } else {
457            // No kid provided - try all keys (common for single-key providers)
458            jwks.keys().iter().collect()
459        };
460
461        if keys_to_try.is_empty() {
462            return Err(if let Some(kid_value) = kid {
463                anyhow!("Key with kid '{}' not found in JWKS", kid_value)
464            } else {
465                anyhow!("No keys found in JWKS")
466            });
467        }
468
469        // Step 6: Try each key until one works
470        let mut key_error = anyhow!("No valid key found");
471        for key in keys_to_try {
472            match jwk_to_decoding_key(key) {
473                Ok(decoding_key) => {
474                    match self
475                        .try_validate_with_key(token, &decoding_key, issuer_clients)
476                        .await
477                    {
478                        Ok(auth_ctx) => return Ok(auth_ctx),
479                        Err(e) => key_error = e,
480                    }
481                }
482                Err(e) => key_error = e,
483            }
484        }
485
486        Err(key_error)
487    }
488
489    /// Try to validate a token with a specific decoding key against all configured audiences
490    async fn try_validate_with_key(
491        &self,
492        token: &str,
493        decoding_key: &jsonwebtoken::DecodingKey,
494        issuer_clients: &[Arc<OidcIssuerClient>],
495    ) -> Result<AuthContext> {
496        let configured_audiences: Vec<String> =
497            issuer_clients.iter().map(|c| c.audience.clone()).collect();
498        let mut last_error = anyhow!("No matching audience found");
499
500        for client in issuer_clients {
501            // Validate token with specific key and issuer
502            let mut validation = Validation::new(Algorithm::RS256);
503            // Don't validate audience yet - we'll do it manually
504            validation.validate_aud = false;
505            validation.set_issuer(&[&client.issuer]);
506
507            let claims = match decode::<Claims>(token, decoding_key, &validation) {
508                Ok(token_data) => token_data.claims,
509                Err(e) => {
510                    last_error = anyhow!("Token validation failed: {e:?}");
511                    continue;
512                }
513            };
514
515            // Manually validate audience - if it matches, we found the right client!
516            if claims.aud.contains(&client.audience) {
517                // Success! Validate expiration and return auth context
518                let expires_at = DateTime::from_timestamp(claims.exp, 0)
519                    .ok_or_else(|| anyhow!("Invalid expiration timestamp"))?;
520
521                if expires_at < Utc::now() {
522                    return Err(anyhow!("Token has expired"));
523                }
524
525                let email = claims.get_email();
526                let is_admin = self.is_admin(&claims.sub, email.as_deref());
527
528                return Ok(AuthContext {
529                    subject: claims.sub,
530                    email,
531                    issuer: claims.iss,
532                    audience: Some(client.audience.clone()),
533                    expires_at: Some(expires_at),
534                    auth_type: AuthType::Oidc,
535                    is_admin,
536                    allow_delegation: false,
537                });
538            } else {
539                // Provide detailed error with configured vs actual audiences
540                let actual_audiences = match &claims.aud {
541                    Audience::Single(s) => vec![s.clone()],
542                    Audience::Multiple(v) => v.clone(),
543                };
544                last_error = anyhow!(
545                    "Token audience mismatch - configured audiences: {:?}, token audiences: {:?}",
546                    configured_audiences,
547                    actual_audiences
548                );
549            }
550        }
551
552        // If we get here, none of the clients matched
553        Err(last_error)
554    }
555}
556
557#[async_trait::async_trait]
558impl AuthProvider for OidcAuthProvider {
559    async fn validate_request(
560        &self,
561        parts: &dyn crate::types::RequestParts,
562    ) -> Result<AuthContext> {
563        let token = parts
564            .bearer_token()
565            .ok_or_else(|| anyhow!("missing bearer token"))?;
566
567        // Check token cache first, but verify token hasn't expired
568        // Use 30-second grace period to account for clock skew and network latency
569        if let Some(cached) = self.token_cache.get(token).await {
570            let is_expired = cached
571                .expires_at
572                .map(|exp| exp <= Utc::now() + TimeDelta::seconds(30))
573                .unwrap_or(false);
574
575            if is_expired {
576                self.token_cache.remove(token).await;
577            } else {
578                return Ok((*cached).clone());
579            }
580        }
581
582        // Validate token
583        let auth_ctx = self.validate_jwt_token(token).await?;
584
585        // Cache the result
586        self.token_cache
587            .insert(token.to_string(), Arc::new(auth_ctx.clone()))
588            .await;
589
590        Ok(auth_ctx)
591    }
592}