forked from mirrors/kingfisher
Fixed AWS access key validation to support temporary/session keys (ASIA prefix) in addition to long-lived keys (AKIA prefix).
This commit is contained in:
parent
3f0fa7afde
commit
1a40fb3bfd
20 changed files with 250 additions and 1810 deletions
|
|
@ -6,6 +6,8 @@ All notable changes to this project will be documented in this file.
|
|||
- Added revocation support for SendGrid, Tailscale, MongoDB Atlas, Twilio, and NPM using multi-step (lookup ID then delete) pattern.
|
||||
- Added new Sumo Logic rule with direct revocation support.
|
||||
- Added `docs/TOKEN_REVOCATION_SUPPORT.md` with detailed revocation implementation guide and testing examples.
|
||||
- Fixed AWS access key validation to support temporary/session keys (ASIA prefix) in addition to long-lived keys (AKIA prefix).
|
||||
- Consolidated all validator implementations into the `kingfisher-scanner` crate to eliminate code duplication. Validators for AWS, Azure, Coinbase, GCP, JWT, JDBC, MongoDB, MySQL, Postgres, and HTTP are now maintained in a single location with proper feature gating.
|
||||
|
||||
## [v1.78.0]c
|
||||
- Added "Skipped Validations" counter to scan summary output to distinguish between validations that failed (HTTP errors, connection failures) and validations that were skipped due to missing preconditions (e.g., missing dependent rules). This provides better visibility into validation coverage for large scans.
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ assets = [
|
|||
# Library crates
|
||||
kingfisher-core = { path = "crates/kingfisher-core" }
|
||||
kingfisher-rules = { path = "crates/kingfisher-rules" }
|
||||
kingfisher-scanner = { path = "crates/kingfisher-scanner" }
|
||||
kingfisher-scanner = { path = "crates/kingfisher-scanner", features = ["validation-all"] }
|
||||
|
||||
clap = { version = "4.5", features = [
|
||||
"cargo",
|
||||
|
|
|
|||
|
|
@ -1,240 +0,0 @@
|
|||
# Example rules demonstrating multi-step revocation
|
||||
#
|
||||
# This file shows how to configure 2-step revocation processes for services
|
||||
# that require looking up an ID before performing the actual revocation.
|
||||
|
||||
rules:
|
||||
# Example 1: Basic 2-step revocation with JSON extraction
|
||||
- name: Example API Token (2-step revocation)
|
||||
id: kingfisher.example_multistep.1
|
||||
pattern: |
|
||||
(?xi)
|
||||
example_api_token_
|
||||
[A-Za-z0-9]{40}
|
||||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- "EXAMPLE_TOKEN=example_api_token_abc123def456ghi789jkl012mno345pqrs"
|
||||
references:
|
||||
- https://example.com/docs/api-tokens
|
||||
|
||||
# Standard single-step validation
|
||||
validation:
|
||||
type: Http
|
||||
content:
|
||||
request:
|
||||
method: GET
|
||||
url: https://api.example.com/v1/auth/verify
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200]
|
||||
|
||||
# Multi-step revocation: lookup ID first, then delete
|
||||
revocation:
|
||||
type: HttpMultiStep
|
||||
content:
|
||||
steps:
|
||||
# Step 1: Get the token's internal ID
|
||||
- name: lookup_token_id
|
||||
request:
|
||||
method: GET
|
||||
url: https://api.example.com/v1/tokens/current
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
Accept: application/json
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200]
|
||||
- type: JsonValid
|
||||
extract:
|
||||
# Extract the token ID from JSON response
|
||||
TOKEN_ID:
|
||||
type: JsonPath
|
||||
path: "$.data.token_id"
|
||||
|
||||
# Step 2: Delete the token using its ID
|
||||
- name: delete_token
|
||||
request:
|
||||
method: DELETE
|
||||
url: https://api.example.com/v1/tokens/{{ TOKEN_ID }}
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- report_response: true
|
||||
- type: StatusMatch
|
||||
status: [204, 200]
|
||||
|
||||
# Example 2: Multi-step with multiple extractions
|
||||
- name: Complex Service Token (multi-extraction)
|
||||
id: kingfisher.example_multistep.2
|
||||
pattern: |
|
||||
(?xi)
|
||||
complex_token_
|
||||
[A-Za-z0-9_-]{32,}
|
||||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- "TOKEN=complex_token_xyz789_abc123_def456_ghi789"
|
||||
|
||||
revocation:
|
||||
type: HttpMultiStep
|
||||
content:
|
||||
steps:
|
||||
# Step 1: Get multiple pieces of information
|
||||
- name: get_token_metadata
|
||||
request:
|
||||
method: GET
|
||||
url: https://api.complex.com/v2/tokens/info
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200]
|
||||
extract:
|
||||
# Extract from JSON body using JSONPath
|
||||
TOKEN_ID:
|
||||
type: JsonPath
|
||||
path: "$.id"
|
||||
|
||||
# Extract from response header
|
||||
ACCOUNT_ID:
|
||||
type: Header
|
||||
name: X-Account-ID
|
||||
|
||||
# Extract nested JSON field
|
||||
WORKSPACE_ID:
|
||||
type: JsonPath
|
||||
path: "$.workspace.id"
|
||||
|
||||
# Step 2: Use all extracted values in revocation request
|
||||
- name: revoke_token
|
||||
request:
|
||||
method: POST
|
||||
url: https://api.complex.com/v2/accounts/{{ ACCOUNT_ID }}/workspaces/{{ WORKSPACE_ID }}/tokens/{{ TOKEN_ID }}/revoke
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
Content-Type: application/json
|
||||
body: '{"reason":"Token compromised","force":true}'
|
||||
response_matcher:
|
||||
- report_response: true
|
||||
- type: StatusMatch
|
||||
status: [200, 204]
|
||||
- type: WordMatch
|
||||
words: ['"success":true']
|
||||
|
||||
# Example 3: Using regex extraction
|
||||
- name: Service With XML Response
|
||||
id: kingfisher.example_multistep.3
|
||||
pattern: |
|
||||
(?xi)
|
||||
xml_service_key_
|
||||
[A-Fa-f0-9]{32}
|
||||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- "KEY=xml_service_key_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6"
|
||||
|
||||
revocation:
|
||||
type: HttpMultiStep
|
||||
content:
|
||||
steps:
|
||||
# Step 1: Parse XML response with regex
|
||||
- name: get_key_id_from_xml
|
||||
request:
|
||||
method: GET
|
||||
url: https://api.xmlservice.com/keys/current
|
||||
headers:
|
||||
X-API-Key: "{{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200]
|
||||
- type: XmlValid
|
||||
extract:
|
||||
# Use regex to extract from XML
|
||||
KEY_ID:
|
||||
type: Regex
|
||||
pattern: '<KeyId>([^<]+)</KeyId>'
|
||||
|
||||
# Step 2: Delete using extracted ID
|
||||
- name: delete_key
|
||||
request:
|
||||
method: DELETE
|
||||
url: https://api.xmlservice.com/keys/{{ KEY_ID }}
|
||||
headers:
|
||||
X-API-Key: "{{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200, 204]
|
||||
|
||||
# Example 4: Single-step (for comparison)
|
||||
# This shows that simple revocations don't need the multi-step approach
|
||||
- name: Simple Service Token (single-step)
|
||||
id: kingfisher.example_multistep.4
|
||||
pattern: |
|
||||
(?xi)
|
||||
simple_token_
|
||||
[A-Za-z0-9]{32}
|
||||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- "TOKEN=simple_token_abcdefghijklmnopqrstuvwxyz123456"
|
||||
|
||||
# This service accepts the token directly for revocation
|
||||
revocation:
|
||||
type: Http
|
||||
content:
|
||||
request:
|
||||
method: DELETE
|
||||
url: https://api.simple.com/v1/tokens/current
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- report_response: true
|
||||
- type: StatusMatch
|
||||
status: [204]
|
||||
|
||||
# Example 5: Array extraction from JSON
|
||||
- name: Service With Array Response
|
||||
id: kingfisher.example_multistep.5
|
||||
pattern: |
|
||||
(?xi)
|
||||
array_service_
|
||||
[A-Za-z0-9]{28}
|
||||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- "TOKEN=array_service_abcdefghijklmnopqrstuvwxyz12"
|
||||
|
||||
revocation:
|
||||
type: HttpMultiStep
|
||||
content:
|
||||
steps:
|
||||
# Step 1: Get the first session ID from an array
|
||||
- name: get_session_id
|
||||
request:
|
||||
method: GET
|
||||
url: https://api.arrayservice.com/sessions
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [200]
|
||||
- type: JsonValid
|
||||
extract:
|
||||
# Extract first element from array
|
||||
SESSION_ID:
|
||||
type: JsonPath
|
||||
path: "$.sessions[0].id"
|
||||
|
||||
# Step 2: Terminate the session
|
||||
- name: terminate_session
|
||||
request:
|
||||
method: DELETE
|
||||
url: https://api.arrayservice.com/sessions/{{ SESSION_ID }}
|
||||
headers:
|
||||
Authorization: "Bearer {{ TOKEN }}"
|
||||
response_matcher:
|
||||
- type: StatusMatch
|
||||
status: [204]
|
||||
|
|
@ -2,7 +2,7 @@ rules:
|
|||
- name: Sumo Logic Access ID
|
||||
id: kingfisher.sumologic.1
|
||||
pattern: |
|
||||
(?x)
|
||||
(?xi)
|
||||
\b
|
||||
sumo
|
||||
(?:.|[\n\r]){0,32}?
|
||||
|
|
@ -19,8 +19,7 @@ rules:
|
|||
confidence: medium
|
||||
visible: false
|
||||
examples:
|
||||
- sumo_access_id=suABCDEF1234567890XYZABC
|
||||
- 'SUMO_ACCESS_ID: suXYZ123456ABC789DEF012'
|
||||
- 'config.sumologic.access.id = "suK9mP2nQ7rT4wX8"'
|
||||
|
||||
- name: Sumo Logic Access Key
|
||||
id: kingfisher.sumologic.2
|
||||
|
|
@ -43,8 +42,7 @@ rules:
|
|||
min_entropy: 3.5
|
||||
confidence: medium
|
||||
examples:
|
||||
- sumo_access_key=ABCdef123456XYZabc789012DEFghi345678PQRstu
|
||||
- 'SUMO_ACCESS_KEY: XYZ123abc456DEF789ghi012JKL345mno678PQR901stu'
|
||||
- '// SumoLogic Private Token: M7nP4qR2tV9wX5yZ8aB1cD3eF5gH7iJ9kL2mN4oP6qR8sT0uV2wX4yZ6aB8cD0eF'
|
||||
references:
|
||||
- https://help.sumologic.com/docs/manage/security/access-keys/
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ validation-aws = [
|
|||
"validation-http",
|
||||
"dep:aws-config",
|
||||
"dep:aws-credential-types",
|
||||
"dep:aws-sdk-iam",
|
||||
"dep:aws-sdk-sts",
|
||||
"dep:aws-types",
|
||||
"dep:aws-smithy-http-client",
|
||||
|
|
@ -41,10 +42,67 @@ validation-aws = [
|
|||
"dep:rand",
|
||||
]
|
||||
|
||||
# Azure credential validation
|
||||
validation-azure = [
|
||||
"validation-http",
|
||||
"dep:chrono",
|
||||
"dep:hmac",
|
||||
"dep:sha2",
|
||||
]
|
||||
|
||||
# Coinbase credential validation
|
||||
validation-coinbase = [
|
||||
"validation-http",
|
||||
"dep:chrono",
|
||||
"dep:ed25519-dalek",
|
||||
"dep:p256",
|
||||
"dep:rand",
|
||||
"dep:hex",
|
||||
]
|
||||
|
||||
# GCP credential validation
|
||||
validation-gcp = [
|
||||
"validation-http",
|
||||
"dep:chrono",
|
||||
"dep:pem",
|
||||
"dep:percent-encoding",
|
||||
"dep:ring",
|
||||
"dep:tokio",
|
||||
]
|
||||
|
||||
# JWT validation
|
||||
validation-jwt = [
|
||||
"validation-http",
|
||||
"dep:chrono",
|
||||
"dep:ipnet",
|
||||
"dep:jsonwebtoken",
|
||||
"dep:serde",
|
||||
"dep:tokio",
|
||||
]
|
||||
|
||||
# Database validation (MongoDB/MySQL/Postgres/JDBC)
|
||||
validation-database = [
|
||||
"validation-http",
|
||||
"dep:bson",
|
||||
"dep:mongodb",
|
||||
"dep:mysql_async",
|
||||
"dep:tokio-postgres",
|
||||
"dep:tokio-postgres-rustls",
|
||||
"dep:rustls",
|
||||
"dep:rustls-native-certs",
|
||||
"dep:url",
|
||||
"dep:sha1",
|
||||
]
|
||||
|
||||
# All validation features
|
||||
validation-all = [
|
||||
"validation",
|
||||
"validation-aws",
|
||||
"validation-azure",
|
||||
"validation-coinbase",
|
||||
"validation-gcp",
|
||||
"validation-jwt",
|
||||
"validation-database",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
|
|
@ -57,7 +115,7 @@ anyhow = "1.0"
|
|||
thiserror = "1.0"
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde = { version = "1.0", features = ["derive"], optional = true }
|
||||
serde_json = "1.0"
|
||||
schemars = "0.8"
|
||||
|
||||
|
|
@ -78,6 +136,7 @@ rustc-hash = "2.1"
|
|||
parking_lot = "0.12"
|
||||
thread_local = "1.1"
|
||||
once_cell = "1.21"
|
||||
crossbeam-skiplist = "0.1.3"
|
||||
|
||||
# HTTP status codes
|
||||
http = "1.4"
|
||||
|
|
@ -102,10 +161,30 @@ liquid = { version = "0.26", optional = true }
|
|||
liquid-core = { version = "0.26", optional = true }
|
||||
quick-xml = { version = "0.38", features = ["serde", "serialize"], optional = true }
|
||||
sha1 = { version = "0.10", optional = true }
|
||||
chrono = { version = "0.4.42", optional = true }
|
||||
hmac = { version = "0.12", optional = true }
|
||||
sha2 = { version = "0.10", optional = true }
|
||||
pem = { version = "3.0.6", optional = true }
|
||||
percent-encoding = { version = "2.3.2", optional = true }
|
||||
ring = { version = "0.17", optional = true }
|
||||
ipnet = { version = "2.11", optional = true }
|
||||
jsonwebtoken = { version = "10.2.0", features = ["aws-lc-rs"], optional = true }
|
||||
p256 = { version = "0.13.2", optional = true }
|
||||
ed25519-dalek = { version = "2.2", features = ["pkcs8"], optional = true }
|
||||
hex = { version = "0.4.3", optional = true }
|
||||
url = { version = "2.5.7", optional = true }
|
||||
bson = { version = "2.15.0", optional = true }
|
||||
mongodb = { version = "3.4", default-features = false, features = ["rustls-tls", "aws-auth", "compat-3-0-0", "dns-resolver"], optional = true }
|
||||
mysql_async = { version = "0.34.2", default-features = false, features = ["default-rustls"], optional = true }
|
||||
tokio-postgres = { version = "0.7", default-features = false, features = ["runtime"], optional = true }
|
||||
tokio-postgres-rustls = { version = "0.13.0", optional = true }
|
||||
rustls = { version = "0.23.35", optional = true }
|
||||
rustls-native-certs = { version = "0.8.2", optional = true }
|
||||
|
||||
# AWS validation
|
||||
aws-config = { version = "1.8", optional = true }
|
||||
aws-credential-types = { version = "1.2", optional = true }
|
||||
aws-sdk-iam = { version = "1.101.0", optional = true }
|
||||
aws-sdk-sts = { version = "1.95", optional = true }
|
||||
aws-types = { version = "1.3", optional = true }
|
||||
aws-smithy-http-client = { version = "1.1", optional = true }
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ use std::{collections::HashSet, sync::RwLock, time::Duration};
|
|||
use anyhow::{anyhow, Result};
|
||||
use aws_config::{retry::RetryConfig, BehaviorVersion, SdkConfig};
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sdk_iam::{
|
||||
config::Builder as IamConfigBuilder, error::SdkError as IamSdkError,
|
||||
operation::update_access_key::UpdateAccessKeyError, types::StatusType, Client as IamClient,
|
||||
};
|
||||
use aws_sdk_sts::{
|
||||
config::Builder as StsConfigBuilder, error::SdkError,
|
||||
operation::get_caller_identity::GetCallerIdentityError, Client as StsClient,
|
||||
|
|
@ -188,18 +192,24 @@ pub fn generate_aws_cache_key(aws_access_key_id: &str, aws_secret_access_key: &s
|
|||
|
||||
/// Validate AWS credentials format before attempting validation.
|
||||
pub fn validate_aws_credentials_input(access_key_id: &str, secret_key: &str) -> Result<(), String> {
|
||||
// Validate access key ID format (typically starts with "AKIA" and is 20 chars)
|
||||
if !access_key_id.starts_with("AKIA") || access_key_id.len() != 20 {
|
||||
// Validate access key ID format (20 chars, known AWS prefixes including STS)
|
||||
if access_key_id.len() != 20 {
|
||||
return Err("Invalid AWS access key ID format".to_string());
|
||||
}
|
||||
if !access_key_id.chars().all(|c| c.is_ascii_alphanumeric()) {
|
||||
return Err("AWS access key ID contains invalid characters".to_string());
|
||||
}
|
||||
let prefix = &access_key_id[..4];
|
||||
let valid_prefix =
|
||||
matches!(prefix, "AKIA" | "AGPA" | "AIDA" | "AROA" | "AIPA" | "ANPA" | "ANVA" | "ASIA")
|
||||
|| prefix.starts_with("A3T");
|
||||
if !valid_prefix {
|
||||
return Err("Invalid AWS access key ID format".to_string());
|
||||
}
|
||||
// Validate secret key format (should be at least 40 chars)
|
||||
if secret_key.len() < 40 {
|
||||
return Err("Invalid AWS secret key format".to_string());
|
||||
}
|
||||
// Check for invalid characters
|
||||
if !access_key_id.chars().all(|c| c.is_ascii_alphanumeric()) {
|
||||
return Err("AWS access key ID contains invalid characters".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -222,6 +232,84 @@ fn is_throttling_or_transient(e: &SdkError<GetCallerIdentityError>) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_iam_throttling_or_transient(e: &IamSdkError<UpdateAccessKeyError>) -> bool {
|
||||
match e {
|
||||
IamSdkError::ServiceError(ctx) => {
|
||||
let code = ctx.err().meta().code().unwrap_or_default();
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
code.contains("Throttl")
|
||||
|| status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
IamSdkError::DispatchFailure(df) => df.is_timeout() || df.is_io(),
|
||||
IamSdkError::ResponseError(ctx) => {
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Revoke (deactivate) an AWS access key via IAM.
|
||||
pub async fn revoke_aws_access_key(
|
||||
aws_access_key_id: &str,
|
||||
aws_secret_access_key: &str,
|
||||
) -> Result<(bool, String)> {
|
||||
// Create static credentials
|
||||
let credentials = Credentials::new(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
None, // session token
|
||||
None, // expiry
|
||||
"static", // provider name
|
||||
);
|
||||
let config = build_base_config(credentials).await;
|
||||
|
||||
// Create IAM client
|
||||
let iam_config = IamConfigBuilder::from(&config).interceptor(UaInterceptor).build();
|
||||
let iam_client = IamClient::from_conf(iam_config);
|
||||
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
const ATTEMPT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
for attempt in 1..=MAX_ATTEMPTS {
|
||||
let result = timeout(
|
||||
ATTEMPT_TIMEOUT,
|
||||
iam_client
|
||||
.update_access_key()
|
||||
.access_key_id(aws_access_key_id)
|
||||
.status(StatusType::Inactive)
|
||||
.send(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(_)) => {
|
||||
return Ok((true, "AWS access key set to Inactive".to_string()));
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
if is_iam_throttling_or_transient(&e) {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS revocation failed: {}", e));
|
||||
}
|
||||
} else {
|
||||
return Ok((false, e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS revocation timed out"));
|
||||
}
|
||||
}
|
||||
}
|
||||
let max_delay = 100u64 * 2u64.pow((attempt - 1) as u32);
|
||||
let sleep_ms = rng().random_range(0..=max_delay);
|
||||
sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
Err(anyhow!("AWS revocation failed"))
|
||||
}
|
||||
|
||||
/// Validate AWS credentials by calling STS GetCallerIdentity.
|
||||
///
|
||||
/// Returns `(is_valid, message)` where message is the ARN on success or an error message.
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ use reqwest::{header::HeaderValue, Client};
|
|||
use serde_json::Value as JsonValue;
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::{
|
||||
validation::{Cache, CachedResponse, ValidationResponseBody, VALIDATION_CACHE_SECONDS},
|
||||
validation_body,
|
||||
use super::{
|
||||
validation_body, Cache, CachedResponse, ValidationResponseBody, VALIDATION_CACHE_SECONDS,
|
||||
};
|
||||
|
||||
pub fn generate_azure_cache_key(azure_json: &str) -> String {
|
||||
|
|
@ -22,14 +21,13 @@ pub fn generate_azure_cache_key(azure_json: &str) -> String {
|
|||
format!("AZURE:{:x}", h.finalize())
|
||||
}
|
||||
|
||||
/// Validate Azure Storage credentials without Azure SDK crates
|
||||
/// Validate Azure Storage credentials without Azure SDK crates.
|
||||
pub async fn validate_azure_storage_credentials(
|
||||
azure_json: &str,
|
||||
cache: &Cache,
|
||||
) -> Result<(bool, ValidationResponseBody)> {
|
||||
let cache_key = generate_azure_cache_key(azure_json);
|
||||
|
||||
/* ── short-circuit cached result ───────────────────────────── */
|
||||
if let Some(e) = cache.get(&cache_key) {
|
||||
let c = e.value();
|
||||
if c.timestamp.elapsed() < Duration::from_secs(VALIDATION_CACHE_SECONDS) {
|
||||
|
|
@ -37,7 +35,6 @@ pub async fn validate_azure_storage_credentials(
|
|||
}
|
||||
}
|
||||
|
||||
/* ── pull account + key from caller JSON ──────────────────── */
|
||||
let tok: JsonValue = serde_json::from_str(azure_json)?;
|
||||
let storage_account = tok["storage_account"].as_str().unwrap_or("");
|
||||
let storage_key = tok["storage_key"].as_str().unwrap_or("");
|
||||
|
|
@ -48,12 +45,10 @@ pub async fn validate_azure_storage_credentials(
|
|||
return Ok((false, msg));
|
||||
}
|
||||
|
||||
/* ── build SignedKey GET /?comp=list ──────────────────────── */
|
||||
let now_rfc = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
|
||||
let url =
|
||||
format!("https://{account}.blob.core.windows.net/?comp=list", account = storage_account);
|
||||
|
||||
// canonical string-to-sign per MSFT docs .
|
||||
let canon_headers = format!("x-ms-date:{now_rfc}\nx-ms-version:2023-11-03\n");
|
||||
let canon_resource = format!("/{account}/\ncomp:list", account = storage_account);
|
||||
let string_to_sign = format!(
|
||||
|
|
@ -62,7 +57,6 @@ pub async fn validate_azure_storage_credentials(
|
|||
resource = canon_resource
|
||||
);
|
||||
|
||||
// HMAC-SHA256 -- Base64
|
||||
let key_bytes = b64.decode(storage_key)?;
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_from_slice(&key_bytes).map_err(|_| anyhow!("invalid key length"))?;
|
||||
|
|
@ -84,7 +78,6 @@ pub async fn validate_azure_storage_credentials(
|
|||
let client = Client::builder().build()?;
|
||||
let resp = client.get(&url).headers(hdrs).send().await?;
|
||||
|
||||
/* ── capture status before `.text()` consumes resp ────────── */
|
||||
let status = resp.status();
|
||||
let body_txt = resp.text().await?;
|
||||
|
||||
|
|
@ -95,7 +88,6 @@ pub async fn validate_azure_storage_credentials(
|
|||
return Err(anyhow!(body));
|
||||
}
|
||||
|
||||
// parse XML payload
|
||||
let mut reader = Reader::from_str(&body_txt);
|
||||
reader.config_mut().trim_text(true);
|
||||
let mut buf = Vec::new();
|
||||
|
|
@ -114,7 +106,6 @@ pub async fn validate_azure_storage_credentials(
|
|||
buf.clear();
|
||||
}
|
||||
|
||||
/* ── success ─────────────────────────────────────────────── */
|
||||
let body = format!("Account: {}; Containers: {:?}", storage_account, names);
|
||||
let body_opt = validation_body::from_string(body);
|
||||
cache.insert(cache_key, CachedResponse::new(body_opt.clone(), StatusCode::OK, true));
|
||||
|
|
@ -15,11 +15,9 @@ use rand::TryRngCore;
|
|||
use reqwest::{Client, StatusCode, Url};
|
||||
use sha1::{Digest, Sha1};
|
||||
|
||||
use crate::{
|
||||
validation::{
|
||||
httpvalidation, Cache, CachedResponse, ValidationResponseBody, VALIDATION_CACHE_SECONDS,
|
||||
},
|
||||
validation_body,
|
||||
use super::http_validation as httpvalidation;
|
||||
use super::{
|
||||
validation_body, Cache, CachedResponse, ValidationResponseBody, VALIDATION_CACHE_SECONDS,
|
||||
};
|
||||
|
||||
pub fn generate_coinbase_cache_key(cred_name: &str, private_key: &str) -> String {
|
||||
|
|
@ -89,7 +87,6 @@ fn build_jwt(
|
|||
|
||||
let _ = rng.try_fill_bytes(&mut nonce);
|
||||
|
||||
// Try ECDSA (PEM encoded EC key). Fallback to raw Ed25519 base64 key.
|
||||
if let Ok(secret_key) =
|
||||
SecretKey::from_sec1_pem(&pem).or_else(|_| SecretKey::from_pkcs8_pem(&pem))
|
||||
{
|
||||
|
|
@ -118,7 +115,6 @@ fn build_jwt(
|
|||
|
||||
return Ok(format!("{signing_input}.{sig_b64}"));
|
||||
} else {
|
||||
// Assume base64-encoded Ed25519 keypair
|
||||
let key_bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(pem.as_bytes())
|
||||
.map_err(|e| anyhow!("invalid base64 key: {e}"))?;
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::validation::GLOBAL_USER_AGENT;
|
||||
use anyhow::{anyhow, Result};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
|
|
@ -13,6 +12,8 @@ use serde_json::Value as JsonValue;
|
|||
use tokio::sync::Semaphore;
|
||||
use tracing::debug;
|
||||
|
||||
use super::GLOBAL_USER_AGENT;
|
||||
|
||||
static GLOBAL_VALIDATOR: OnceCell<GcpValidator> = OnceCell::new();
|
||||
|
||||
pub struct GcpValidator {
|
||||
|
|
@ -52,7 +53,6 @@ impl GcpValidator {
|
|||
let _permit = self.semaphore.acquire().await?;
|
||||
let token_info: JsonValue = serde_json::from_str(gcp_json)?;
|
||||
|
||||
// Extract required fields.
|
||||
let project_id = token_info["project_id"].as_str().unwrap_or("").to_string();
|
||||
let client_email = token_info["client_email"].as_str().unwrap_or("").to_string();
|
||||
let private_key = token_info["private_key"].as_str().unwrap_or("").to_string();
|
||||
|
|
@ -185,7 +185,6 @@ impl GcpValidator {
|
|||
let iat = now.timestamp();
|
||||
let exp = (now + ChronoDuration::hours(1)).timestamp();
|
||||
|
||||
// JWT Header and Claims.
|
||||
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
|
||||
let claims = format!(
|
||||
r#"{{
|
||||
|
|
@ -200,12 +199,10 @@ impl GcpValidator {
|
|||
let claims_encoded = URL_SAFE_NO_PAD.encode(claims);
|
||||
let message = format!("{}.{}", header, claims_encoded);
|
||||
|
||||
// Parse PEM and create RSA key pair.
|
||||
let pem = parse(private_key_pem).map_err(|e| anyhow!("Failed to parse PEM: {}", e))?;
|
||||
let key_pair = signature::RsaKeyPair::from_pkcs8(&pem.contents())
|
||||
.map_err(|_| anyhow!("Invalid RSA private key"))?;
|
||||
|
||||
// Sign the message.
|
||||
let rng = rand::SystemRandom::new();
|
||||
let mut signature = vec![0; key_pair.public().modulus_len()];
|
||||
key_pair
|
||||
|
|
@ -1,11 +1,4 @@
|
|||
//! HTTP-based credential validation.
|
||||
//!
|
||||
//! This module provides utilities for validating credentials via HTTP requests.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use std::{collections::BTreeMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Error, Result};
|
||||
use http::StatusCode;
|
||||
|
|
@ -25,11 +18,6 @@ use super::GLOBAL_USER_AGENT;
|
|||
use kingfisher_rules::ResponseMatcher;
|
||||
|
||||
/// Build a deterministic cache key from the immutable parts of an HTTP request.
|
||||
///
|
||||
/// * `method` – case-insensitive HTTP verb ("GET", "POST"…)
|
||||
/// * `url` – fully-qualified URL (any query string should already be present)
|
||||
/// * `headers` – *logical* headers you intend to send (template-rendered)
|
||||
/// * `body` – optional request body
|
||||
pub fn generate_http_cache_key_parts(
|
||||
method: &str,
|
||||
url: &Url,
|
||||
|
|
@ -45,7 +33,6 @@ pub fn generate_http_cache_key_parts(
|
|||
hasher.update(url.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
|
||||
// Collect headers sorted lexicographically (BTreeMap is already sorted)
|
||||
for (k, v) in headers {
|
||||
hasher.update(k.as_bytes());
|
||||
hasher.update(b":");
|
||||
|
|
@ -53,7 +40,6 @@ pub fn generate_http_cache_key_parts(
|
|||
hasher.update(b"\0");
|
||||
}
|
||||
|
||||
// Include the request body in the cache key if present
|
||||
if let Some(b) = body {
|
||||
hasher.update(b"BODY\0");
|
||||
hasher.update(b.as_bytes());
|
||||
|
|
@ -87,7 +73,6 @@ pub fn build_request_builder(
|
|||
let custom_headers = process_headers(headers, parser, globals, url)
|
||||
.map_err(|e| format!("Error processing headers: {}", e))?;
|
||||
|
||||
// Prepare a standard set of headers
|
||||
let user_agent = GLOBAL_USER_AGENT.as_str();
|
||||
let standard_headers = [
|
||||
(header::USER_AGENT, user_agent),
|
||||
|
|
@ -99,7 +84,6 @@ pub fn build_request_builder(
|
|||
(header::ACCEPT_ENCODING, "gzip, deflate, br"),
|
||||
(header::CONNECTION, "keep-alive"),
|
||||
];
|
||||
|
||||
let mut combined_headers = HeaderMap::new();
|
||||
for (name, value) in &standard_headers {
|
||||
if let Ok(hv) = HeaderValue::from_str(value) {
|
||||
|
|
@ -111,7 +95,6 @@ pub fn build_request_builder(
|
|||
}
|
||||
request_builder = request_builder.headers(combined_headers);
|
||||
|
||||
// If a body template is provided, parse and render it
|
||||
if let Some(body_template) = body {
|
||||
let template = parser
|
||||
.parse(body_template)
|
||||
|
|
@ -157,7 +140,6 @@ pub fn process_headers(
|
|||
|
||||
let cleaned_key = key.trim().replace(&['\n', '\r'][..], "");
|
||||
let cleaned_value = header_value.trim().replace(&['\n', '\r'][..], "");
|
||||
|
||||
let name = match HeaderName::from_str(&cleaned_key) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
|
|
@ -170,7 +152,6 @@ pub fn process_headers(
|
|||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let value = match HeaderValue::from_str(&cleaned_value) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
|
|
@ -188,7 +169,7 @@ pub fn process_headers(
|
|||
Ok(headers_map)
|
||||
}
|
||||
|
||||
/// Exponential-backoff retry helper.
|
||||
/// Exponential‐backoff retry helper that always returns `Result<T, anyhow::Error>`.
|
||||
async fn retry_with_backoff<F, Fut, T>(
|
||||
mut operation: F,
|
||||
is_retryable: impl Fn(&Result<T, Error>, usize) -> bool,
|
||||
|
|
@ -216,7 +197,6 @@ where
|
|||
Err(anyhow!("Max retries reached"))
|
||||
}
|
||||
|
||||
/// Retry a multipart request with exponential backoff.
|
||||
pub async fn retry_multipart_request<F, Fut>(
|
||||
mut build_request: F,
|
||||
max_retries: usize,
|
||||
|
|
@ -256,7 +236,6 @@ where
|
|||
.await
|
||||
}
|
||||
|
||||
/// Retry an HTTP request with exponential backoff.
|
||||
pub async fn retry_request(
|
||||
request_builder: RequestBuilder,
|
||||
max_retries: u32,
|
||||
|
|
@ -414,42 +393,3 @@ pub async fn check_url_resolvable(url: &Url) -> Result<(), Box<dyn std::error::E
|
|||
let addr = format!("{}:{}", host, port);
|
||||
lookup_host(addr).await?.next().ok_or_else(|| "Failed to resolve URL".into()).map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_includes_body() {
|
||||
let url = Url::from_str("https://example.com/api").unwrap();
|
||||
let headers =
|
||||
BTreeMap::from([("Content-Type".to_string(), "application/json".to_string())]);
|
||||
|
||||
let key_no_body = generate_http_cache_key_parts("POST", &url, &headers, None);
|
||||
let key_body_a =
|
||||
generate_http_cache_key_parts("POST", &url, &headers, Some(r#"{"value": "abc"}"#));
|
||||
let key_body_b =
|
||||
generate_http_cache_key_parts("POST", &url, &headers, Some(r#"{"value": "xyz"}"#));
|
||||
|
||||
assert_ne!(key_no_body, key_body_a);
|
||||
assert_ne!(key_no_body, key_body_b);
|
||||
assert_ne!(key_body_a, key_body_b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_response_word_match() {
|
||||
let matchers = vec![ResponseMatcher::WordMatch {
|
||||
r#type: "word-match".to_string(),
|
||||
words: vec!["test".to_string()],
|
||||
match_all_words: true,
|
||||
negative: false,
|
||||
}];
|
||||
let status = StatusCode::OK;
|
||||
let body = "This is a test";
|
||||
let headers = HeaderMap::new();
|
||||
let html_allowed = false;
|
||||
|
||||
let result = validate_response(&matchers, body, &status, &headers, html_allowed);
|
||||
assert!(result);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@ pub fn generate_jdbc_cache_key(raw: &str) -> String {
|
|||
}
|
||||
|
||||
/// Validate a JDBC connection string by dispatching to the supported backend validators.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `jdbc_conn` - The JDBC connection string to validate
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates
|
||||
pub async fn validate_jdbc(jdbc_conn: &str, lax_tls: bool) -> Result<JdbcValidationOutcome> {
|
||||
let trimmed = jdbc_conn.trim();
|
||||
if !trimmed.to_ascii_lowercase().starts_with("jdbc:") {
|
||||
|
|
@ -90,14 +86,12 @@ fn normalize_postgres_url(subname: &str) -> Result<String> {
|
|||
return Err(anyhow!("Postgres JDBC connection string is empty"));
|
||||
}
|
||||
|
||||
// First try parsing using the standard JDBC layout, otherwise fall back to a canonical URL.
|
||||
let candidate = format!("postgresql:{}", trimmed);
|
||||
let mut url = Url::parse(&candidate).or_else(|_| {
|
||||
let fallback = format!("postgresql://{}", trimmed.trim_start_matches('/'));
|
||||
Url::parse(&fallback)
|
||||
})?;
|
||||
|
||||
// Extract credentials from the query string when they are present.
|
||||
let mut user = None;
|
||||
let mut password = None;
|
||||
if url.query().is_some() {
|
||||
|
|
@ -129,30 +123,3 @@ fn normalize_postgres_url(subname: &str) -> Result<String> {
|
|||
|
||||
Ok(url.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::normalize_postgres_url;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn normalizes_postgres_query_credentials() {
|
||||
let normalized = normalize_postgres_url(
|
||||
"//db.example.com:5432/app?user=admin&password=s3cr3t&sslmode=require",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(normalized, "postgresql://admin:s3cr3t@db.example.com:5432/app?sslmode=require");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_existing_credentials() {
|
||||
let normalized =
|
||||
normalize_postgres_url("//db.example.com:5432/app?sslmode=prefer").unwrap();
|
||||
assert_eq!(normalized, "postgresql://db.example.com:5432/app?sslmode=prefer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_input() {
|
||||
assert!(normalize_postgres_url("").is_err());
|
||||
}
|
||||
}
|
||||
|
|
@ -10,11 +10,9 @@ use reqwest::{redirect::Policy, Client, Url};
|
|||
use serde::Deserialize;
|
||||
use tokio::net::lookup_host;
|
||||
|
||||
use super::utils::check_url_resolvable;
|
||||
use super::http_validation::check_url_resolvable;
|
||||
|
||||
/// Global redirect-free client with strict TLS validation.
|
||||
/// Building a `Client` is comparatively expensive; re-using it lets reqwest
|
||||
/// share its internal connection pool and TLS sessions across JWT validations.
|
||||
static STRICT_CLIENT: Lazy<Client> = Lazy::new(|| {
|
||||
Client::builder()
|
||||
.redirect(Policy::none())
|
||||
|
|
@ -32,7 +30,6 @@ static LAX_CLIENT: Lazy<Client> = Lazy::new(|| {
|
|||
.expect("failed to build lax Client")
|
||||
});
|
||||
|
||||
/// Get the appropriate client based on TLS mode.
|
||||
fn get_client(lax_tls: bool) -> &'static Client {
|
||||
if lax_tls {
|
||||
&LAX_CLIENT
|
||||
|
|
@ -41,16 +38,10 @@ fn get_client(lax_tls: bool) -> &'static Client {
|
|||
}
|
||||
}
|
||||
|
||||
/// RFC 1918 + loopback + link-local nets we refuse to contact
|
||||
const BLOCKED_NETS: &[&str] = &[
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16", // private
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16", // loopback / link-local
|
||||
];
|
||||
/// RFC 1918 + loopback + link-local nets we refuse to contact.
|
||||
const BLOCKED_NETS: &[&str] =
|
||||
&["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "169.254.0.0/16"];
|
||||
|
||||
// aud is allowed to be either a string or an array, so let Serde flatten it.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Aud {
|
||||
|
|
@ -66,25 +57,15 @@ struct Claims {
|
|||
aud: Option<Aud>,
|
||||
}
|
||||
|
||||
/// Runtime options for JWT validation policy.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ValidateOptions {
|
||||
/// If true, accept unsigned tokens (`alg: "none"`) as long as temporal checks pass.
|
||||
/// Default is **false** (more secure).
|
||||
pub allow_alg_none: bool,
|
||||
|
||||
/// If provided and `iss` is absent, use this key to cryptographically verify the token.
|
||||
/// Useful for non-OIDC flows where you already know the verification key.
|
||||
pub fallback_decoding_key: Option<DecodingKey>,
|
||||
}
|
||||
|
||||
/// Backwards-compatible entry point with secure defaults:
|
||||
/// - `alg: none` is **rejected**
|
||||
/// - `iss` is **required** unless `fallback_decoding_key` is supplied (not supplied here)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The JWT token to validate
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates for JWKS fetching
|
||||
/// Backwards-compatible entry point with secure defaults.
|
||||
pub async fn validate_jwt(token: &str, lax_tls: bool) -> Result<(bool, String)> {
|
||||
validate_jwt_with(
|
||||
token,
|
||||
|
|
@ -95,19 +76,12 @@ pub async fn validate_jwt(token: &str, lax_tls: bool) -> Result<(bool, String)>
|
|||
}
|
||||
|
||||
/// Strict validator with policy control.
|
||||
/// Returns (is_active_credential, explanation).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The JWT token to validate
|
||||
/// * `opts` - Validation options
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates for JWKS fetching
|
||||
pub async fn validate_jwt_with(
|
||||
token: &str,
|
||||
opts: &ValidateOptions,
|
||||
lax_tls: bool,
|
||||
) -> Result<(bool, String)> {
|
||||
let client = get_client(lax_tls);
|
||||
// --- insecure payload decode to read claims --------------------------------
|
||||
let claims: Claims = {
|
||||
let payload_b64 = token.split('.').nth(1).ok_or_else(|| anyhow!("invalid JWT format"))?;
|
||||
let payload_json = URL_SAFE_NO_PAD
|
||||
|
|
@ -116,7 +90,6 @@ pub async fn validate_jwt_with(
|
|||
serde_json::from_slice(&payload_json).map_err(|e| anyhow!("invalid JSON claims: {e}"))?
|
||||
};
|
||||
|
||||
// temporal checks
|
||||
let now = Utc::now().timestamp();
|
||||
if let Some(nbf) = claims.nbf {
|
||||
if now < nbf {
|
||||
|
|
@ -129,7 +102,6 @@ pub async fn validate_jwt_with(
|
|||
}
|
||||
}
|
||||
|
||||
// parse header enough to read "alg" without jsonwebtoken's enum (which rejects "none")
|
||||
let header_b64 = token.split('.').next().ok_or_else(|| anyhow!("invalid JWT format"))?;
|
||||
let header_json =
|
||||
URL_SAFE_NO_PAD.decode(header_b64).map_err(|e| anyhow!("invalid base64 in header: {e}"))?;
|
||||
|
|
@ -137,10 +109,8 @@ pub async fn validate_jwt_with(
|
|||
serde_json::from_slice(&header_json).map_err(|e| anyhow!("invalid header json: {e}"))?;
|
||||
let alg_str = header_val.get("alg").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
// --- Policy: reject `alg: none` unless explicitly allowed ------------------
|
||||
if alg_str.eq_ignore_ascii_case("none") {
|
||||
if opts.allow_alg_none {
|
||||
// time-valid is enough if explicitly allowed
|
||||
return Ok((
|
||||
true,
|
||||
format!(
|
||||
|
|
@ -154,11 +124,9 @@ pub async fn validate_jwt_with(
|
|||
}
|
||||
}
|
||||
|
||||
// Safe to decode full header now that we know alg != none
|
||||
let header = decode_header(token).map_err(|e| anyhow!("decode header: {e}"))?;
|
||||
let alg = header.alg;
|
||||
|
||||
// Proactively skip HMAC-signed JWTs to avoid ambiguous liveness results.
|
||||
if matches!(alg, Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512) {
|
||||
return Ok((false, format!("HMAC-signed JWTs are not validated ({alg:?})")));
|
||||
}
|
||||
|
|
@ -166,16 +134,12 @@ pub async fn validate_jwt_with(
|
|||
let issuer = claims.iss.clone().unwrap_or_default();
|
||||
let aud_strings = extract_aud_strings(&claims);
|
||||
|
||||
// --- New rule: require `iss` OR use fallback key for crypto verification ---
|
||||
if issuer.trim().is_empty() {
|
||||
// No issuer — we may still accept if we can cryptographically verify with a fallback key
|
||||
if let Some(decoding_key) = opts.fallback_decoding_key.as_ref() {
|
||||
// Verify signature (aud checked if present)
|
||||
let mut validation = JwtValidation::new(alg);
|
||||
if !aud_strings.is_empty() {
|
||||
validation.set_audience(&aud_strings);
|
||||
}
|
||||
// We already did exp/nbf manually.
|
||||
validation.validate_exp = false;
|
||||
validation.validate_nbf = false;
|
||||
|
||||
|
|
@ -194,13 +158,10 @@ pub async fn validate_jwt_with(
|
|||
}
|
||||
}
|
||||
|
||||
// --- With `iss`: OIDC discovery + JWKS verification path -------------------
|
||||
// require kid before any network I/O
|
||||
let Some(kid) = header.kid.clone() else {
|
||||
return Ok((false, "no kid in header".into()));
|
||||
};
|
||||
|
||||
// build discovery URL and fetch it (redirects disabled)
|
||||
let config_url = format!("{}/.well-known/openid-configuration", issuer.trim_end_matches('/'));
|
||||
let cfg_resp = client
|
||||
.get(&config_url)
|
||||
|
|
@ -215,19 +176,16 @@ pub async fn validate_jwt_with(
|
|||
let cfg_json: serde_json::Value =
|
||||
cfg_resp.json().await.map_err(|e| anyhow!("invalid discovery JSON: {e}"))?;
|
||||
|
||||
// extract jwks_uri
|
||||
let jwks_uri = cfg_json
|
||||
.get("jwks_uri")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow!("jwks_uri missing"))?;
|
||||
|
||||
// must be HTTPS
|
||||
let url = Url::parse(jwks_uri).map_err(|e| anyhow!("invalid jwks_uri: {e}"))?;
|
||||
if url.scheme() != "https" {
|
||||
return Ok((false, "jwks_uri must use https".to_string()));
|
||||
}
|
||||
|
||||
// host must match issuer host
|
||||
let iss_host = Url::parse(&issuer)
|
||||
.map_err(|e| anyhow!("invalid iss: {e}"))?
|
||||
.host_str()
|
||||
|
|
@ -241,17 +199,14 @@ pub async fn validate_jwt_with(
|
|||
));
|
||||
}
|
||||
|
||||
// DNS resolution + private-range block
|
||||
for addr in lookup_host((jwks_host.as_str(), 443)).await? {
|
||||
if is_blocked_ip(addr.ip()) {
|
||||
return Ok((false, "jwks_uri resolves to private or link-local IP".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// reachability check (existing helper)
|
||||
check_url_resolvable(&url).await.map_err(|e| anyhow!("jwks uri unresolvable: {e}"))?;
|
||||
|
||||
// fetch JWKS with redirect-free client
|
||||
let jwks_resp = client.get(url).send().await.map_err(|e| anyhow!("jwks fetch failed: {e}"))?;
|
||||
if !jwks_resp.status().is_success() {
|
||||
return Ok((false, format!("jwks fetch failed: {}", jwks_resp.status())));
|
||||
|
|
@ -259,14 +214,12 @@ pub async fn validate_jwt_with(
|
|||
|
||||
let jwk_set: JwkSet = jwks_resp.json().await.map_err(|e| anyhow!("invalid jwks json: {e}"))?;
|
||||
|
||||
// select key by kid
|
||||
let jwk = jwk_set
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| anyhow!("kid not found in jwks"))?;
|
||||
|
||||
// verify signature
|
||||
let decoding_key = DecodingKey::from_jwk(jwk).map_err(|e| anyhow!("invalid jwk: {e}"))?;
|
||||
let mut validation = JwtValidation::new(header.alg);
|
||||
if !aud_strings.is_empty() {
|
||||
|
|
@ -281,7 +234,6 @@ pub async fn validate_jwt_with(
|
|||
Ok((true, format!("JWT valid (alg: {:?}, iss: {issuer}, aud: {:?})", alg, aud_strings)))
|
||||
}
|
||||
|
||||
/// Helper: normalize aud into a flat Vec<String>
|
||||
fn extract_aud_strings(claims: &Claims) -> Vec<String> {
|
||||
match &claims.aud {
|
||||
Some(Aud::Str(s)) => vec![s.clone()],
|
||||
|
|
@ -289,97 +241,7 @@ fn extract_aud_strings(claims: &Claims) -> Vec<String> {
|
|||
None => vec![],
|
||||
}
|
||||
}
|
||||
/// returns true if IP is in a blocked network
|
||||
|
||||
fn is_blocked_ip(ip: std::net::IpAddr) -> bool {
|
||||
BLOCKED_NETS.iter().filter_map(|cidr| cidr.parse::<IpNet>().ok()).any(|net| net.contains(&ip))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{validate_jwt, validate_jwt_with, ValidateOptions};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
|
||||
fn build_unsigned_token(exp_offset: i64) -> String {
|
||||
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
|
||||
let exp = (Utc::now() + ChronoDuration::seconds(exp_offset)).timestamp();
|
||||
let payload = URL_SAFE_NO_PAD.encode(format!(
|
||||
r#"{{
|
||||
"exp": {exp},
|
||||
"iss": "https://example.com",
|
||||
"aud": ["test-audience"]
|
||||
}}"#
|
||||
));
|
||||
format!("{header}.{payload}.")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hmac_signed_tokens_skipped() {
|
||||
let mut header = Header::new(jsonwebtoken::Algorithm::HS256);
|
||||
header.kid = Some("dummy".into());
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"iss": "https://example.com",
|
||||
"exp": (Utc::now() + ChronoDuration::minutes(5)).timestamp(),
|
||||
});
|
||||
|
||||
let token = encode(&header, &payload, &EncodingKey::from_secret(b"secret")).unwrap();
|
||||
let res = validate_jwt(&token, false).await.unwrap();
|
||||
assert!(!res.0);
|
||||
assert!(res.1.contains("HMAC-signed JWTs are not validated"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_kid_short_circuits_before_network() {
|
||||
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
|
||||
let payload = URL_SAFE_NO_PAD.encode(format!(
|
||||
r#"{{
|
||||
"exp": {},
|
||||
"iss": "https://example.com"
|
||||
}}"#,
|
||||
(Utc::now() + ChronoDuration::minutes(5)).timestamp()
|
||||
));
|
||||
let signature = URL_SAFE_NO_PAD.encode("sig");
|
||||
let token = format!("{header}.{payload}.{signature}");
|
||||
|
||||
let res = validate_jwt(&token, false).await.unwrap();
|
||||
assert!(!res.0);
|
||||
assert!(res.1.contains("no kid in header"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsigned_token_rejected_by_default() {
|
||||
let token = build_unsigned_token(60);
|
||||
let res = validate_jwt(&token, false).await.unwrap();
|
||||
assert!(!res.0);
|
||||
assert!(res.1.contains("unsigned JWT (alg: none) not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_token_allows_alg_none_when_opted_in() {
|
||||
let token = build_unsigned_token(60);
|
||||
let res = validate_jwt_with(
|
||||
&token,
|
||||
&ValidateOptions { allow_alg_none: true, fallback_decoding_key: None },
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.0, "expected success when alg none is explicitly allowed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expired_token_still_rejected() {
|
||||
let token = build_unsigned_token(-60);
|
||||
let res = validate_jwt_with(
|
||||
&token,
|
||||
&ValidateOptions { allow_alg_none: true, fallback_decoding_key: None },
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.0);
|
||||
assert!(res.1.contains("expired"));
|
||||
}
|
||||
}
|
||||
|
|
@ -25,11 +25,35 @@ mod utils;
|
|||
mod validation_body;
|
||||
|
||||
#[cfg(feature = "validation-http")]
|
||||
mod http_validation;
|
||||
pub mod http_validation;
|
||||
|
||||
#[cfg(feature = "validation-aws")]
|
||||
pub mod aws;
|
||||
|
||||
#[cfg(feature = "validation-azure")]
|
||||
pub mod azure;
|
||||
|
||||
#[cfg(feature = "validation-coinbase")]
|
||||
pub mod coinbase;
|
||||
|
||||
#[cfg(feature = "validation-gcp")]
|
||||
pub mod gcp;
|
||||
|
||||
#[cfg(feature = "validation-jwt")]
|
||||
pub mod jwt;
|
||||
|
||||
#[cfg(feature = "validation-database")]
|
||||
pub mod jdbc;
|
||||
|
||||
#[cfg(feature = "validation-database")]
|
||||
pub mod mongodb;
|
||||
|
||||
#[cfg(feature = "validation-database")]
|
||||
pub mod mysql;
|
||||
|
||||
#[cfg(feature = "validation-database")]
|
||||
pub mod postgres;
|
||||
|
||||
// Re-exports
|
||||
pub use utils::{find_closest_variable, process_captures};
|
||||
pub use validation_body::{as_str, clone_as_string, from_string, ValidationResponseBody};
|
||||
|
|
@ -42,13 +66,18 @@ pub use http_validation::{
|
|||
|
||||
#[cfg(feature = "validation-aws")]
|
||||
pub use aws::{
|
||||
aws_key_to_account_number, generate_aws_cache_key, set_aws_skip_account_ids,
|
||||
set_aws_validation_concurrency, should_skip_aws_validation, validate_aws_credentials,
|
||||
validate_aws_credentials_input,
|
||||
aws_key_to_account_number, generate_aws_cache_key, revoke_aws_access_key,
|
||||
set_aws_skip_account_ids, set_aws_validation_concurrency, should_skip_aws_validation,
|
||||
validate_aws_credentials, validate_aws_credentials_input,
|
||||
};
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use crossbeam_skiplist::SkipMap;
|
||||
|
||||
/// User agent string used for HTTP validation requests.
|
||||
#[cfg(feature = "validation-http")]
|
||||
|
|
@ -92,6 +121,9 @@ pub fn set_user_agent_suffix<S: Into<String>>(suffix: Option<S>) {
|
|||
/// Cache duration for validation results (20 minutes).
|
||||
pub const VALIDATION_CACHE_SECONDS: u64 = 1200;
|
||||
|
||||
/// Cache type used for validation memoization.
|
||||
pub type Cache = Arc<SkipMap<String, CachedResponse>>;
|
||||
|
||||
/// A cached validation response.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedResponse {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
// src/validation/mongodb.rs
|
||||
use std::{net::IpAddr, time::Duration};
|
||||
|
||||
use anyhow::Result;
|
||||
|
|
@ -12,45 +11,34 @@ use tokio::time::timeout;
|
|||
use tracing::debug;
|
||||
|
||||
pub fn looks_like_mongodb_uri(uri: &str) -> bool {
|
||||
// quick scheme check first
|
||||
if !(uri.starts_with("mongodb://") || uri.starts_with("mongodb+srv://")) {
|
||||
return false;
|
||||
}
|
||||
// pure string-level parse – no network, even for +srv
|
||||
mongodb::options::ConnectionString::parse(uri).is_ok()
|
||||
}
|
||||
|
||||
/// Return true if the URI targets localhost/loopback or a unix domain socket.
|
||||
/// This is a *string-only* check—no DNS or driver IO.
|
||||
fn uri_targets_localhost(uri: &str) -> bool {
|
||||
// strip scheme
|
||||
let rest = uri
|
||||
.strip_prefix("mongodb://")
|
||||
.or_else(|| uri.strip_prefix("mongodb+srv://"))
|
||||
.unwrap_or(uri);
|
||||
|
||||
// authority ends at first '/' (before db/path); if missing, take whole rest
|
||||
let authority = rest.split_once('/').map(|(a, _)| a).unwrap_or(rest);
|
||||
|
||||
// unix domain socket forms (percent-encoded "/path/to.sock")
|
||||
let auth_lower = authority.to_ascii_lowercase();
|
||||
if auth_lower.starts_with("%2f") || authority.starts_with('/') {
|
||||
return true; // UDS → treat as local
|
||||
return true;
|
||||
}
|
||||
|
||||
// drop userinfo if present
|
||||
let hostlist = authority.rsplit_once('@').map(|(_, h)| h).unwrap_or(authority);
|
||||
|
||||
// iterate seed list (mongodb://hostA,hostB,...)
|
||||
for part in hostlist.split(',') {
|
||||
let mut host = part.trim();
|
||||
|
||||
// strip brackets for IPv6 literals
|
||||
if host.starts_with('[') && host.ends_with(']') && host.len() >= 2 {
|
||||
host = &host[1..host.len() - 1];
|
||||
}
|
||||
|
||||
// strip :port if present (only when suffix is all digits)
|
||||
if let Some(idx) = host.rfind(':') {
|
||||
if host[idx + 1..].chars().all(|c| c.is_ascii_digit()) {
|
||||
host = &host[..idx];
|
||||
|
|
@ -65,12 +53,10 @@ fn uri_targets_localhost(uri: &str) -> bool {
|
|||
false
|
||||
}
|
||||
|
||||
/// Returns true for localhost/loopback/unspecified IPs and common localhost aliases.
|
||||
fn is_local_host(h: &str) -> bool {
|
||||
let s = h.trim().trim_end_matches('.');
|
||||
let s_lower = s.to_ascii_lowercase();
|
||||
|
||||
// common aliases seen in hosts files across distros
|
||||
if matches!(
|
||||
s_lower.as_str(),
|
||||
"localhost"
|
||||
|
|
@ -83,12 +69,10 @@ fn is_local_host(h: &str) -> bool {
|
|||
return true;
|
||||
}
|
||||
|
||||
// explicit unspecified forms
|
||||
if s_lower.as_str() == "0.0.0.0" || s_lower.as_str() == "::" {
|
||||
return true;
|
||||
}
|
||||
|
||||
// literal IPs
|
||||
if let Ok(ip) = s.parse::<IpAddr>() {
|
||||
return ip.is_loopback() || ip.is_unspecified();
|
||||
}
|
||||
|
|
@ -96,32 +80,24 @@ fn is_local_host(h: &str) -> bool {
|
|||
false
|
||||
}
|
||||
|
||||
const FAST_CONNECT_MS: u64 = 700; // direct single-host URIs
|
||||
const FAST_CONNECT_MS: u64 = 700;
|
||||
const FAST_SELECT_MS: u64 = 300;
|
||||
const SRV_PARSE_MS: u64 = 2_000; // limit DNS resolution time
|
||||
const SRV_PARSE_MS: u64 = 2_000;
|
||||
const SRV_CONNECT_MS: u64 = 2500;
|
||||
const SRV_SELECT_MS: u64 = 2500;
|
||||
|
||||
/// Validates a MongoDB URI in ≤ 2 s. Returns `(bool, String)` where the
|
||||
/// boolean indicates success and the string provides a status message.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `uri` - The MongoDB connection URI to validate
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates
|
||||
/// Validates a MongoDB URI in ≤ 2 s.
|
||||
pub async fn validate_mongodb(uri: &str, lax_tls: bool) -> Result<(bool, String)> {
|
||||
// ---- quick reject without touching the network
|
||||
if !looks_like_mongodb_uri(uri) {
|
||||
return Ok((false, "Invalid MongoDB URI".to_string()));
|
||||
}
|
||||
|
||||
// ---- refuse localhost/loopback/UDS outright
|
||||
if uri_targets_localhost(uri) {
|
||||
return Ok((false, "Refusing to validate localhost/loopback MongoDB URIs.".to_string()));
|
||||
}
|
||||
|
||||
let is_srv = uri.starts_with("mongodb+srv://");
|
||||
|
||||
// ---- build client opts (guarded so we don't hit DNS/driver first)
|
||||
let mut opts = if is_srv {
|
||||
match timeout(Duration::from_millis(SRV_PARSE_MS), ClientOptions::parse(uri)).await {
|
||||
Ok(res) => res?,
|
||||
|
|
@ -134,27 +110,22 @@ pub async fn validate_mongodb(uri: &str, lax_tls: bool) -> Result<(bool, String)
|
|||
};
|
||||
|
||||
if !is_srv {
|
||||
// one socket, skip cluster discovery for plain 'mongodb://'
|
||||
opts.direct_connection = Some(true);
|
||||
opts.connect_timeout = Some(Duration::from_millis(FAST_CONNECT_MS));
|
||||
opts.server_selection_timeout = Some(Duration::from_millis(FAST_SELECT_MS));
|
||||
} else {
|
||||
// SRV needs DNS and replica-set discovery; fail fast
|
||||
opts.connect_timeout = Some(Duration::from_millis(SRV_CONNECT_MS));
|
||||
opts.server_selection_timeout = Some(Duration::from_millis(SRV_SELECT_MS));
|
||||
// leave direct_connection = None (driver decides)
|
||||
}
|
||||
opts.max_pool_size = Some(1);
|
||||
opts.min_pool_size = Some(0);
|
||||
|
||||
// Configure TLS options based on lax_tls setting
|
||||
if lax_tls {
|
||||
debug!("Using lax TLS mode for MongoDB connection");
|
||||
let tls_options = TlsOptions::builder().allow_invalid_certificates(true).build();
|
||||
opts.tls = Some(Tls::Enabled(tls_options));
|
||||
}
|
||||
|
||||
// ---- dial and ping
|
||||
let client = Client::with_options(opts)?;
|
||||
let res = client.database("admin").run_command(doc! { "ping": 1 }).await;
|
||||
match res {
|
||||
|
|
@ -95,10 +95,6 @@ fn targets_localhost(opts: &Opts) -> bool {
|
|||
}
|
||||
|
||||
/// Validate a MySQL connection URL.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `mysql_url` - The MySQL connection URL to validate
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates
|
||||
pub async fn validate_mysql(mysql_url: &str, lax_tls: bool) -> Result<(bool, Vec<String>)> {
|
||||
let opts = parse_mysql_url(mysql_url)?;
|
||||
|
||||
|
|
@ -109,7 +105,6 @@ pub async fn validate_mysql(mysql_url: &str, lax_tls: bool) -> Result<(bool, Vec
|
|||
|
||||
let mut builder = OptsBuilder::from_opts(opts).stmt_cache_size(Some(0));
|
||||
|
||||
// Configure TLS options based on lax_tls setting
|
||||
if lax_tls {
|
||||
debug!("Using lax TLS mode for MySQL connection");
|
||||
let ssl_opts = SslOpts::default().with_danger_accept_invalid_certs(true);
|
||||
|
|
@ -139,42 +134,3 @@ pub async fn validate_mysql(mysql_url: &str, lax_tls: bool) -> Result<(bool, Vec
|
|||
Err(_) => Err(anyhow!("MySQL connection timed out after {CONNECT_TIMEOUT:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_mysql_url_accepts_valid_urls() {
|
||||
let url = "mysql://user:secret@exmple.com:3306/app";
|
||||
let opts = parse_mysql_url(url).expect("expected valid MySQL URL");
|
||||
assert_eq!(opts.user(), Some("user"));
|
||||
assert_eq!(opts.pass(), Some("secret"));
|
||||
assert_eq!(opts.ip_or_hostname(), "exmple.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mysql_url_rejects_invalid_urls() {
|
||||
for candidate in [
|
||||
"", // empty
|
||||
"mysql://user@exmple.com/app", // missing password
|
||||
"mysql://:secret@exmple.com/app", // missing username
|
||||
"mysql://user:secret@:3306/app", // missing host
|
||||
"postgres://user:secret@exmple.com", // wrong scheme
|
||||
"mysql://user:secret@exmple.com:70000/app", // invalid port
|
||||
] {
|
||||
assert!(
|
||||
parse_mysql_url(candidate).is_err(),
|
||||
"expected parsing to fail for {candidate}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mysql_url_allows_trimming_whitespace() {
|
||||
let opts =
|
||||
parse_mysql_url(" mysql://user:secret@exmple.com:3306/app ").expect("trimmed URL");
|
||||
assert_eq!(opts.user(), Some("user"));
|
||||
assert_eq!(opts.pass(), Some("secret"));
|
||||
}
|
||||
}
|
||||
|
|
@ -22,16 +22,10 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
|||
static INIT_PROVIDER: OnceCell<()> = OnceCell::new();
|
||||
fn ensure_crypto_provider() {
|
||||
INIT_PROVIDER.get_or_init(|| {
|
||||
// If another part of the program already installed a provider,
|
||||
// ignore the error — we just need one global provider.
|
||||
let _ = CryptoProvider::install_default(ring::default_provider());
|
||||
});
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate (for lax TLS mode).
|
||||
///
|
||||
/// This verifier still validates signatures to ensure the connection is encrypted,
|
||||
/// but does not verify the certificate chain against trusted CAs.
|
||||
#[derive(Debug)]
|
||||
struct LaxCertVerifier(Arc<CryptoProvider>);
|
||||
|
||||
|
|
@ -44,7 +38,6 @@ impl ServerCertVerifier for LaxCertVerifier {
|
|||
_ocsp_response: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> std::result::Result<ServerCertVerified, rustls::Error> {
|
||||
// Accept any certificate - this is the "lax" behavior
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
|
|
@ -93,14 +86,9 @@ pub fn parse_postgres_url(postgres_url: &str) -> Result<Config> {
|
|||
}
|
||||
|
||||
/// Validate a Postgres connection URL.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `postgres_url` - The Postgres connection URL to validate
|
||||
/// * `lax_tls` - If true, accept self-signed or invalid certificates
|
||||
pub async fn validate_postgres(postgres_url: &str, lax_tls: bool) -> Result<(bool, Vec<String>)> {
|
||||
let mut cfg = parse_postgres_url(postgres_url)?;
|
||||
|
||||
// --- skip localhost/loopback/unix-socket targets entirely -------------
|
||||
if has_any_local_host(&cfg) {
|
||||
debug!("Skipping Postgres validation: host is localhost/loopback or unix socket");
|
||||
return Ok((false, vec!["skipped localhost/loopback host".into()]));
|
||||
|
|
@ -117,16 +105,14 @@ pub async fn validate_postgres(postgres_url: &str, lax_tls: bool) -> Result<(boo
|
|||
fn has_any_local_host(cfg: &Config) -> bool {
|
||||
cfg.get_hosts().iter().any(|h| match h {
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => true, // local unix socket
|
||||
Host::Unix(_) => true,
|
||||
Host::Tcp(s) => is_local_tcp_host(s),
|
||||
})
|
||||
}
|
||||
|
||||
fn is_local_tcp_host(s: &str) -> bool {
|
||||
// strip URI-style IPv6 brackets if present
|
||||
let host = s.trim_matches(|c| c == '[' || c == ']');
|
||||
|
||||
// Direct IPs
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
return match ip {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
|
|
@ -138,7 +124,6 @@ fn is_local_tcp_host(s: &str) -> bool {
|
|||
};
|
||||
}
|
||||
|
||||
// Common localhost hostnames
|
||||
let lower = host.to_ascii_lowercase();
|
||||
lower == "localhost"
|
||||
|| lower.starts_with("localhost.")
|
||||
|
|
@ -151,7 +136,6 @@ async fn check_postgres_db_connection(
|
|||
original_mode: SslMode,
|
||||
lax_tls: bool,
|
||||
) -> Result<(bool, Vec<String>)> {
|
||||
// First attempt with caller-supplied sslmode, optional retry without TLS.
|
||||
for attempt in 0..=1 {
|
||||
let cfg_try = cfg.clone();
|
||||
|
||||
|
|
@ -170,11 +154,9 @@ async fn check_postgres_db_connection(
|
|||
.await
|
||||
} else {
|
||||
timeout(CONNECT_TIMEOUT, async {
|
||||
// Ensure Rustls crypto provider is installed *before* using the builder
|
||||
ensure_crypto_provider();
|
||||
|
||||
let tls_cfg = if lax_tls {
|
||||
// Lax mode: accept any certificate (self-signed, expired, wrong hostname)
|
||||
debug!("Using lax TLS mode for Postgres connection");
|
||||
let provider = Arc::new(ring::default_provider());
|
||||
ClientConfig::builder()
|
||||
|
|
@ -182,7 +164,6 @@ async fn check_postgres_db_connection(
|
|||
.with_custom_certificate_verifier(Arc::new(LaxCertVerifier(provider)))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
// Strict mode: full certificate validation
|
||||
let CertificateResult { certs, errors, .. } = load_native_certs();
|
||||
for err in errors {
|
||||
debug!("native-cert error: {err}");
|
||||
|
|
@ -262,55 +243,3 @@ fn server_requires_encryption(err_msg: &str) -> bool {
|
|||
fn missing_cluster_identifier(err_msg: &str) -> bool {
|
||||
err_msg.contains("missing cluster identifier")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
is_local_tcp_host, missing_cluster_identifier, parse_postgres_url,
|
||||
server_requires_encryption,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn detects_encryption_requirement() {
|
||||
assert!(server_requires_encryption("db error: FATAL: server requires encryption"));
|
||||
assert!(!server_requires_encryption("some other error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_missing_cluster() {
|
||||
assert!(missing_cluster_identifier(
|
||||
"db error: FATAL: codeParamsRoutingFailed: missing cluster identifier",
|
||||
));
|
||||
assert!(!missing_cluster_identifier("another error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_local_hosts() {
|
||||
for h in [
|
||||
"localhost",
|
||||
"LOCALHOST",
|
||||
"localhost.localdomain",
|
||||
"localhost6",
|
||||
"127.0.0.1",
|
||||
"[::1]",
|
||||
"::",
|
||||
] {
|
||||
assert!(is_local_tcp_host(h), "should treat {h} as local");
|
||||
}
|
||||
for h in ["db.example.com", "10.0.0.1"] {
|
||||
assert!(!is_local_tcp_host(h), "should not treat {h} as local");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_accepts_postgis_scheme() {
|
||||
let url = "postgis://postgres:secret@exmple.com:5432";
|
||||
assert!(parse_postgres_url(url).is_ok(), "postgis scheme should be accepted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_invalid_port() {
|
||||
let url = "postgres://postgres:secret@exmple.com:70000";
|
||||
assert!(parse_postgres_url(url).is_err(), "invalid port should be rejected");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# This file lists gitignore-style patterns: https://git-scm.com/docs/gitignore
|
||||
#
|
||||
# These patterns control which paths Kingfisher will scan.
|
||||
|
||||
**/objects/pack/pack-*.pack
|
||||
**/objects/pack/pack-*.idx
|
||||
**/packed-refs
|
||||
|
|
@ -7,7 +7,6 @@ use std::{
|
|||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossbeam_skiplist::SkipMap;
|
||||
use dashmap::DashMap;
|
||||
use http::StatusCode;
|
||||
use liquid::Object;
|
||||
|
|
@ -23,24 +22,20 @@ use crate::{
|
|||
location::OffsetSpan,
|
||||
matcher::{OwnedBlobMatch, SerializableCaptures},
|
||||
rules::rule::Validation,
|
||||
validation_body::{self, ValidationResponseBody},
|
||||
validation_body::{self},
|
||||
};
|
||||
|
||||
// Re-export TlsMode from kingfisher_rules for use in client_for_rule
|
||||
pub use kingfisher_rules::TlsMode as RuleTlsMode;
|
||||
|
||||
pub mod aws;
|
||||
pub mod azure;
|
||||
pub mod coinbase;
|
||||
pub mod gcp;
|
||||
pub mod httpvalidation;
|
||||
pub mod jdbc;
|
||||
pub mod jwt;
|
||||
pub mod mongodb;
|
||||
pub mod mysql;
|
||||
pub mod postgres;
|
||||
pub use mysql::validate_mysql;
|
||||
pub use postgres::validate_postgres;
|
||||
pub use kingfisher_scanner::validation::aws;
|
||||
pub use kingfisher_scanner::validation::http_validation as httpvalidation;
|
||||
pub use kingfisher_scanner::validation::mysql::validate_mysql;
|
||||
pub use kingfisher_scanner::validation::postgres::validate_postgres;
|
||||
pub use kingfisher_scanner::validation::CachedResponse;
|
||||
pub use kingfisher_scanner::validation::{
|
||||
azure, coinbase, gcp, jdbc, jwt, mongodb, mysql, postgres,
|
||||
};
|
||||
pub mod utils;
|
||||
|
||||
const VALIDATION_CACHE_SECONDS: u64 = 1200; // 20 minutes
|
||||
|
|
@ -88,7 +83,8 @@ pub fn set_user_agent_suffix<S: Into<String>>(suffix: Option<S>) {
|
|||
return;
|
||||
}
|
||||
|
||||
let _ = USER_AGENT_SUFFIX.set(trimmed);
|
||||
let _ = USER_AGENT_SUFFIX.set(trimmed.clone());
|
||||
kingfisher_scanner::validation::set_user_agent_suffix(Some(trimmed));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -158,7 +154,7 @@ impl ValidationClients {
|
|||
}
|
||||
|
||||
// Use SkipMap-based cache instead of a mutex-wrapped FxHashMap.
|
||||
type Cache = Arc<SkipMap<String, CachedResponse>>;
|
||||
type Cache = kingfisher_scanner::validation::Cache;
|
||||
|
||||
/// Returns an opaque 64-bit key for internal validation deduplication.
|
||||
///
|
||||
|
|
@ -227,24 +223,6 @@ pub fn is_parseable_mysql_uri(uri: &str) -> bool {
|
|||
mysql::parse_mysql_url(uri).is_ok()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CachedResponse {
|
||||
pub body: ValidationResponseBody,
|
||||
pub status: StatusCode,
|
||||
pub is_valid: bool,
|
||||
pub timestamp: Instant,
|
||||
}
|
||||
|
||||
impl CachedResponse {
|
||||
pub fn new(body: ValidationResponseBody, status: StatusCode, is_valid: bool) -> Self {
|
||||
Self { body, status, is_valid, timestamp: Instant::now() }
|
||||
}
|
||||
|
||||
pub fn is_still_valid(&self, cache_duration: Duration) -> bool {
|
||||
self.timestamp.elapsed() < cache_duration
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect dependent variables and missing dependencies from the provided matches.
|
||||
pub fn collect_variables_and_dependencies(
|
||||
matches: &[OwnedBlobMatch],
|
||||
|
|
|
|||
|
|
@ -1,453 +0,0 @@
|
|||
use std::{collections::HashSet, sync::RwLock, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use aws_config::{retry::RetryConfig, BehaviorVersion, SdkConfig};
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sdk_iam::{
|
||||
config::Builder as IamConfigBuilder, error::SdkError as IamSdkError,
|
||||
operation::update_access_key::UpdateAccessKeyError, types::StatusType, Client as IamClient,
|
||||
};
|
||||
use aws_sdk_sts::{
|
||||
config::Builder as StsConfigBuilder, error::SdkError,
|
||||
operation::get_caller_identity::GetCallerIdentityError, Client as StsClient,
|
||||
};
|
||||
use aws_smithy_http_client::{
|
||||
proxy::ProxyConfig, tls, Builder as HttpClientBuilder, ConnectorBuilder,
|
||||
};
|
||||
use aws_smithy_runtime_api::{
|
||||
box_error::BoxError,
|
||||
client::{
|
||||
http::SharedHttpClient,
|
||||
interceptors::{context::BeforeTransmitInterceptorContextMut, Intercept},
|
||||
runtime_components::RuntimeComponents,
|
||||
},
|
||||
};
|
||||
use aws_smithy_types::config_bag::ConfigBag;
|
||||
use aws_types::region::Region;
|
||||
use base32::Alphabet;
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use http::{
|
||||
header::{HeaderValue, USER_AGENT},
|
||||
StatusCode,
|
||||
};
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
use rand::{rng, Rng};
|
||||
use regex::Regex;
|
||||
use tokio::{
|
||||
sync::Semaphore,
|
||||
time::{sleep, timeout},
|
||||
};
|
||||
|
||||
use crate::validation::GLOBAL_USER_AGENT;
|
||||
|
||||
static AWS_VALIDATION_SEMAPHORE: OnceCell<Semaphore> = OnceCell::new();
|
||||
const BUILTIN_SKIP_ACCOUNT_IDS: &[&str] = &[
|
||||
"052310077262",
|
||||
"171436882533",
|
||||
"528757803018",
|
||||
"534261010715",
|
||||
"538784191382",
|
||||
"595918472158",
|
||||
"729780141977",
|
||||
"893192397702",
|
||||
"992382622183",
|
||||
];
|
||||
|
||||
static AWS_SKIP_ACCOUNT_IDS: Lazy<RwLock<HashSet<String>>> = Lazy::new(|| {
|
||||
let mut set = HashSet::new();
|
||||
set.extend(BUILTIN_SKIP_ACCOUNT_IDS.iter().map(|id| id.to_string()));
|
||||
RwLock::new(set)
|
||||
});
|
||||
|
||||
fn build_http_client() -> SharedHttpClient {
|
||||
HttpClientBuilder::new().build_with_connector_fn(|settings, runtime_components| {
|
||||
let mut conn_builder = ConnectorBuilder::default()
|
||||
.tls_provider(tls::Provider::Rustls(tls::rustls_provider::CryptoMode::AwsLc));
|
||||
|
||||
conn_builder.set_connector_settings(settings.cloned());
|
||||
if let Some(components) = runtime_components {
|
||||
conn_builder.set_sleep_impl(components.sleep_impl());
|
||||
}
|
||||
conn_builder.set_proxy_config(Some(ProxyConfig::from_env()));
|
||||
conn_builder.build()
|
||||
})
|
||||
}
|
||||
|
||||
async fn build_base_config(credentials: Credentials) -> SdkConfig {
|
||||
let retry_config = RetryConfig::adaptive().with_max_attempts(3);
|
||||
aws_config::defaults(BehaviorVersion::latest())
|
||||
.region(Region::new("us-east-1"))
|
||||
.credentials_provider(credentials)
|
||||
.http_client(build_http_client())
|
||||
.retry_config(retry_config)
|
||||
.load()
|
||||
.await
|
||||
}
|
||||
|
||||
fn extract_account_id(input: &str) -> Option<String> {
|
||||
let trimmed = input.trim();
|
||||
if trimmed.len() == 12 && trimmed.chars().all(|c| c.is_ascii_digit()) {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
|
||||
static ACCOUNT_ID_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d{12})").expect("valid regex"));
|
||||
ACCOUNT_ID_RE.captures(trimmed).and_then(|caps| caps.get(1)).map(|m| m.as_str().to_string())
|
||||
}
|
||||
|
||||
/// Set the maximum number of concurrent AWS validations. Call before first use.
|
||||
pub fn set_aws_validation_concurrency(max: usize) {
|
||||
AWS_VALIDATION_SEMAPHORE.set(Semaphore::new(max)).ok();
|
||||
}
|
||||
|
||||
fn aws_validation_semaphore() -> &'static Semaphore {
|
||||
AWS_VALIDATION_SEMAPHORE.get_or_init(|| Semaphore::new(15))
|
||||
}
|
||||
|
||||
pub fn set_aws_skip_account_ids<I, S>(ids: I)
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: Into<String>,
|
||||
{
|
||||
let mut guard = match AWS_SKIP_ACCOUNT_IDS.write() {
|
||||
Ok(g) => g,
|
||||
Err(poisoned) => poisoned.into_inner(),
|
||||
};
|
||||
guard.clear();
|
||||
|
||||
guard.extend(BUILTIN_SKIP_ACCOUNT_IDS.iter().map(|id| id.to_string()));
|
||||
|
||||
for raw in ids.into_iter() {
|
||||
let value = raw.into();
|
||||
if value.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Some(normalized) = extract_account_id(&value) {
|
||||
guard.insert(normalized);
|
||||
} else {
|
||||
tracing::warn!("Ignoring invalid AWS account ID in skip list: {value}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_skip_aws_validation(access_key_id: &str) -> Option<String> {
|
||||
let guard = AWS_SKIP_ACCOUNT_IDS.read().ok()?;
|
||||
if guard.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let account = aws_key_to_account_number(access_key_id).ok()?;
|
||||
if guard.contains(&account) {
|
||||
Some(account)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UaInterceptor;
|
||||
|
||||
impl Intercept for UaInterceptor {
|
||||
fn name(&self) -> &'static str {
|
||||
"ua"
|
||||
}
|
||||
|
||||
fn modify_before_transmit(
|
||||
&self,
|
||||
context: &mut BeforeTransmitInterceptorContextMut<'_>,
|
||||
_rc: &RuntimeComponents,
|
||||
_cfg: &mut ConfigBag,
|
||||
) -> std::result::Result<(), BoxError> {
|
||||
let req = context.request_mut();
|
||||
req.headers_mut().insert(
|
||||
USER_AGENT,
|
||||
HeaderValue::from_str(GLOBAL_USER_AGENT.as_str())
|
||||
.map_err(|e| format!("invalid USER_AGENT header: {e}"))?,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a standardized cache key for AWS validation attempts
|
||||
pub fn generate_aws_cache_key(aws_access_key_id: &str, aws_secret_access_key: &str) -> String {
|
||||
use sha1::{Digest, Sha1};
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(aws_access_key_id.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
hasher.update(aws_secret_access_key.as_bytes());
|
||||
format!("AWS:{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
// Validate AWS credentials before attempting validation
|
||||
pub fn validate_aws_credentials_input(access_key_id: &str, secret_key: &str) -> Result<(), String> {
|
||||
// Validate access key ID format (typically starts with "AKIA" and is 20 chars)
|
||||
if !access_key_id.starts_with("AKIA") || access_key_id.len() != 20 {
|
||||
return Err("Invalid AWS access key ID format".to_string());
|
||||
}
|
||||
// Validate secret key format (should be at least 40 chars)
|
||||
if secret_key.len() < 40 {
|
||||
return Err("Invalid AWS secret key format".to_string());
|
||||
}
|
||||
// Check for invalid characters
|
||||
if !access_key_id.chars().all(|c| c.is_ascii_alphanumeric()) {
|
||||
return Err("AWS access key ID contains invalid characters".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_throttling_or_transient(e: &SdkError<GetCallerIdentityError>) -> bool {
|
||||
match e {
|
||||
SdkError::ServiceError(ctx) => {
|
||||
let code = ctx.err().meta().code().unwrap_or_default();
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
code.contains("Throttl")
|
||||
|| status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
SdkError::DispatchFailure(df) => df.is_timeout() || df.is_io(),
|
||||
SdkError::ResponseError(ctx) => {
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_iam_throttling_or_transient(e: &IamSdkError<UpdateAccessKeyError>) -> bool {
|
||||
match e {
|
||||
IamSdkError::ServiceError(ctx) => {
|
||||
let code = ctx.err().meta().code().unwrap_or_default();
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
code.contains("Throttl")
|
||||
|| status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
IamSdkError::DispatchFailure(df) => df.is_timeout() || df.is_io(),
|
||||
IamSdkError::ResponseError(ctx) => {
|
||||
let status: StatusCode = ctx.raw().status().into();
|
||||
status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn revoke_aws_access_key(
|
||||
aws_access_key_id: &str,
|
||||
aws_secret_access_key: &str,
|
||||
) -> Result<(bool, String)> {
|
||||
// Create static credentials
|
||||
let credentials = Credentials::new(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
None, // session token
|
||||
None, // expiry
|
||||
"static", // provider name
|
||||
);
|
||||
let config = build_base_config(credentials).await;
|
||||
|
||||
// Create IAM client
|
||||
let iam_config = IamConfigBuilder::from(&config).interceptor(UaInterceptor).build();
|
||||
let iam_client = IamClient::from_conf(iam_config);
|
||||
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
const ATTEMPT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
for attempt in 1..=MAX_ATTEMPTS {
|
||||
let result = timeout(
|
||||
ATTEMPT_TIMEOUT,
|
||||
iam_client
|
||||
.update_access_key()
|
||||
.access_key_id(aws_access_key_id)
|
||||
.status(StatusType::Inactive)
|
||||
.send(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(_)) => {
|
||||
return Ok((true, "AWS access key set to Inactive".to_string()));
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
if is_iam_throttling_or_transient(&e) {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS revocation failed: {}", e));
|
||||
}
|
||||
} else {
|
||||
return Ok((false, e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS revocation timed out"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let max_delay = 100u64 * 2u64.pow((attempt - 1) as u32);
|
||||
let sleep_ms = rng().random_range(0..=max_delay);
|
||||
sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
Err(anyhow!("AWS revocation failed"))
|
||||
}
|
||||
pub async fn validate_aws_credentials(
|
||||
aws_access_key_id: &str,
|
||||
aws_secret_access_key: &str,
|
||||
) -> Result<(bool, String)> {
|
||||
let _permit = aws_validation_semaphore().acquire().await.expect("semaphore closed");
|
||||
|
||||
// Create static credentials
|
||||
let credentials = Credentials::new(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
None, // session token
|
||||
None, // expiry
|
||||
"static", // provider name
|
||||
);
|
||||
let config = build_base_config(credentials).await;
|
||||
|
||||
// Create STS client
|
||||
let sts_config = StsConfigBuilder::from(&config).interceptor(UaInterceptor).build();
|
||||
let sts_client = StsClient::from_conf(sts_config);
|
||||
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
const ATTEMPT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
for attempt in 1..=MAX_ATTEMPTS {
|
||||
let result = timeout(ATTEMPT_TIMEOUT, sts_client.get_caller_identity().send()).await;
|
||||
match result {
|
||||
Ok(Ok(identity)) => {
|
||||
let arn = identity.arn.unwrap_or_else(|| "Unknown".to_string());
|
||||
return Ok((true, arn));
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
if is_throttling_or_transient(&e) {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS validation failed: {}", e));
|
||||
}
|
||||
} else {
|
||||
return Ok((false, e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt == MAX_ATTEMPTS {
|
||||
return Err(anyhow!("AWS validation timed out"));
|
||||
}
|
||||
}
|
||||
}
|
||||
let max_delay = 100u64 * 2u64.pow((attempt - 1) as u32);
|
||||
let sleep_ms = rng().random_range(0..=max_delay);
|
||||
sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
Err(anyhow!("AWS validation failed"))
|
||||
}
|
||||
|
||||
/// Converts an AWS Key ID to an AWS Account Number.
|
||||
/// It assumes that the Key ID has a specific format and extracts the account
|
||||
/// number encoded within it. Reference: https://medium.com/@TalBeerySec/a-short-note-on-aws-key-id-f88cc4317489
|
||||
pub fn aws_key_to_account_number(aws_key_id: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Ensure the AWS Key ID is at least 5 characters long (since we'll access index
|
||||
// 4)
|
||||
if aws_key_id.len() < 5 {
|
||||
return Err("AWSKeyID is too short".into());
|
||||
}
|
||||
// Check if the 5th character is 'I' or 'J'
|
||||
let fifth_char = aws_key_id.as_bytes()[4] as char;
|
||||
if fifth_char == 'I' || fifth_char == 'J' {
|
||||
let err_msg =
|
||||
format!("Not possible to retrieve account number for {} keys", &aws_key_id[..5]);
|
||||
return Err(err_msg.into());
|
||||
}
|
||||
// Remove the Key ID prefix (first 4 characters)
|
||||
let trimmed_aws_key_id = &aws_key_id[4..];
|
||||
// Decode the trimmed Key ID from base32, ensuring it's in uppercase
|
||||
let decoded =
|
||||
base32::decode(Alphabet::Rfc4648 { padding: false }, &trimmed_aws_key_id.to_uppercase())
|
||||
.ok_or("Error decoding AWSKeyID")?;
|
||||
if decoded.len() < 6 {
|
||||
return Err("Decoded AWSKeyID is too short".into());
|
||||
}
|
||||
// Create an 8-byte array initialized to zeros
|
||||
let mut data = [0u8; 8];
|
||||
// Copy decoded[0..6] into data[2..8]
|
||||
data[2..8].copy_from_slice(&decoded[0..6]);
|
||||
// Interpret data as a big-endian u64
|
||||
let z = BigEndian::read_u64(&data);
|
||||
// Define the mask
|
||||
const MASK: u64 = 0x7FFFFFFFFF80;
|
||||
// Calculate the account number
|
||||
let account_num = (z & MASK) >> 7;
|
||||
// Return the account number formatted as a 12-digit string
|
||||
Ok(format!("{:012}", account_num))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Mutex;
|
||||
|
||||
static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
|
||||
|
||||
#[test]
|
||||
fn skip_account_list_normalizes_inputs() {
|
||||
let _lock = TEST_GUARD.lock().unwrap();
|
||||
|
||||
set_aws_skip_account_ids([
|
||||
" 052310077262 ",
|
||||
"arn:aws:iam::171436882533:role/demo",
|
||||
"invalid",
|
||||
]);
|
||||
|
||||
let guard = AWS_SKIP_ACCOUNT_IDS.read().unwrap();
|
||||
assert!(guard.contains("052310077262"));
|
||||
assert!(guard.contains("171436882533"));
|
||||
assert_eq!(guard.len(), BUILTIN_SKIP_ACCOUNT_IDS.len());
|
||||
drop(guard);
|
||||
|
||||
set_aws_skip_account_ids(Vec::<String>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_skip_when_account_matches() {
|
||||
let _lock = TEST_GUARD.lock().unwrap();
|
||||
|
||||
set_aws_skip_account_ids(["534261010715"]);
|
||||
assert_eq!(
|
||||
should_skip_aws_validation("AKIAXYZDQCEN4B6JSJQI"),
|
||||
Some("534261010715".to_string())
|
||||
);
|
||||
|
||||
set_aws_skip_account_ids(Vec::<String>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_canary_accounts_are_preseeded() {
|
||||
let _lock = TEST_GUARD.lock().unwrap();
|
||||
|
||||
set_aws_skip_account_ids(Vec::<String>::new());
|
||||
assert_eq!(
|
||||
should_skip_aws_validation("AKIAXYZDQCEN4B6JSJQI"),
|
||||
Some("534261010715".to_string())
|
||||
);
|
||||
|
||||
set_aws_skip_account_ids(Vec::<String>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_accounts_are_deduplicated() {
|
||||
let _lock = TEST_GUARD.lock().unwrap();
|
||||
|
||||
set_aws_skip_account_ids([
|
||||
"534261010715",
|
||||
"arn:aws:iam::534261010715:user/canarytokens",
|
||||
" 534261010715 ",
|
||||
]);
|
||||
|
||||
let guard = AWS_SKIP_ACCOUNT_IDS.read().unwrap();
|
||||
assert_eq!(guard.iter().filter(|id| id.as_str() == "534261010715").count(), 1);
|
||||
drop(guard);
|
||||
|
||||
set_aws_skip_account_ids(Vec::<String>::new());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,646 +0,0 @@
|
|||
use std::{collections::BTreeMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use crate::validation::GLOBAL_USER_AGENT;
|
||||
use anyhow::{anyhow, Error, Result};
|
||||
use http::StatusCode;
|
||||
use liquid::Object;
|
||||
use quick_xml::de::from_str as xml_from_str;
|
||||
use reqwest::{
|
||||
header,
|
||||
header::{HeaderMap, HeaderName, HeaderValue},
|
||||
Client, Method, RequestBuilder, Response, Url,
|
||||
};
|
||||
use serde::de::IgnoredAny;
|
||||
use sha1::{Digest, Sha1};
|
||||
use tokio::time::sleep;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::rules::rule::ResponseMatcher;
|
||||
|
||||
/// Build a deterministic cache key from the immutable parts of an HTTP request.
|
||||
///
|
||||
/// * `method` – case-insensitive HTTP verb (“GET”, “POST”…)
|
||||
/// * `url` – fully-qualified URL (any query string should already be present)
|
||||
/// * `headers` – *logical* headers you intend to send (template-rendered, lower-level additions
|
||||
/// such as `User-Agent` may be appended by the caller)
|
||||
///
|
||||
/// The parts are concatenated with `\0` separators before hashing to avoid accidental
|
||||
/// collisions such as `"GET/foo"` vs `"GE" + "T/foo"`.
|
||||
pub fn generate_http_cache_key_parts(
|
||||
method: &str,
|
||||
url: &Url,
|
||||
headers: &BTreeMap<String, String>,
|
||||
body: Option<&str>,
|
||||
) -> String {
|
||||
let method = method.to_uppercase(); // ensure "get" == "GET"
|
||||
let url = url.as_str(); // canonical form from `reqwest::Url`
|
||||
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(method.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
hasher.update(url.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
|
||||
// Collect headers sorted lexicographically (BTreeMap is already sorted),
|
||||
// then hash as `key:value\0`
|
||||
for (k, v) in headers {
|
||||
hasher.update(k.as_bytes());
|
||||
hasher.update(b":");
|
||||
hasher.update(v.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
}
|
||||
|
||||
// Include the request body in the cache key if present
|
||||
if let Some(b) = body {
|
||||
hasher.update(b"BODY\0");
|
||||
hasher.update(b.as_bytes());
|
||||
hasher.update(b"\0");
|
||||
}
|
||||
|
||||
// Hex-encode and prefix so callers can tell this key came from HTTP logic
|
||||
format!("HTTP:{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Parse an HTTP method from a string.
|
||||
pub fn parse_http_method(method_str: &str) -> Result<Method, String> {
|
||||
Method::from_str(method_str).map_err(|_| format!("Invalid HTTP method: {}", method_str))
|
||||
}
|
||||
|
||||
/// Build a reqwest RequestBuilder using the provided parameters.
|
||||
pub fn build_request_builder(
|
||||
client: &Client,
|
||||
method_str: &str,
|
||||
url: &Url,
|
||||
headers: &BTreeMap<String, String>,
|
||||
body: &Option<String>,
|
||||
timeout: Duration,
|
||||
parser: &liquid::Parser,
|
||||
globals: &liquid::Object,
|
||||
) -> Result<RequestBuilder, String> {
|
||||
let method = parse_http_method(method_str).map_err(|err_msg| {
|
||||
debug!("{}", err_msg);
|
||||
err_msg
|
||||
})?;
|
||||
let mut request_builder = client.request(method, url.clone()).timeout(timeout);
|
||||
let custom_headers = process_headers(headers, parser, globals, url)
|
||||
.map_err(|e| format!("Error processing headers: {}", e))?;
|
||||
|
||||
// Prepare a standard set of headers.
|
||||
let user_agent = GLOBAL_USER_AGENT.as_str();
|
||||
let standard_headers = [
|
||||
(header::USER_AGENT, user_agent),
|
||||
(
|
||||
header::ACCEPT,
|
||||
"text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
|
||||
),
|
||||
(header::ACCEPT_LANGUAGE, "en-US,en;q=0.5"),
|
||||
(header::ACCEPT_ENCODING, "gzip, deflate, br"),
|
||||
(header::CONNECTION, "keep-alive"),
|
||||
];
|
||||
// Start with the standard headers and then overlay any custom headers so
|
||||
// caller-specified values take precedence over defaults.
|
||||
let mut combined_headers = HeaderMap::new();
|
||||
for (name, value) in &standard_headers {
|
||||
if let Ok(hv) = HeaderValue::from_str(value) {
|
||||
combined_headers.insert(name.clone(), hv);
|
||||
}
|
||||
}
|
||||
for (name, value) in custom_headers.iter() {
|
||||
combined_headers.insert(name.clone(), value.clone());
|
||||
}
|
||||
request_builder = request_builder.headers(combined_headers);
|
||||
|
||||
// If a body template is provided, parse and render it
|
||||
if let Some(body_template) = body {
|
||||
let template = parser
|
||||
.parse(body_template)
|
||||
.map_err(|e| format!("Error parsing body template: {}", e))?;
|
||||
let rendered_body = template
|
||||
.render(globals)
|
||||
.map_err(|e| format!("Error rendering body template: {}", e))?;
|
||||
request_builder = request_builder.body(rendered_body);
|
||||
}
|
||||
|
||||
Ok(request_builder)
|
||||
}
|
||||
|
||||
/// Process headers from a BTreeMap, rendering any Liquid templates.
|
||||
pub fn process_headers(
|
||||
headers: &BTreeMap<String, String>,
|
||||
parser: &liquid::Parser,
|
||||
globals: &Object,
|
||||
url: &Url,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers_map = HeaderMap::new();
|
||||
for (key, value) in headers {
|
||||
// Render the template
|
||||
let template = match parser.parse(value) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
debug!("Error parsing Liquid template for '{}': {}", key, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = match template.render(globals) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"Failed to render header template. URL = <{}> | Key '{}': {}",
|
||||
url.as_str(),
|
||||
key,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Clean key and value
|
||||
let cleaned_key = key.trim().replace(&['\n', '\r'][..], "");
|
||||
let cleaned_value = header_value.trim().replace(&['\n', '\r'][..], "");
|
||||
// Validate header name
|
||||
let name = match HeaderName::from_str(&cleaned_key) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"Invalid header name. URL = <{}> | Key '{}': {}",
|
||||
url.as_str(),
|
||||
cleaned_key,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Validate header value
|
||||
let value = match HeaderValue::from_str(&cleaned_value) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"Invalid header value. URL = <{}> | Value '{}': {}",
|
||||
url.as_str(),
|
||||
cleaned_value,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers_map.insert(name, value);
|
||||
}
|
||||
Ok(headers_map)
|
||||
}
|
||||
|
||||
/// Exponential‐backoff retry helper that always returns `Result<T, anyhow::Error>`
|
||||
async fn retry_with_backoff<F, Fut, T>(
|
||||
mut operation: F,
|
||||
is_retryable: impl Fn(&Result<T, Error>, usize) -> bool,
|
||||
max_retries: usize,
|
||||
backoff_min: Duration,
|
||||
backoff_max: Duration,
|
||||
) -> Result<T, Error>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: Future<Output = Result<T, Error>>,
|
||||
{
|
||||
let mut retries = 0;
|
||||
while retries <= max_retries {
|
||||
let result = operation().await;
|
||||
// If this result is *not* retryable, return it directly (Ok or Err).
|
||||
if !is_retryable(&result, retries) {
|
||||
return result;
|
||||
}
|
||||
retries += 1;
|
||||
if retries > max_retries {
|
||||
break;
|
||||
}
|
||||
let backoff = backoff_min.saturating_mul(2u32.pow(retries as u32)).min(backoff_max);
|
||||
sleep(backoff).await;
|
||||
}
|
||||
Err(anyhow!("Max retries reached"))
|
||||
}
|
||||
|
||||
pub async fn retry_multipart_request<F, Fut>(
|
||||
mut build_request: F,
|
||||
max_retries: usize,
|
||||
backoff_min: Duration,
|
||||
backoff_max: Duration,
|
||||
) -> Result<Response, Error>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: Future<Output = RequestBuilder>,
|
||||
{
|
||||
retry_with_backoff(
|
||||
// 1) operation: build + send
|
||||
move || {
|
||||
let fut = build_request();
|
||||
async move {
|
||||
let rb = fut.await;
|
||||
rb.send().await.map_err(Error::from)
|
||||
}
|
||||
},
|
||||
// 2) same retry logic
|
||||
|res: &Result<_, Error>, _attempt| match res {
|
||||
Ok(resp)
|
||||
if matches!(
|
||||
resp.status(),
|
||||
StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE
|
||||
| StatusCode::GATEWAY_TIMEOUT
|
||||
) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
Err(_) => true,
|
||||
_ => false,
|
||||
},
|
||||
max_retries,
|
||||
backoff_min,
|
||||
backoff_max,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn retry_request(
|
||||
request_builder: RequestBuilder,
|
||||
max_retries: u32,
|
||||
backoff_min: Duration,
|
||||
backoff_max: Duration,
|
||||
) -> Result<Response, Error> {
|
||||
retry_with_backoff(
|
||||
// 1) operation: clone + send, yielding Result<Response, Error>
|
||||
move || {
|
||||
let rb =
|
||||
request_builder.try_clone().expect("retry_request: failed to clone RequestBuilder");
|
||||
async move { rb.send().await.map_err(Error::from) }
|
||||
},
|
||||
// 2) is_retryable: transient HTTP status or network error
|
||||
|res: &Result<_, Error>, _attempt| match res {
|
||||
Ok(resp)
|
||||
if matches!(
|
||||
resp.status(),
|
||||
StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE
|
||||
| StatusCode::GATEWAY_TIMEOUT
|
||||
) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
Err(_) => true,
|
||||
_ => false,
|
||||
},
|
||||
max_retries as usize,
|
||||
backoff_min,
|
||||
backoff_max,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Return `true` when the body is very likely HTML.
|
||||
fn body_looks_like_html(body: &str, headers: &HeaderMap) -> bool {
|
||||
// ---- 1. header heuristic ---------------------------------------------
|
||||
let header_says_html = headers
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|ct| {
|
||||
let ct = ct.to_ascii_lowercase();
|
||||
ct.contains("text/html") || ct.contains("application/xhtml")
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// ---- 2. early-body scan (<=1024 bytes) --------------------------------
|
||||
// Find the last character boundary at or before 1024 bytes to avoid UTF-8 boundary issues
|
||||
// Walk backward at most 3 bytes (UTF-8 max char size is 4 bytes) to find valid boundary
|
||||
let mut end = 1024.min(body.len());
|
||||
while end > 0 && !body.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
let probe = &body[..end];
|
||||
// Trim any leading whitespace so we still catch HTML that starts after newlines/indentation.
|
||||
let trimmed = probe.trim_start_matches(|c: char| c.is_whitespace());
|
||||
let probe = trimmed.to_ascii_lowercase();
|
||||
let body_looks_htmlish = probe.starts_with('<') && probe.contains("<html");
|
||||
|
||||
// ⇒ Only HTML if **both** header and body agree
|
||||
header_says_html && body_looks_htmlish
|
||||
}
|
||||
|
||||
/// Validate the response by checking word and status matchers.
|
||||
pub fn validate_response(
|
||||
matchers: &[ResponseMatcher],
|
||||
body: &str,
|
||||
status: &StatusCode,
|
||||
headers: &HeaderMap,
|
||||
html_allowed: bool,
|
||||
) -> bool {
|
||||
// Since match_all_types is always true here, we simply require all word and status conditions
|
||||
// to hold.
|
||||
let word_ok = matchers
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
if let ResponseMatcher::WordMatch { words, match_all_words, negative, .. } = m {
|
||||
let raw = if *match_all_words {
|
||||
words.iter().all(|w| body.contains(w))
|
||||
} else {
|
||||
words.iter().any(|w| body.contains(w))
|
||||
};
|
||||
Some(if *negative { !raw } else { raw })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.all(|b| b);
|
||||
|
||||
let status_ok = matchers
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
if let ResponseMatcher::StatusMatch {
|
||||
status: expected,
|
||||
match_all_status,
|
||||
negative,
|
||||
..
|
||||
} = m
|
||||
{
|
||||
let raw = if *match_all_status {
|
||||
expected.iter().all(|s| s.to_string() == status.as_str())
|
||||
} else {
|
||||
expected.iter().any(|s| s.to_string() == status.as_str())
|
||||
};
|
||||
Some(if *negative { !raw } else { raw })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.all(|b| b);
|
||||
|
||||
// ── Header checks ──────────────────────────────────────────
|
||||
let header_ok = matchers
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
if let ResponseMatcher::HeaderMatch { header, expected, match_all_values, .. } = m {
|
||||
// header names are case-insensitive
|
||||
let val = headers
|
||||
.get(header)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
Some(if *match_all_values {
|
||||
expected.iter().all(|e| val.contains(&e.to_ascii_lowercase()))
|
||||
} else {
|
||||
expected.iter().any(|e| val.contains(&e.to_ascii_lowercase()))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.all(|b| b);
|
||||
|
||||
// ----- JsonValid ----------------------------------------------------------
|
||||
let json_ok = matchers
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
if matches!(m, ResponseMatcher::JsonValid { .. }) {
|
||||
Some(serde_json::from_str::<serde_json::Value>(body).is_ok())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.all(|b| b);
|
||||
|
||||
let xml_ok = matchers
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
if matches!(m, ResponseMatcher::XmlValid { .. }) {
|
||||
// succeeds if `body` is well-formed XML
|
||||
Some(xml_from_str::<IgnoredAny>(body).is_ok())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.all(|b| b);
|
||||
|
||||
let html_detected = body_looks_like_html(body, headers);
|
||||
let html_ok = html_allowed || !html_detected;
|
||||
|
||||
// // ── debug line ─-
|
||||
// debug!(
|
||||
// "validate_response -- word:{}, status:{}, header:{}, json:{}, xml:{} ⇒ {}",
|
||||
// word_ok, status_ok, header_ok, json_ok, xml_ok, all_ok
|
||||
// );
|
||||
// // ──────────────────────────────────────────────────────────────
|
||||
|
||||
let all_ok = word_ok && status_ok && header_ok && json_ok && xml_ok && html_ok;
|
||||
all_ok
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Once;
|
||||
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
static INIT: Once = Once::new();
|
||||
fn init() {
|
||||
INIT.call_once(|| {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request_builder() {
|
||||
init();
|
||||
let client = Client::builder()
|
||||
.gzip(true) // enable gzip
|
||||
.deflate(true) // enable deflate
|
||||
.brotli(true) // enable brotli
|
||||
.build()
|
||||
.expect("building reqwest client");
|
||||
let parser = liquid::ParserBuilder::with_stdlib().build().unwrap();
|
||||
let globals = liquid::Object::new();
|
||||
let headers = BTreeMap::from([
|
||||
("Content-Type".to_string(), "application/json".to_string()),
|
||||
("Accept".to_string(), "application/custom".to_string()),
|
||||
]);
|
||||
let url = Url::from_str("https://example.com").unwrap();
|
||||
let result = build_request_builder(
|
||||
&client,
|
||||
"GET",
|
||||
&url,
|
||||
&headers,
|
||||
&None,
|
||||
Duration::from_secs(10),
|
||||
&parser,
|
||||
&globals,
|
||||
)
|
||||
.expect("building request");
|
||||
let req = result.build().expect("finalizing request");
|
||||
assert_eq!(
|
||||
req.headers().get(header::ACCEPT).and_then(|v| v.to_str().ok()),
|
||||
Some("application/custom"),
|
||||
);
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn test_retry_request() {
|
||||
init();
|
||||
let mock_server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/test"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
let client = Client::builder()
|
||||
.gzip(true) // enable gzip
|
||||
.deflate(true) // enable deflate
|
||||
.brotli(true) // enable brotli
|
||||
.build()
|
||||
.expect("building reqwest client");
|
||||
let request_builder = client.get(&format!("{}/test", mock_server.uri()));
|
||||
let response = retry_request(
|
||||
request_builder,
|
||||
3,
|
||||
Duration::from_millis(50),
|
||||
Duration::from_millis(200),
|
||||
)
|
||||
.await;
|
||||
assert!(response.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_validate_response() {
|
||||
// --- arrange ----------------------------------------------------------
|
||||
let matchers = vec![ResponseMatcher::WordMatch {
|
||||
r#type: "word-match".to_string(),
|
||||
words: vec!["test".to_string()],
|
||||
match_all_words: true,
|
||||
negative: false,
|
||||
}];
|
||||
let status = StatusCode::OK;
|
||||
let body = "This is a test";
|
||||
let headers = HeaderMap::new(); // empty header map
|
||||
let html_allowed = false;
|
||||
|
||||
// --- act --------------------------------------------------------------
|
||||
let result = validate_response(&matchers, body, &status, &headers, html_allowed);
|
||||
|
||||
// --- assert -----------------------------------------------------------
|
||||
assert!(result);
|
||||
}
|
||||
#[test]
|
||||
fn test_validate_response_slack_webhook() {
|
||||
// Build matchers equivalent to rule kingfisher.slack.4
|
||||
let matchers = vec![
|
||||
ResponseMatcher::WordMatch {
|
||||
r#type: "word-match".to_string(),
|
||||
words: vec!["invalid_payload".to_string()],
|
||||
match_all_words: false, // rule omits this → default is false
|
||||
negative: false,
|
||||
},
|
||||
ResponseMatcher::WordMatch {
|
||||
r#type: "word-match".to_string(),
|
||||
words: vec!["invalid_token".to_string()],
|
||||
match_all_words: false,
|
||||
negative: true, // body must *not* contain “invalid_token”
|
||||
},
|
||||
];
|
||||
|
||||
// Simulate the real Slack response you posted
|
||||
let body = "invalid_payload";
|
||||
let status = StatusCode::BAD_REQUEST; // 400
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
|
||||
|
||||
// Call validate_response with html_allowed = false
|
||||
let ok = validate_response(&matchers, body, &status, &headers, false);
|
||||
|
||||
// 4It *should* be valid (true) because all matcher conditions hold
|
||||
assert!(ok, "Slack webhook response should be considered ACTIVE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_body_looks_like_html_trims_whitespace() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html; charset=utf-8"));
|
||||
|
||||
let body = "\n\n \n<!DOCTYPE html>\n<html lang=\"en\"><body>page</body></html>";
|
||||
|
||||
assert!(body_looks_like_html(body, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_response_rejected_when_not_allowed() {
|
||||
let matchers = vec![ResponseMatcher::StatusMatch {
|
||||
r#type: "status-match".to_string(),
|
||||
status: vec![StatusCode::OK.into()],
|
||||
match_all_status: false,
|
||||
negative: false,
|
||||
}];
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html; charset=utf-8"));
|
||||
|
||||
let body = "\n<html><body>Sign in</body></html>";
|
||||
|
||||
let ok = validate_response(&matchers, body, &StatusCode::OK, &headers, false);
|
||||
|
||||
assert!(!ok, "HTML responses should be rejected unless explicitly allowed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_body_looks_like_html_utf8_boundary() {
|
||||
// Test case for UTF-8 boundary issue: multi-byte character at 1024-byte boundary
|
||||
// This reproduces the bug where slicing at byte 1024 would panic if it's in the middle
|
||||
// of a multi-byte character (e.g., Chinese character '业' spans bytes 1023..1026)
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html; charset=utf-8"));
|
||||
|
||||
// HTML at the start, with padding to push a multi-byte char to byte 1024
|
||||
// This mirrors the real crash: HTML response from Gitee with Chinese chars
|
||||
let html_start = "<!DOCTYPE html><html lang=\"zh-CN\"><head><title>";
|
||||
let padding_len = 1023 - html_start.len();
|
||||
let body = format!(
|
||||
"{}{}业</title></head><body>Gitee</body></html>",
|
||||
html_start,
|
||||
"x".repeat(padding_len)
|
||||
);
|
||||
|
||||
// Verify our test setup: multi-byte char should be at byte 1023
|
||||
assert_eq!(body.as_bytes()[1023], 0xE4, "Expected first byte of '业' at position 1023");
|
||||
|
||||
// This should not panic AND should correctly identify HTML
|
||||
let result = body_looks_like_html(&body, &headers);
|
||||
assert!(
|
||||
result,
|
||||
"Should correctly identify HTML even with multi-byte characters at boundary"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_includes_body() {
|
||||
let url = Url::from_str("https://example.com/api").unwrap();
|
||||
let headers =
|
||||
BTreeMap::from([("Content-Type".to_string(), "application/json".to_string())]);
|
||||
|
||||
// Same method, url, headers but different bodies should produce different cache keys
|
||||
let key_no_body = generate_http_cache_key_parts("POST", &url, &headers, None);
|
||||
let key_body_a =
|
||||
generate_http_cache_key_parts("POST", &url, &headers, Some(r#"{"value": "abc"}"#));
|
||||
let key_body_b =
|
||||
generate_http_cache_key_parts("POST", &url, &headers, Some(r#"{"value": "xyz"}"#));
|
||||
|
||||
// All three should be different
|
||||
assert_ne!(
|
||||
key_no_body, key_body_a,
|
||||
"Cache key with body should differ from key without body"
|
||||
);
|
||||
assert_ne!(
|
||||
key_no_body, key_body_b,
|
||||
"Cache key with body should differ from key without body"
|
||||
);
|
||||
assert_ne!(key_body_a, key_body_b, "Cache keys with different bodies should be different");
|
||||
|
||||
// Same body should produce same key
|
||||
let key_body_a_dup =
|
||||
generate_http_cache_key_parts("POST", &url, &headers, Some(r#"{"value": "abc"}"#));
|
||||
assert_eq!(key_body_a, key_body_a_dup, "Same inputs should produce same cache key");
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue