diff --git a/src/cli/commands/inputs.rs b/src/cli/commands/inputs.rs index 4f2dc5b..3e9e107 100644 --- a/src/cli/commands/inputs.rs +++ b/src/cli/commands/inputs.rs @@ -169,12 +169,13 @@ pub struct InputSpecifierArgs { #[derive(Args, Debug, Clone)] pub struct ContentFilteringArgs { /// Ignore files larger than the given size in MB - #[arg(long("max-file-size"), default_value_t = 64.0)] + #[arg( + long("max-file-size"), + long("max-filesize"), + default_value_t = 256.0 + )] pub max_file_size_mb: f64, - // /// Use custom path-based ignore rules from the given file(s) - // #[arg(long, short, value_hint = ValueHint::FilePath)] - // pub ignore: Vec, /// Skip any file or directory whose path matches this glob pattern. Multiple /// patterns may be provided by repeating the flag. #[arg(long, value_name = "PATTERN")] @@ -197,7 +198,7 @@ impl ContentFilteringArgs { /// Convert the maximum file size in MB to bytes pub fn max_file_size_bytes(&self) -> Option { if self.max_file_size_mb < 0.0 { - Some(25 * 1024 * 1024) // default 25 MB if negative + Some(256 * 1024 * 1024) // default 256 MB if negative } else { Some((self.max_file_size_mb * 1024.0 * 1024.0) as u64) } diff --git a/src/validation.rs b/src/validation.rs index 8cf97ba..7212da0 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -70,6 +70,7 @@ static IN_FLIGHT: OnceCell>> = OnceCell::new(); pub fn init_validation_caches() { VALIDATION_CACHE.set(DashMap::new()).ok(); IN_FLIGHT.set(DashMap::new()).ok(); + aws::set_aws_validation_concurrency(15); } #[derive(Clone)] @@ -766,16 +767,30 @@ async fn timed_validate_single_match<'a>( return; } - match aws::validate_aws_credentials(&akid, &secret, cache).await { - Ok((ok, arn)) => { + match aws::validate_aws_credentials(&akid, &secret).await { + Ok((ok, msg)) => { m.validation_success = ok; - m.validation_response_body = format!("{} --- ARN: {}", akid, arn); - m.validation_response_status = - if ok { StatusCode::OK } else { StatusCode::UNAUTHORIZED }; - if let Ok(acct) = aws::aws_key_to_account_number(&akid) { - m.validation_response_body - .push_str(&format!(" --- AWS Account Number: {:012}", acct)); + if ok { + m.validation_response_body = format!("{} --- ARN: {}", akid, msg); + m.validation_response_status = StatusCode::OK; + if let Ok(acct) = aws::aws_key_to_account_number(&akid) { + m.validation_response_body + .push_str(&format!(" --- AWS Account Number: {:012}", acct)); + } + } else { + m.validation_response_body = + format!("AWS validation error ({}): {}", akid, msg); + m.validation_response_status = StatusCode::UNAUTHORIZED; } + cache.insert( + cache_key, + CachedResponse { + body: m.validation_response_body.clone(), + status: m.validation_response_status, + is_valid: m.validation_success, + timestamp: Instant::now(), + }, + ); } Err(e) => { m.validation_success = false; @@ -783,15 +798,6 @@ async fn timed_validate_single_match<'a>( m.validation_response_status = StatusCode::BAD_GATEWAY; } } - cache.insert( - cache_key, - CachedResponse { - body: m.validation_response_body.clone(), - status: m.validation_response_status, - is_valid: m.validation_success, - timestamp: Instant::now(), - }, - ); } // ----------------------------------------------------- GCP validator diff --git a/src/validation/aws.rs b/src/validation/aws.rs index 9b6079a..7e1073d 100644 --- a/src/validation/aws.rs +++ b/src/validation/aws.rs @@ -1,9 +1,12 @@ use std::time::Duration; use anyhow::{anyhow, Result}; -use aws_config::BehaviorVersion; +use aws_config::{retry::RetryConfig, BehaviorVersion}; use aws_credential_types::Credentials; -use aws_sdk_sts::{config::Builder as StsConfigBuilder, Client as StsClient}; +use aws_sdk_sts::{ + config::Builder as StsConfigBuilder, error::SdkError, + operation::get_caller_identity::GetCallerIdentityError, Client as StsClient, +}; use aws_smithy_http_client::{ proxy::ProxyConfig, tls, Builder as HttpClientBuilder, ConnectorBuilder, }; @@ -23,10 +26,25 @@ use http::{ header::{HeaderValue, USER_AGENT}, StatusCode, }; +use once_cell::sync::OnceCell; +use rand::{rng, Rng}; +use tokio::{ + sync::Semaphore, + time::{sleep, timeout}, +}; use crate::validation::GLOBAL_USER_AGENT; -use crate::validation::{Cache, CachedResponse, VALIDATION_CACHE_SECONDS}; +static AWS_VALIDATION_SEMAPHORE: OnceCell = OnceCell::new(); + +/// Set the maximum number of concurrent AWS validations. Call before first use. +pub fn set_aws_validation_concurrency(max: usize) { + AWS_VALIDATION_SEMAPHORE.set(Semaphore::new(max)).ok(); +} + +fn aws_validation_semaphore() -> &'static Semaphore { + AWS_VALIDATION_SEMAPHORE.get_or_init(|| Semaphore::new(15)) +} #[derive(Debug)] struct UaInterceptor; @@ -82,19 +100,30 @@ pub fn validate_aws_credentials_input(access_key_id: &str, secret_key: &str) -> Ok(()) } +fn is_throttling_or_transient(e: &SdkError) -> bool { + match e { + SdkError::ServiceError(ctx) => { + let code = ctx.err().meta().code().unwrap_or_default(); + let status: StatusCode = ctx.raw().status().into(); + code.contains("Throttl") + || status == StatusCode::TOO_MANY_REQUESTS + || status == StatusCode::SERVICE_UNAVAILABLE + } + SdkError::DispatchFailure(df) => df.is_timeout() || df.is_io(), + SdkError::ResponseError(ctx) => { + let status: StatusCode = ctx.raw().status().into(); + status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::SERVICE_UNAVAILABLE + } + _ => false, + } +} + pub async fn validate_aws_credentials( aws_access_key_id: &str, aws_secret_access_key: &str, - cache: &Cache, ) -> Result<(bool, String)> { - let cache_key = generate_aws_cache_key(aws_access_key_id, aws_secret_access_key); - // Check cache first - if let Some(cached) = cache.get(&cache_key) { - let cached_response = cached.value(); - if cached_response.timestamp.elapsed() < Duration::from_secs(VALIDATION_CACHE_SECONDS) { - return Ok((cached_response.is_valid, cached_response.body.clone())); - } - } + let _permit = aws_validation_semaphore().acquire().await.expect("semaphore closed"); + // Create static credentials let credentials = Credentials::new( aws_access_key_id, @@ -117,31 +146,50 @@ pub async fn validate_aws_credentials( conn_builder.build() }); - // Create AWS config + // Create AWS config with adaptive retries + let retry_config = RetryConfig::adaptive().with_max_attempts(3); let config = aws_config::defaults(BehaviorVersion::latest()) .region(Region::new("us-east-1")) .credentials_provider(credentials) .http_client(http_client) + .retry_config(retry_config) .load() .await; + // Create STS client let sts_config = StsConfigBuilder::from(&config).interceptor(UaInterceptor).build(); let sts_client = StsClient::from_conf(sts_config); - // Call get-caller-identity - match sts_client.get_caller_identity().send().await { - Ok(identity) => { - let arn = identity.arn.unwrap_or_else(|| "Unknown".to_string()); - // let acct = identity.account.unwrap_or_else(|| "Unknown".to_string()); - let response = CachedResponse::new(arn.clone(), StatusCode::OK, true); - cache.insert(cache_key, response); - Ok((true, arn)) - } - Err(e) => { - let response = CachedResponse::new(e.to_string(), StatusCode::UNAUTHORIZED, false); - cache.insert(cache_key, response); - Err(anyhow!("AWS validation failed: {}", e)) + + const MAX_ATTEMPTS: usize = 3; + const ATTEMPT_TIMEOUT: Duration = Duration::from_secs(5); + + for attempt in 1..=MAX_ATTEMPTS { + let result = timeout(ATTEMPT_TIMEOUT, sts_client.get_caller_identity().send()).await; + match result { + Ok(Ok(identity)) => { + let arn = identity.arn.unwrap_or_else(|| "Unknown".to_string()); + return Ok((true, arn)); + } + Ok(Err(e)) => { + if is_throttling_or_transient(&e) { + if attempt == MAX_ATTEMPTS { + return Err(anyhow!("AWS validation failed: {}", e)); + } + } else { + return Ok((false, e.to_string())); + } + } + Err(_) => { + if attempt == MAX_ATTEMPTS { + return Err(anyhow!("AWS validation timed out")); + } + } } + let max_delay = 100u64 * 2u64.pow((attempt - 1) as u32); + let sleep_ms = rng().random_range(0..=max_delay); + sleep(Duration::from_millis(sleep_ms)).await; } + Err(anyhow!("AWS validation failed")) } /// Converts an AWS Key ID to an AWS Account Number.