1use crate::types::{AuthContext, AuthProvider, AuthType};
2use anyhow::{Result, anyhow};
3use base64::Engine;
4use chrono::{DateTime, TimeDelta, Utc};
5use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
6use moka::future::Cache;
7use openidconnect::core::{CoreJsonWebKeySet, CoreProviderMetadata};
8use openidconnect::{IssuerUrl, JsonWebKey};
9use rsa::pkcs1::EncodeRsaPublicKey;
10use rsa::{BigUint, RsaPublicKey};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16pub fn create_http_client() -> Result<reqwest::Client> {
43 reqwest::ClientBuilder::new()
44 .redirect(reqwest::redirect::Policy::none())
45 .build()
46 .map_err(|e| anyhow!("Failed to create HTTP client: {e:?}"))
47}
48
49async fn fetch_jwks(issuer_url: &IssuerUrl) -> Result<Arc<CoreJsonWebKeySet>> {
51 let http_client = create_http_client()?;
52
53 let metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
55 .await
56 .map_err(|e| {
57 anyhow!(
58 "Failed to discover OIDC metadata from {}: {e:?}",
59 issuer_url
60 )
61 })?;
62
63 let jwks_uri = metadata.jwks_uri();
65 let jwks: CoreJsonWebKeySet = http_client
66 .get(jwks_uri.url().as_str())
67 .send()
68 .await
69 .map_err(|e| anyhow!("Failed to fetch JWKS from {}: {e:?}", jwks_uri))?
70 .json()
71 .await
72 .map_err(|e| anyhow!("Failed to parse JWKS: {e:?}"))?;
73
74 Ok(Arc::new(jwks))
75}
76
77struct JwksCache {
82 issuer_url: IssuerUrl,
83 cache: Cache<String, Arc<CoreJsonWebKeySet>>,
84}
85
86impl JwksCache {
87 fn new(issuer_url: IssuerUrl, ttl: Duration) -> Self {
89 let cache = Cache::builder().time_to_live(ttl).build();
90
91 Self { issuer_url, cache }
92 }
93
94 async fn get(&self) -> Result<Arc<CoreJsonWebKeySet>> {
96 let issuer_url = self.issuer_url.clone();
97
98 self.cache
99 .try_get_with(
100 "jwks".to_string(),
101 async move { fetch_jwks(&issuer_url).await },
102 )
103 .await
104 .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))
105 }
106}
107
108#[derive(Debug, Clone, Deserialize)]
110pub struct OidcIssuer {
111 pub issuer: String,
113 pub audience: String,
117}
118
119const DEFAULT_JWKS_REFRESH_INTERVAL_SECS: u64 = 3600;
120const DEFAULT_TOKEN_CACHE_SIZE: u64 = 1000;
121const DEFAULT_TOKEN_CACHE_TTL_SECS: u64 = 300;
122
123#[derive(Debug, Clone, Deserialize)]
125#[serde(default)]
126pub struct OidcConfig {
127 pub issuers: Vec<OidcIssuer>,
129 pub jwks_refresh_interval_secs: u64,
131 pub token_cache_size: u64,
133 pub token_cache_ttl_secs: u64,
135}
136
137impl Default for OidcConfig {
138 fn default() -> Self {
139 Self {
140 issuers: Vec::new(),
141 jwks_refresh_interval_secs: DEFAULT_JWKS_REFRESH_INTERVAL_SECS,
142 token_cache_size: DEFAULT_TOKEN_CACHE_SIZE,
143 token_cache_ttl_secs: DEFAULT_TOKEN_CACHE_TTL_SECS,
144 }
145 }
146}
147
148impl OidcConfig {
149 pub fn from_env() -> Result<Self> {
151 Self::from_env_var("MICROMEGAS_OIDC_CONFIG")
152 }
153
154 pub fn from_env_var(name: &str) -> Result<Self> {
156 let json =
157 std::env::var(name).map_err(|_| anyhow!("{name} environment variable not set"))?;
158 let config: OidcConfig =
159 serde_json::from_str(&json).map_err(|e| anyhow!("Failed to parse {name}: {e:?}"))?;
160 Ok(config)
161 }
162}
163
164#[derive(Debug, Serialize, Deserialize)]
166#[serde(untagged)]
167enum Audience {
168 Single(String),
169 Multiple(Vec<String>),
170}
171
172impl Audience {
173 fn contains(&self, aud: &str) -> bool {
174 match self {
175 Audience::Single(s) => s == aud,
176 Audience::Multiple(v) => v.iter().any(|a| a == aud),
177 }
178 }
179}
180
181#[derive(Debug, Serialize, Deserialize)]
194struct Claims {
195 iss: String,
197 sub: String,
199 aud: Audience,
202 exp: i64,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 email: Option<String>,
207 #[serde(skip_serializing_if = "Option::is_none")]
209 verified_primary_email: Option<String>,
210 #[serde(skip_serializing_if = "Option::is_none")]
212 preferred_username: Option<String>,
213 #[serde(skip_serializing_if = "Option::is_none")]
215 upn: Option<String>,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 unique_name: Option<String>,
219 #[serde(rename = "https://micromegas.io/email")]
221 #[serde(skip_serializing_if = "Option::is_none")]
222 namespaced_email: Option<String>,
223 #[serde(rename = "https://micromegas.io/name")]
225 #[serde(skip_serializing_if = "Option::is_none")]
226 namespaced_name: Option<String>,
227}
228
229impl Claims {
230 fn get_email(&self) -> Option<String> {
233 self.verified_primary_email
234 .clone()
235 .or_else(|| self.email.clone())
236 .or_else(|| self.namespaced_email.clone())
237 .or_else(|| self.preferred_username.clone())
238 .or_else(|| self.upn.clone())
239 .or_else(|| self.unique_name.clone())
240 }
241}
242
243struct OidcIssuerClient {
245 issuer: String,
246 audience: String,
247 jwks_cache: JwksCache,
248}
249
250impl OidcIssuerClient {
251 fn new(issuer: String, audience: String, jwks_ttl: Duration) -> Result<Self> {
252 let issuer_url = IssuerUrl::new(issuer.clone())
253 .map_err(|e| anyhow!("Invalid issuer URL '{}': {e:?}", issuer))?;
254
255 Ok(Self {
256 issuer,
257 audience,
258 jwks_cache: JwksCache::new(issuer_url, jwks_ttl),
259 })
260 }
261}
262
263fn load_admin_users(admin_var: &str) -> Vec<String> {
265 match std::env::var(admin_var) {
266 Ok(json) => serde_json::from_str::<Vec<String>>(&json).unwrap_or_default(),
267 Err(_) => vec![],
268 }
269}
270
271fn jwk_to_decoding_key(
273 jwk: &openidconnect::core::CoreJsonWebKey,
274) -> Result<jsonwebtoken::DecodingKey> {
275 let jwk_json =
277 serde_json::to_value(jwk).map_err(|e| anyhow!("Failed to serialize JWK: {e:?}"))?;
278
279 let n = jwk_json
281 .get("n")
282 .and_then(|v| v.as_str())
283 .ok_or_else(|| anyhow!("JWK missing 'n' parameter"))?;
284 let e = jwk_json
285 .get("e")
286 .and_then(|v| v.as_str())
287 .ok_or_else(|| anyhow!("JWK missing 'e' parameter"))?;
288
289 let n_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
291 .decode(n.as_bytes())
292 .map_err(|e| anyhow!("Failed to decode 'n': {e:?}"))?;
293 let e_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
294 .decode(e.as_bytes())
295 .map_err(|e| anyhow!("Failed to decode 'e': {e:?}"))?;
296
297 let n_bigint = BigUint::from_bytes_be(&n_bytes);
299 let e_bigint = BigUint::from_bytes_be(&e_bytes);
300
301 let public_key = RsaPublicKey::new(n_bigint, e_bigint)
302 .map_err(|e| anyhow!("Failed to create RSA public key: {e:?}"))?;
303
304 let pem = public_key
306 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
307 .map_err(|e| anyhow!("Failed to encode public key as PEM: {e:?}"))?;
308
309 jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
311 .map_err(|e| anyhow!("Failed to create decoding key: {e:?}"))
312}
313
314pub struct OidcAuthProvider {
319 clients: HashMap<String, Vec<Arc<OidcIssuerClient>>>,
321 token_cache: Cache<String, Arc<AuthContext>>,
323 admin_users: Vec<String>,
325}
326
327impl std::fmt::Debug for OidcAuthProvider {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 f.debug_struct("OidcAuthProvider")
330 .field("num_clients", &self.clients.len())
331 .field("admin_users", &"(not printed)")
332 .finish()
333 }
334}
335
336impl OidcAuthProvider {
337 pub async fn new(config: OidcConfig, admin_var: &str) -> Result<Self> {
342 if config.issuers.is_empty() {
343 return Err(anyhow!("At least one OIDC issuer must be configured"));
344 }
345
346 micromegas_tracing::info!("Configuring OIDC with {} issuer(s)", config.issuers.len());
347 for (idx, issuer_config) in config.issuers.iter().enumerate() {
348 micromegas_tracing::info!(
349 " Issuer {}: {} (audience: {})",
350 idx + 1,
351 issuer_config.issuer,
352 issuer_config.audience
353 );
354 }
355
356 let jwks_ttl = Duration::from_secs(config.jwks_refresh_interval_secs);
357 let mut clients: HashMap<String, Vec<Arc<OidcIssuerClient>>> = HashMap::new();
358
359 for issuer_config in config.issuers {
362 let client = OidcIssuerClient::new(
363 issuer_config.issuer.clone(),
364 issuer_config.audience,
365 jwks_ttl,
366 )?;
367
368 clients
369 .entry(issuer_config.issuer)
370 .or_default()
371 .push(Arc::new(client));
372 }
373
374 let token_cache = Cache::builder()
376 .max_capacity(config.token_cache_size)
377 .time_to_live(Duration::from_secs(config.token_cache_ttl_secs))
378 .build();
379
380 let admin_users = load_admin_users(admin_var);
382
383 Ok(Self {
384 clients,
385 token_cache,
386 admin_users,
387 })
388 }
389
390 fn is_admin(&self, subject: &str, email: Option<&str>) -> bool {
391 self.admin_users
392 .iter()
393 .any(|admin| admin == subject || email.map(|e| admin == e).unwrap_or(false))
394 }
395
396 fn decode_payload_unsafe(&self, token: &str) -> Result<Claims> {
401 let parts: Vec<&str> = token.split('.').collect();
402 if parts.len() != 3 {
403 return Err(anyhow!("Invalid JWT format"));
404 }
405
406 let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
408 .decode(parts[1].as_bytes())
409 .map_err(|e| anyhow!("Failed to decode JWT payload: {e:?}"))?;
410
411 let claims: Claims = serde_json::from_slice(&payload_bytes)
412 .map_err(|e| anyhow!("Failed to parse JWT claims: {e:?}"))?;
413
414 Ok(claims)
415 }
416
417 async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext> {
429 let header = decode_header(token).map_err(|e| anyhow!("Invalid JWT header: {e:?}"))?;
431
432 let kid = header.kid;
434
435 let unverified_claims = self.decode_payload_unsafe(token)?;
437
438 let issuer_clients = self
440 .clients
441 .get(&unverified_claims.iss)
442 .ok_or_else(|| anyhow!("Unknown issuer: {}", unverified_claims.iss))?;
443
444 let first_client = issuer_clients
447 .first()
448 .ok_or_else(|| anyhow!("No clients configured for issuer"))?;
449
450 let jwks = first_client
451 .jwks_cache
452 .get()
453 .await
454 .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))?;
455
456 let keys_to_try: Vec<_> = if let Some(ref kid_value) = kid {
459 jwks.keys()
461 .iter()
462 .filter(|k| k.key_id().map(|id| id.as_str()) == Some(kid_value.as_str()))
463 .collect()
464 } else {
465 jwks.keys().iter().collect()
467 };
468
469 if keys_to_try.is_empty() {
470 return Err(if let Some(kid_value) = kid {
471 anyhow!("Key with kid '{}' not found in JWKS", kid_value)
472 } else {
473 anyhow!("No keys found in JWKS")
474 });
475 }
476
477 let mut key_error = anyhow!("No valid key found");
479 for key in keys_to_try {
480 match jwk_to_decoding_key(key) {
481 Ok(decoding_key) => {
482 match self
483 .try_validate_with_key(token, &decoding_key, issuer_clients)
484 .await
485 {
486 Ok(auth_ctx) => return Ok(auth_ctx),
487 Err(e) => key_error = e,
488 }
489 }
490 Err(e) => key_error = e,
491 }
492 }
493
494 Err(key_error)
495 }
496
497 async fn try_validate_with_key(
499 &self,
500 token: &str,
501 decoding_key: &jsonwebtoken::DecodingKey,
502 issuer_clients: &[Arc<OidcIssuerClient>],
503 ) -> Result<AuthContext> {
504 let configured_audiences: Vec<String> =
505 issuer_clients.iter().map(|c| c.audience.clone()).collect();
506 let mut last_error = anyhow!("No matching audience found");
507
508 for client in issuer_clients {
509 let mut validation = Validation::new(Algorithm::RS256);
511 validation.validate_aud = false;
513 validation.set_issuer(&[&client.issuer]);
514
515 let claims = match decode::<Claims>(token, decoding_key, &validation) {
516 Ok(token_data) => token_data.claims,
517 Err(e) => {
518 last_error = anyhow!("Token validation failed: {e:?}");
519 continue;
520 }
521 };
522
523 if claims.aud.contains(&client.audience) {
525 let expires_at = DateTime::from_timestamp(claims.exp, 0)
527 .ok_or_else(|| anyhow!("Invalid expiration timestamp"))?;
528
529 if expires_at < Utc::now() {
530 return Err(anyhow!("Token has expired"));
531 }
532
533 let email = claims.get_email();
534 let is_admin = self.is_admin(&claims.sub, email.as_deref());
535
536 return Ok(AuthContext {
537 subject: claims.sub,
538 email,
539 issuer: claims.iss,
540 audience: Some(client.audience.clone()),
541 expires_at: Some(expires_at),
542 auth_type: AuthType::Oidc,
543 is_admin,
544 allow_delegation: false,
545 });
546 } else {
547 let actual_audiences = match &claims.aud {
549 Audience::Single(s) => vec![s.clone()],
550 Audience::Multiple(v) => v.clone(),
551 };
552 last_error = anyhow!(
553 "Token audience mismatch - configured audiences: {:?}, token audiences: {:?}",
554 configured_audiences,
555 actual_audiences
556 );
557 }
558 }
559
560 Err(last_error)
562 }
563}
564
565#[async_trait::async_trait]
566impl AuthProvider for OidcAuthProvider {
567 async fn validate_request(
568 &self,
569 parts: &dyn crate::types::RequestParts,
570 ) -> Result<AuthContext> {
571 let token = parts
572 .bearer_token()
573 .ok_or_else(|| anyhow!("missing bearer token"))?;
574
575 if let Some(cached) = self.token_cache.get(token).await {
578 let is_expired = cached
579 .expires_at
580 .map(|exp| exp <= Utc::now() + TimeDelta::seconds(30))
581 .unwrap_or(false);
582
583 if is_expired {
584 self.token_cache.remove(token).await;
585 } else {
586 return Ok((*cached).clone());
587 }
588 }
589
590 let auth_ctx = self.validate_jwt_token(token).await?;
592
593 self.token_cache
595 .insert(token.to_string(), Arc::new(auth_ctx.clone()))
596 .await;
597
598 Ok(auth_ctx)
599 }
600}