From 601ca05fc803f9a6fdf107d687fecd97d021bbd0 Mon Sep 17 00:00:00 2001 From: Mick Grove Date: Mon, 14 Jul 2025 15:31:44 -0700 Subject: [PATCH] JWT validation performs OpenID Connect discovery using the iss claim and verifies signatures via JWKS --- CHANGELOG.md | 2 + Cargo.toml | 2 + data/rules/jwt.yml | 4 +- src/rules/rule.rs | 1 + src/validation.rs | 59 ++++++------ src/validation/jwt.rs | 207 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 245 insertions(+), 30 deletions(-) create mode 100644 src/validation/jwt.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e1179b..032379c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ All notable changes to this project will be documented in this file. - Added baseline feature with `--baseline-file` and `--manage-baseline` flags - Introduced `--exclude` option for skipping paths - Added tests covering baseline and exclude workflow +- Added validation for JWT tokens that checks `exp` and `nbf` claims +- JWT validation performs OpenID Connect discovery using the `iss` claim and verifies signatures via JWKS ## [1.20.0] diff --git a/Cargo.toml b/Cargo.toml index dc2db13..e220d39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -162,6 +162,8 @@ atty = "0.2.14" self_update = { version = "0.42.0", default-features = false, features = ["rustls", "archive-tar", "archive-zip", "compression-flate2"] } semver = "1.0.26" globset = "0.4.16" +jsonwebtoken = "9.3.1" +ipnet = "2.11.0" [dependencies.tikv-jemallocator] version = "0.6" diff --git a/data/rules/jwt.yml b/data/rules/jwt.yml index cd3f78d..e596027 100644 --- a/data/rules/jwt.yml +++ b/data/rules/jwt.yml @@ -22,4 +22,6 @@ rules: - https://datatracker.ietf.org/doc/html/rfc7519 - https://en.wikipedia.org/wiki/Base64#URL_applications - https://datatracker.ietf.org/doc/html/rfc4648 - - https://developer.okta.com/blog/2018/06/20/what-happens-if-your-jwt-is-stolen \ No newline at end of file + - https://developer.okta.com/blog/2018/06/20/what-happens-if-your-jwt-is-stolen + validation: + type: JWT \ No newline at end of file diff --git a/src/rules/rule.rs b/src/rules/rule.rs index bf923f3..a301a09 100644 --- a/src/rules/rule.rs +++ b/src/rules/rule.rs @@ -38,6 +38,7 @@ pub enum Validation { GCP, MongoDB, Postgres, + JWT, Raw(String), Http(HttpValidation), } diff --git a/src/validation.rs b/src/validation.rs index 44a7a6f..172888d 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -27,6 +27,7 @@ mod aws; mod azure; mod gcp; mod httpvalidation; +mod jwt; mod mongodb; mod postgres; mod utils; @@ -58,35 +59,6 @@ pub fn init_validation_caches() { IN_FLIGHT.set(DashMap::new()).ok(); } -// #[derive(Clone, FilterReflection, ParseFilter)] -// #[filter( -// name = "b64enc", -// description = "Encodes the input string using Base64 encoding", -// parsed(B64EncFilter) -// )] -// pub struct B64EncFilterParser; - -// #[derive(Debug, Default, Clone)] -// pub struct B64EncFilter; - -// impl std::fmt::Display for B64EncFilter { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// write!(f, "b64enc") -// } -// } - -// impl Filter for B64EncFilter { -// fn evaluate( -// &self, -// input: &dyn ValueView, -// _runtime: &dyn Runtime, -// ) -> Result { -// let input_str = input.to_kstr().into_owned(); -// let encoded = general_purpose::STANDARD.encode(input_str.as_bytes()); -// Ok(Value::scalar(encoded)) -// } -// } - #[derive(Clone)] pub struct CachedResponse { pub body: String, @@ -700,7 +672,36 @@ async fn timed_validate_single_match<'a>( }, ); } + // ---------------------------------------------------- JWT validator + Some(Validation::JWT) => { + let token = captured_values + .iter() + .find(|(n, ..)| n == "TOKEN") + .map(|(_, v, ..)| v.clone()) + .unwrap_or_default(); + if token.is_empty() { + m.validation_success = false; + m.validation_response_body = "JWT token not found.".to_string(); + m.validation_response_status = StatusCode::BAD_REQUEST; + commit_and_return(m); + return; + } + + match jwt::validate_jwt(&token, client).await { + Ok((ok, msg)) => { + m.validation_success = ok; + m.validation_response_body = msg; + m.validation_response_status = + if ok { StatusCode::OK } else { StatusCode::UNAUTHORIZED }; + } + Err(e) => { + m.validation_success = false; + m.validation_response_body = format!("JWT validation error: {}", e); + m.validation_response_status = StatusCode::BAD_REQUEST; + } + } + } // ---------------------------------------------------- AWS validator Some(Validation::AWS) => { let secret = captured_values diff --git a/src/validation/jwt.rs b/src/validation/jwt.rs new file mode 100644 index 0000000..ed2bd3c --- /dev/null +++ b/src/validation/jwt.rs @@ -0,0 +1,207 @@ +use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use chrono::Utc; +use ipnet::IpNet; +use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation as JwtValidation}; +use reqwest::{redirect::Policy, Client, Url}; +use serde::Deserialize; +use tokio::net::lookup_host; + +use super::utils::check_url_resolvable; + +/// RFC 1918 + loopback + link-local nets we refuse to contact +const BLOCKED_NETS: &[&str] = &[ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", // private + "127.0.0.0/8", + "169.254.0.0/16", // loopback / link-local +]; + +// aud is allowed to be either a string or an array, so let Serde flatten it. +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum Aud { + Str(String), + Arr(Vec), +} + +#[derive(Debug, Deserialize)] +struct Claims { + exp: Option, + nbf: Option, + iss: Option, + aud: Option, +} + +pub async fn validate_jwt(token: &str, client: &Client) -> Result<(bool, String)> { + // --- insecure payload decode ------------------------------------------------- + let claims: Claims = { + let payload_b64 = token.split('.').nth(1).ok_or_else(|| anyhow!("invalid JWT format"))?; + let payload_json = URL_SAFE_NO_PAD + .decode(payload_b64) + .map_err(|e| anyhow!("invalid base64 in payload: {e}"))?; + serde_json::from_slice(&payload_json).map_err(|e| anyhow!("invalid JSON claims: {e}"))? + }; + + // temporal checks + let now = Utc::now().timestamp(); + if let Some(nbf) = claims.nbf { + if now < nbf { + return Ok((false, format!("Token not valid before {nbf}"))); + } + } + if let Some(exp) = claims.exp { + if now > exp { + return Ok((false, format!("Token expired at {exp}"))); + } + } + + // --------------------------------------------------------------------------- + let issuer = claims.iss.clone().unwrap_or_default(); + + if let Some(iss) = claims.iss.clone() { + // parse header now (kid, alg) + let header = decode_header(token).map_err(|e| anyhow!("decode header: {e}"))?; + + // build discovery URL and fetch it (redirects disabled) + let config_url = format!("{}/.well-known/openid-configuration", iss.trim_end_matches('/')); + let no_redirect_client = Client::builder() + .redirect(Policy::none()) + .build() + .map_err(|e| anyhow!("client build: {e}"))?; + + let cfg_resp = no_redirect_client + .get(&config_url) + .send() + .await + .map_err(|e| anyhow!("issuer discovery failed: {e}"))?; + + if !cfg_resp.status().is_success() { + return Ok((false, format!("issuer discovery failed: {}", cfg_resp.status()))); + } + + let cfg_json: serde_json::Value = + cfg_resp.json().await.map_err(|e| anyhow!("invalid discovery JSON: {e}"))?; + + // extract jwks_uri + let jwks_uri = cfg_json + .get("jwks_uri") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow!("jwks_uri missing"))?; + + // must be HTTPS + let url = Url::parse(jwks_uri).map_err(|e| anyhow!("invalid jwks_uri: {e}"))?; + if url.scheme() != "https" { + return Ok((false, "jwks_uri must use https".to_string())); + } + + // host must match issuer host  —  prevents open redirects / SSRF-on-other-host + let iss_host = Url::parse(&iss) + .map_err(|e| anyhow!("invalid iss: {e}"))? + .host_str() + .unwrap_or_default() + .to_ascii_lowercase(); + let jwks_host = url.host_str().unwrap_or_default().to_ascii_lowercase(); + if jwks_host != iss_host { + return Ok(( + false, + format!("jwks_uri host ({jwks_host}) must match issuer host ({iss_host})"), + )); + } + + // ----------------------------------------------------------------------- + // DNS resolution + private-range block + for addr in lookup_host((jwks_host.as_str(), 443)).await? { + if is_blocked_ip(addr.ip()) { + return Ok((false, "jwks_uri resolves to private or link-local IP".to_string())); + } + } + + // reachability check (existing helper) + check_url_resolvable(&url).await.map_err(|e| anyhow!("jwks uri unresolvable: {e}"))?; + + // fetch JWKS with redirect-free client + let jwks_resp = no_redirect_client + .get(url) + .send() + .await + .map_err(|e| anyhow!("jwks fetch failed: {e}"))?; + if !jwks_resp.status().is_success() { + return Ok((false, format!("jwks fetch failed: {}", jwks_resp.status()))); + } + + let jwk_set: JwkSet = + jwks_resp.json().await.map_err(|e| anyhow!("invalid jwks json: {e}"))?; + + // select key by kid + let kid = header.kid.ok_or_else(|| anyhow!("no kid in header"))?; + let jwk = jwk_set + .keys + .iter() + .find(|k| k.common.key_id.as_deref() == Some(&kid)) + .ok_or_else(|| anyhow!("kid not found in jwks"))?; + + // verify signature + let decoding_key = DecodingKey::from_jwk(jwk).map_err(|e| anyhow!("invalid jwk: {e}"))?; + let mut validation = JwtValidation::new(header.alg); + validation.set_audience(&extract_aud_strings(&claims)); + validation.validate_exp = false; + validation.validate_nbf = false; + + decode::(token, &decoding_key, &validation) + .map_err(|e| anyhow!("signature verification failed: {e}"))?; + + return Ok(( + true, + format!("JWT valid (iss: {issuer}, aud: {:?})", extract_aud_strings(&claims)), + )); + } + + Ok((true, format!("JWT not expired (iss: {issuer}, aud: {:?})", extract_aud_strings(&claims)))) +} + +/// Helper: normalize aud into a flat Vec +fn extract_aud_strings(claims: &Claims) -> Vec { + match &claims.aud { + Some(Aud::Str(s)) => vec![s.clone()], + Some(Aud::Arr(v)) => v.clone(), + None => vec![], + } +} +/// returns true if IP is in a blocked network +fn is_blocked_ip(ip: std::net::IpAddr) -> bool { + BLOCKED_NETS.iter().filter_map(|cidr| cidr.parse::().ok()).any(|net| net.contains(&ip)) +} + +#[cfg(test)] +mod tests { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; + use chrono::{Duration as ChronoDuration, Utc}; + use reqwest::Client; + + use super::validate_jwt; + + fn build_token(exp_offset: i64) -> String { + let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#); + let exp = (Utc::now() + ChronoDuration::seconds(exp_offset)).timestamp(); + let payload = URL_SAFE_NO_PAD.encode(format!("{{\"exp\":{exp}}}")); + format!("{header}.{payload}.") + } + + #[tokio::test] + async fn valid_token() { + let token = build_token(60); + let client = Client::new(); + let res = validate_jwt(&token, &client).await.unwrap(); + assert!(res.0); + } + + #[tokio::test] + async fn expired_token() { + let token = build_token(-60); + let client = Client::new(); + let res = validate_jwt(&token, &client).await.unwrap(); + assert!(!res.0); + } +}