1use crate::types::{AuthContext, AuthProvider, AuthType};
2use anyhow::{Result, anyhow};
3use base64::Engine;
4use chrono::{DateTime, TimeDelta, Utc};
5use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
6use moka::future::Cache;
7use openidconnect::core::{CoreJsonWebKeySet, CoreProviderMetadata};
8use openidconnect::{IssuerUrl, JsonWebKey};
9use rsa::pkcs1::EncodeRsaPublicKey;
10use rsa::{BigUint, RsaPublicKey};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16pub fn create_http_client() -> Result<reqwest::Client> {
43 reqwest::ClientBuilder::new()
44 .redirect(reqwest::redirect::Policy::none())
45 .build()
46 .map_err(|e| anyhow!("Failed to create HTTP client: {e:?}"))
47}
48
49async fn fetch_jwks(issuer_url: &IssuerUrl) -> Result<Arc<CoreJsonWebKeySet>> {
51 let http_client = create_http_client()?;
52
53 let metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
55 .await
56 .map_err(|e| {
57 anyhow!(
58 "Failed to discover OIDC metadata from {}: {e:?}",
59 issuer_url
60 )
61 })?;
62
63 let jwks_uri = metadata.jwks_uri();
65 let jwks: CoreJsonWebKeySet = http_client
66 .get(jwks_uri.url().as_str())
67 .send()
68 .await
69 .map_err(|e| anyhow!("Failed to fetch JWKS from {}: {e:?}", jwks_uri))?
70 .json()
71 .await
72 .map_err(|e| anyhow!("Failed to parse JWKS: {e:?}"))?;
73
74 Ok(Arc::new(jwks))
75}
76
77struct JwksCache {
82 issuer_url: IssuerUrl,
83 cache: Cache<String, Arc<CoreJsonWebKeySet>>,
84}
85
86impl JwksCache {
87 fn new(issuer_url: IssuerUrl, ttl: Duration) -> Self {
89 let cache = Cache::builder().time_to_live(ttl).build();
90
91 Self { issuer_url, cache }
92 }
93
94 async fn get(&self) -> Result<Arc<CoreJsonWebKeySet>> {
96 let issuer_url = self.issuer_url.clone();
97
98 self.cache
99 .try_get_with(
100 "jwks".to_string(),
101 async move { fetch_jwks(&issuer_url).await },
102 )
103 .await
104 .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))
105 }
106}
107
108#[derive(Debug, Clone, Deserialize)]
110pub struct OidcIssuer {
111 pub issuer: String,
113 pub audience: String,
117}
118
119const DEFAULT_JWKS_REFRESH_INTERVAL_SECS: u64 = 3600;
120const DEFAULT_TOKEN_CACHE_SIZE: u64 = 1000;
121const DEFAULT_TOKEN_CACHE_TTL_SECS: u64 = 300;
122
123#[derive(Debug, Clone, Deserialize)]
125#[serde(default)]
126pub struct OidcConfig {
127 pub issuers: Vec<OidcIssuer>,
129 pub jwks_refresh_interval_secs: u64,
131 pub token_cache_size: u64,
133 pub token_cache_ttl_secs: u64,
135}
136
137impl Default for OidcConfig {
138 fn default() -> Self {
139 Self {
140 issuers: Vec::new(),
141 jwks_refresh_interval_secs: DEFAULT_JWKS_REFRESH_INTERVAL_SECS,
142 token_cache_size: DEFAULT_TOKEN_CACHE_SIZE,
143 token_cache_ttl_secs: DEFAULT_TOKEN_CACHE_TTL_SECS,
144 }
145 }
146}
147
148impl OidcConfig {
149 pub fn from_env() -> Result<Self> {
151 let json = std::env::var("MICROMEGAS_OIDC_CONFIG")
152 .map_err(|_| anyhow!("MICROMEGAS_OIDC_CONFIG environment variable not set"))?;
153 let config: OidcConfig = serde_json::from_str(&json)
154 .map_err(|e| anyhow!("Failed to parse MICROMEGAS_OIDC_CONFIG: {e:?}"))?;
155 Ok(config)
156 }
157}
158
159#[derive(Debug, Serialize, Deserialize)]
161#[serde(untagged)]
162enum Audience {
163 Single(String),
164 Multiple(Vec<String>),
165}
166
167impl Audience {
168 fn contains(&self, aud: &str) -> bool {
169 match self {
170 Audience::Single(s) => s == aud,
171 Audience::Multiple(v) => v.iter().any(|a| a == aud),
172 }
173 }
174}
175
176#[derive(Debug, Serialize, Deserialize)]
189struct Claims {
190 iss: String,
192 sub: String,
194 aud: Audience,
197 exp: i64,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 email: Option<String>,
202 #[serde(skip_serializing_if = "Option::is_none")]
204 verified_primary_email: Option<String>,
205 #[serde(skip_serializing_if = "Option::is_none")]
207 preferred_username: Option<String>,
208 #[serde(skip_serializing_if = "Option::is_none")]
210 upn: Option<String>,
211 #[serde(skip_serializing_if = "Option::is_none")]
213 unique_name: Option<String>,
214 #[serde(rename = "https://micromegas.io/email")]
216 #[serde(skip_serializing_if = "Option::is_none")]
217 namespaced_email: Option<String>,
218 #[serde(rename = "https://micromegas.io/name")]
220 #[serde(skip_serializing_if = "Option::is_none")]
221 namespaced_name: Option<String>,
222}
223
224impl Claims {
225 fn get_email(&self) -> Option<String> {
228 self.verified_primary_email
229 .clone()
230 .or_else(|| self.email.clone())
231 .or_else(|| self.namespaced_email.clone())
232 .or_else(|| self.preferred_username.clone())
233 .or_else(|| self.upn.clone())
234 .or_else(|| self.unique_name.clone())
235 }
236}
237
238struct OidcIssuerClient {
240 issuer: String,
241 audience: String,
242 jwks_cache: JwksCache,
243}
244
245impl OidcIssuerClient {
246 fn new(issuer: String, audience: String, jwks_ttl: Duration) -> Result<Self> {
247 let issuer_url = IssuerUrl::new(issuer.clone())
248 .map_err(|e| anyhow!("Invalid issuer URL '{}': {e:?}", issuer))?;
249
250 Ok(Self {
251 issuer,
252 audience,
253 jwks_cache: JwksCache::new(issuer_url, jwks_ttl),
254 })
255 }
256}
257
258fn load_admin_users() -> Vec<String> {
260 match std::env::var("MICROMEGAS_ADMINS") {
261 Ok(json) => serde_json::from_str::<Vec<String>>(&json).unwrap_or_default(),
262 Err(_) => vec![],
263 }
264}
265
266fn jwk_to_decoding_key(
268 jwk: &openidconnect::core::CoreJsonWebKey,
269) -> Result<jsonwebtoken::DecodingKey> {
270 let jwk_json =
272 serde_json::to_value(jwk).map_err(|e| anyhow!("Failed to serialize JWK: {e:?}"))?;
273
274 let n = jwk_json
276 .get("n")
277 .and_then(|v| v.as_str())
278 .ok_or_else(|| anyhow!("JWK missing 'n' parameter"))?;
279 let e = jwk_json
280 .get("e")
281 .and_then(|v| v.as_str())
282 .ok_or_else(|| anyhow!("JWK missing 'e' parameter"))?;
283
284 let n_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
286 .decode(n.as_bytes())
287 .map_err(|e| anyhow!("Failed to decode 'n': {e:?}"))?;
288 let e_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
289 .decode(e.as_bytes())
290 .map_err(|e| anyhow!("Failed to decode 'e': {e:?}"))?;
291
292 let n_bigint = BigUint::from_bytes_be(&n_bytes);
294 let e_bigint = BigUint::from_bytes_be(&e_bytes);
295
296 let public_key = RsaPublicKey::new(n_bigint, e_bigint)
297 .map_err(|e| anyhow!("Failed to create RSA public key: {e:?}"))?;
298
299 let pem = public_key
301 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
302 .map_err(|e| anyhow!("Failed to encode public key as PEM: {e:?}"))?;
303
304 jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
306 .map_err(|e| anyhow!("Failed to create decoding key: {e:?}"))
307}
308
309pub struct OidcAuthProvider {
314 clients: HashMap<String, Vec<Arc<OidcIssuerClient>>>,
316 token_cache: Cache<String, Arc<AuthContext>>,
318 admin_users: Vec<String>,
320}
321
322impl std::fmt::Debug for OidcAuthProvider {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 f.debug_struct("OidcAuthProvider")
325 .field("num_clients", &self.clients.len())
326 .field("admin_users", &"(not printed)")
327 .finish()
328 }
329}
330
331impl OidcAuthProvider {
332 pub async fn new(config: OidcConfig) -> Result<Self> {
334 if config.issuers.is_empty() {
335 return Err(anyhow!("At least one OIDC issuer must be configured"));
336 }
337
338 micromegas_tracing::info!("Configuring OIDC with {} issuer(s)", config.issuers.len());
339 for (idx, issuer_config) in config.issuers.iter().enumerate() {
340 micromegas_tracing::info!(
341 " Issuer {}: {} (audience: {})",
342 idx + 1,
343 issuer_config.issuer,
344 issuer_config.audience
345 );
346 }
347
348 let jwks_ttl = Duration::from_secs(config.jwks_refresh_interval_secs);
349 let mut clients: HashMap<String, Vec<Arc<OidcIssuerClient>>> = HashMap::new();
350
351 for issuer_config in config.issuers {
354 let client = OidcIssuerClient::new(
355 issuer_config.issuer.clone(),
356 issuer_config.audience,
357 jwks_ttl,
358 )?;
359
360 clients
361 .entry(issuer_config.issuer)
362 .or_default()
363 .push(Arc::new(client));
364 }
365
366 let token_cache = Cache::builder()
368 .max_capacity(config.token_cache_size)
369 .time_to_live(Duration::from_secs(config.token_cache_ttl_secs))
370 .build();
371
372 let admin_users = load_admin_users();
374
375 Ok(Self {
376 clients,
377 token_cache,
378 admin_users,
379 })
380 }
381
382 fn is_admin(&self, subject: &str, email: Option<&str>) -> bool {
383 self.admin_users
384 .iter()
385 .any(|admin| admin == subject || email.map(|e| admin == e).unwrap_or(false))
386 }
387
388 fn decode_payload_unsafe(&self, token: &str) -> Result<Claims> {
393 let parts: Vec<&str> = token.split('.').collect();
394 if parts.len() != 3 {
395 return Err(anyhow!("Invalid JWT format"));
396 }
397
398 let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
400 .decode(parts[1].as_bytes())
401 .map_err(|e| anyhow!("Failed to decode JWT payload: {e:?}"))?;
402
403 let claims: Claims = serde_json::from_slice(&payload_bytes)
404 .map_err(|e| anyhow!("Failed to parse JWT claims: {e:?}"))?;
405
406 Ok(claims)
407 }
408
409 async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext> {
421 let header = decode_header(token).map_err(|e| anyhow!("Invalid JWT header: {e:?}"))?;
423
424 let kid = header.kid;
426
427 let unverified_claims = self.decode_payload_unsafe(token)?;
429
430 let issuer_clients = self
432 .clients
433 .get(&unverified_claims.iss)
434 .ok_or_else(|| anyhow!("Unknown issuer: {}", unverified_claims.iss))?;
435
436 let first_client = issuer_clients
439 .first()
440 .ok_or_else(|| anyhow!("No clients configured for issuer"))?;
441
442 let jwks = first_client
443 .jwks_cache
444 .get()
445 .await
446 .map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))?;
447
448 let keys_to_try: Vec<_> = if let Some(ref kid_value) = kid {
451 jwks.keys()
453 .iter()
454 .filter(|k| k.key_id().map(|id| id.as_str()) == Some(kid_value.as_str()))
455 .collect()
456 } else {
457 jwks.keys().iter().collect()
459 };
460
461 if keys_to_try.is_empty() {
462 return Err(if let Some(kid_value) = kid {
463 anyhow!("Key with kid '{}' not found in JWKS", kid_value)
464 } else {
465 anyhow!("No keys found in JWKS")
466 });
467 }
468
469 let mut key_error = anyhow!("No valid key found");
471 for key in keys_to_try {
472 match jwk_to_decoding_key(key) {
473 Ok(decoding_key) => {
474 match self
475 .try_validate_with_key(token, &decoding_key, issuer_clients)
476 .await
477 {
478 Ok(auth_ctx) => return Ok(auth_ctx),
479 Err(e) => key_error = e,
480 }
481 }
482 Err(e) => key_error = e,
483 }
484 }
485
486 Err(key_error)
487 }
488
489 async fn try_validate_with_key(
491 &self,
492 token: &str,
493 decoding_key: &jsonwebtoken::DecodingKey,
494 issuer_clients: &[Arc<OidcIssuerClient>],
495 ) -> Result<AuthContext> {
496 let configured_audiences: Vec<String> =
497 issuer_clients.iter().map(|c| c.audience.clone()).collect();
498 let mut last_error = anyhow!("No matching audience found");
499
500 for client in issuer_clients {
501 let mut validation = Validation::new(Algorithm::RS256);
503 validation.validate_aud = false;
505 validation.set_issuer(&[&client.issuer]);
506
507 let claims = match decode::<Claims>(token, decoding_key, &validation) {
508 Ok(token_data) => token_data.claims,
509 Err(e) => {
510 last_error = anyhow!("Token validation failed: {e:?}");
511 continue;
512 }
513 };
514
515 if claims.aud.contains(&client.audience) {
517 let expires_at = DateTime::from_timestamp(claims.exp, 0)
519 .ok_or_else(|| anyhow!("Invalid expiration timestamp"))?;
520
521 if expires_at < Utc::now() {
522 return Err(anyhow!("Token has expired"));
523 }
524
525 let email = claims.get_email();
526 let is_admin = self.is_admin(&claims.sub, email.as_deref());
527
528 return Ok(AuthContext {
529 subject: claims.sub,
530 email,
531 issuer: claims.iss,
532 audience: Some(client.audience.clone()),
533 expires_at: Some(expires_at),
534 auth_type: AuthType::Oidc,
535 is_admin,
536 allow_delegation: false,
537 });
538 } else {
539 let actual_audiences = match &claims.aud {
541 Audience::Single(s) => vec![s.clone()],
542 Audience::Multiple(v) => v.clone(),
543 };
544 last_error = anyhow!(
545 "Token audience mismatch - configured audiences: {:?}, token audiences: {:?}",
546 configured_audiences,
547 actual_audiences
548 );
549 }
550 }
551
552 Err(last_error)
554 }
555}
556
557#[async_trait::async_trait]
558impl AuthProvider for OidcAuthProvider {
559 async fn validate_request(
560 &self,
561 parts: &dyn crate::types::RequestParts,
562 ) -> Result<AuthContext> {
563 let token = parts
564 .bearer_token()
565 .ok_or_else(|| anyhow!("missing bearer token"))?;
566
567 if let Some(cached) = self.token_cache.get(token).await {
570 let is_expired = cached
571 .expires_at
572 .map(|exp| exp <= Utc::now() + TimeDelta::seconds(30))
573 .unwrap_or(false);
574
575 if is_expired {
576 self.token_cache.remove(token).await;
577 } else {
578 return Ok((*cached).clone());
579 }
580 }
581
582 let auth_ctx = self.validate_jwt_token(token).await?;
584
585 self.token_cache
587 .insert(token.to_string(), Arc::new(auth_ctx.clone()))
588 .await;
589
590 Ok(auth_ctx)
591 }
592}