forked from mirrors/kingfisher
JWT validation performs OpenID Connect discovery using the iss claim and verifies signatures via JWKS
This commit is contained in:
parent
a5bdbeb313
commit
601ca05fc8
6 changed files with 245 additions and 30 deletions
|
|
@ -8,6 +8,8 @@ All notable changes to this project will be documented in this file.
|
|||
- Added baseline feature with `--baseline-file` and `--manage-baseline` flags
|
||||
- Introduced `--exclude` option for skipping paths
|
||||
- Added tests covering baseline and exclude workflow
|
||||
- Added validation for JWT tokens that checks `exp` and `nbf` claims
|
||||
- JWT validation performs OpenID Connect discovery using the `iss` claim and verifies signatures via JWKS
|
||||
|
||||
|
||||
## [1.20.0]
|
||||
|
|
|
|||
|
|
@ -162,6 +162,8 @@ atty = "0.2.14"
|
|||
self_update = { version = "0.42.0", default-features = false, features = ["rustls", "archive-tar", "archive-zip", "compression-flate2"] }
|
||||
semver = "1.0.26"
|
||||
globset = "0.4.16"
|
||||
jsonwebtoken = "9.3.1"
|
||||
ipnet = "2.11.0"
|
||||
|
||||
[dependencies.tikv-jemallocator]
|
||||
version = "0.6"
|
||||
|
|
|
|||
|
|
@ -22,4 +22,6 @@ rules:
|
|||
- https://datatracker.ietf.org/doc/html/rfc7519
|
||||
- https://en.wikipedia.org/wiki/Base64#URL_applications
|
||||
- https://datatracker.ietf.org/doc/html/rfc4648
|
||||
- https://developer.okta.com/blog/2018/06/20/what-happens-if-your-jwt-is-stolen
|
||||
- https://developer.okta.com/blog/2018/06/20/what-happens-if-your-jwt-is-stolen
|
||||
validation:
|
||||
type: JWT
|
||||
|
|
@ -38,6 +38,7 @@ pub enum Validation {
|
|||
GCP,
|
||||
MongoDB,
|
||||
Postgres,
|
||||
JWT,
|
||||
Raw(String),
|
||||
Http(HttpValidation),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ mod aws;
|
|||
mod azure;
|
||||
mod gcp;
|
||||
mod httpvalidation;
|
||||
mod jwt;
|
||||
mod mongodb;
|
||||
mod postgres;
|
||||
mod utils;
|
||||
|
|
@ -58,35 +59,6 @@ pub fn init_validation_caches() {
|
|||
IN_FLIGHT.set(DashMap::new()).ok();
|
||||
}
|
||||
|
||||
// #[derive(Clone, FilterReflection, ParseFilter)]
|
||||
// #[filter(
|
||||
// name = "b64enc",
|
||||
// description = "Encodes the input string using Base64 encoding",
|
||||
// parsed(B64EncFilter)
|
||||
// )]
|
||||
// pub struct B64EncFilterParser;
|
||||
|
||||
// #[derive(Debug, Default, Clone)]
|
||||
// pub struct B64EncFilter;
|
||||
|
||||
// impl std::fmt::Display for B64EncFilter {
|
||||
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// write!(f, "b64enc")
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl Filter for B64EncFilter {
|
||||
// fn evaluate(
|
||||
// &self,
|
||||
// input: &dyn ValueView,
|
||||
// _runtime: &dyn Runtime,
|
||||
// ) -> Result<Value, LiquidError> {
|
||||
// let input_str = input.to_kstr().into_owned();
|
||||
// let encoded = general_purpose::STANDARD.encode(input_str.as_bytes());
|
||||
// Ok(Value::scalar(encoded))
|
||||
// }
|
||||
// }
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CachedResponse {
|
||||
pub body: String,
|
||||
|
|
@ -700,7 +672,36 @@ async fn timed_validate_single_match<'a>(
|
|||
},
|
||||
);
|
||||
}
|
||||
// ---------------------------------------------------- JWT validator
|
||||
Some(Validation::JWT) => {
|
||||
let token = captured_values
|
||||
.iter()
|
||||
.find(|(n, ..)| n == "TOKEN")
|
||||
.map(|(_, v, ..)| v.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if token.is_empty() {
|
||||
m.validation_success = false;
|
||||
m.validation_response_body = "JWT token not found.".to_string();
|
||||
m.validation_response_status = StatusCode::BAD_REQUEST;
|
||||
commit_and_return(m);
|
||||
return;
|
||||
}
|
||||
|
||||
match jwt::validate_jwt(&token, client).await {
|
||||
Ok((ok, msg)) => {
|
||||
m.validation_success = ok;
|
||||
m.validation_response_body = msg;
|
||||
m.validation_response_status =
|
||||
if ok { StatusCode::OK } else { StatusCode::UNAUTHORIZED };
|
||||
}
|
||||
Err(e) => {
|
||||
m.validation_success = false;
|
||||
m.validation_response_body = format!("JWT validation error: {}", e);
|
||||
m.validation_response_status = StatusCode::BAD_REQUEST;
|
||||
}
|
||||
}
|
||||
}
|
||||
// ---------------------------------------------------- AWS validator
|
||||
Some(Validation::AWS) => {
|
||||
let secret = captured_values
|
||||
|
|
|
|||
207
src/validation/jwt.rs
Normal file
207
src/validation/jwt.rs
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::Utc;
|
||||
use ipnet::IpNet;
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation as JwtValidation};
|
||||
use reqwest::{redirect::Policy, Client, Url};
|
||||
use serde::Deserialize;
|
||||
use tokio::net::lookup_host;
|
||||
|
||||
use super::utils::check_url_resolvable;
|
||||
|
||||
/// RFC 1918 + loopback + link-local nets we refuse to contact
|
||||
const BLOCKED_NETS: &[&str] = &[
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16", // private
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16", // loopback / link-local
|
||||
];
|
||||
|
||||
// aud is allowed to be either a string or an array, so let Serde flatten it.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Aud {
|
||||
Str(String),
|
||||
Arr(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
exp: Option<i64>,
|
||||
nbf: Option<i64>,
|
||||
iss: Option<String>,
|
||||
aud: Option<Aud>,
|
||||
}
|
||||
|
||||
pub async fn validate_jwt(token: &str, client: &Client) -> Result<(bool, String)> {
|
||||
// --- insecure payload decode -------------------------------------------------
|
||||
let claims: Claims = {
|
||||
let payload_b64 = token.split('.').nth(1).ok_or_else(|| anyhow!("invalid JWT format"))?;
|
||||
let payload_json = URL_SAFE_NO_PAD
|
||||
.decode(payload_b64)
|
||||
.map_err(|e| anyhow!("invalid base64 in payload: {e}"))?;
|
||||
serde_json::from_slice(&payload_json).map_err(|e| anyhow!("invalid JSON claims: {e}"))?
|
||||
};
|
||||
|
||||
// temporal checks
|
||||
let now = Utc::now().timestamp();
|
||||
if let Some(nbf) = claims.nbf {
|
||||
if now < nbf {
|
||||
return Ok((false, format!("Token not valid before {nbf}")));
|
||||
}
|
||||
}
|
||||
if let Some(exp) = claims.exp {
|
||||
if now > exp {
|
||||
return Ok((false, format!("Token expired at {exp}")));
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
let issuer = claims.iss.clone().unwrap_or_default();
|
||||
|
||||
if let Some(iss) = claims.iss.clone() {
|
||||
// parse header now (kid, alg)
|
||||
let header = decode_header(token).map_err(|e| anyhow!("decode header: {e}"))?;
|
||||
|
||||
// build discovery URL and fetch it (redirects disabled)
|
||||
let config_url = format!("{}/.well-known/openid-configuration", iss.trim_end_matches('/'));
|
||||
let no_redirect_client = Client::builder()
|
||||
.redirect(Policy::none())
|
||||
.build()
|
||||
.map_err(|e| anyhow!("client build: {e}"))?;
|
||||
|
||||
let cfg_resp = no_redirect_client
|
||||
.get(&config_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("issuer discovery failed: {e}"))?;
|
||||
|
||||
if !cfg_resp.status().is_success() {
|
||||
return Ok((false, format!("issuer discovery failed: {}", cfg_resp.status())));
|
||||
}
|
||||
|
||||
let cfg_json: serde_json::Value =
|
||||
cfg_resp.json().await.map_err(|e| anyhow!("invalid discovery JSON: {e}"))?;
|
||||
|
||||
// extract jwks_uri
|
||||
let jwks_uri = cfg_json
|
||||
.get("jwks_uri")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow!("jwks_uri missing"))?;
|
||||
|
||||
// must be HTTPS
|
||||
let url = Url::parse(jwks_uri).map_err(|e| anyhow!("invalid jwks_uri: {e}"))?;
|
||||
if url.scheme() != "https" {
|
||||
return Ok((false, "jwks_uri must use https".to_string()));
|
||||
}
|
||||
|
||||
// host must match issuer host — prevents open redirects / SSRF-on-other-host
|
||||
let iss_host = Url::parse(&iss)
|
||||
.map_err(|e| anyhow!("invalid iss: {e}"))?
|
||||
.host_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
let jwks_host = url.host_str().unwrap_or_default().to_ascii_lowercase();
|
||||
if jwks_host != iss_host {
|
||||
return Ok((
|
||||
false,
|
||||
format!("jwks_uri host ({jwks_host}) must match issuer host ({iss_host})"),
|
||||
));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// DNS resolution + private-range block
|
||||
for addr in lookup_host((jwks_host.as_str(), 443)).await? {
|
||||
if is_blocked_ip(addr.ip()) {
|
||||
return Ok((false, "jwks_uri resolves to private or link-local IP".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// reachability check (existing helper)
|
||||
check_url_resolvable(&url).await.map_err(|e| anyhow!("jwks uri unresolvable: {e}"))?;
|
||||
|
||||
// fetch JWKS with redirect-free client
|
||||
let jwks_resp = no_redirect_client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("jwks fetch failed: {e}"))?;
|
||||
if !jwks_resp.status().is_success() {
|
||||
return Ok((false, format!("jwks fetch failed: {}", jwks_resp.status())));
|
||||
}
|
||||
|
||||
let jwk_set: JwkSet =
|
||||
jwks_resp.json().await.map_err(|e| anyhow!("invalid jwks json: {e}"))?;
|
||||
|
||||
// select key by kid
|
||||
let kid = header.kid.ok_or_else(|| anyhow!("no kid in header"))?;
|
||||
let jwk = jwk_set
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| anyhow!("kid not found in jwks"))?;
|
||||
|
||||
// verify signature
|
||||
let decoding_key = DecodingKey::from_jwk(jwk).map_err(|e| anyhow!("invalid jwk: {e}"))?;
|
||||
let mut validation = JwtValidation::new(header.alg);
|
||||
validation.set_audience(&extract_aud_strings(&claims));
|
||||
validation.validate_exp = false;
|
||||
validation.validate_nbf = false;
|
||||
|
||||
decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| anyhow!("signature verification failed: {e}"))?;
|
||||
|
||||
return Ok((
|
||||
true,
|
||||
format!("JWT valid (iss: {issuer}, aud: {:?})", extract_aud_strings(&claims)),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((true, format!("JWT not expired (iss: {issuer}, aud: {:?})", extract_aud_strings(&claims))))
|
||||
}
|
||||
|
||||
/// Helper: normalize aud into a flat Vec<String>
|
||||
fn extract_aud_strings(claims: &Claims) -> Vec<String> {
|
||||
match &claims.aud {
|
||||
Some(Aud::Str(s)) => vec![s.clone()],
|
||||
Some(Aud::Arr(v)) => v.clone(),
|
||||
None => vec![],
|
||||
}
|
||||
}
|
||||
/// returns true if IP is in a blocked network
|
||||
fn is_blocked_ip(ip: std::net::IpAddr) -> bool {
|
||||
BLOCKED_NETS.iter().filter_map(|cidr| cidr.parse::<IpNet>().ok()).any(|net| net.contains(&ip))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use reqwest::Client;
|
||||
|
||||
use super::validate_jwt;
|
||||
|
||||
fn build_token(exp_offset: i64) -> String {
|
||||
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
|
||||
let exp = (Utc::now() + ChronoDuration::seconds(exp_offset)).timestamp();
|
||||
let payload = URL_SAFE_NO_PAD.encode(format!("{{\"exp\":{exp}}}"));
|
||||
format!("{header}.{payload}.")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_token() {
|
||||
let token = build_token(60);
|
||||
let client = Client::new();
|
||||
let res = validate_jwt(&token, &client).await.unwrap();
|
||||
assert!(res.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expired_token() {
|
||||
let token = build_token(-60);
|
||||
let client = Client::new();
|
||||
let res = validate_jwt(&token, &client).await.unwrap();
|
||||
assert!(!res.0);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue