diff --git a/src/scanner/validation.rs b/src/scanner/validation.rs index fbdb7d2..11b4abe 100644 --- a/src/scanner/validation.rs +++ b/src/scanner/validation.rs @@ -1,4 +1,6 @@ use std::{ + future::Future, + panic::AssertUnwindSafe, sync::{ Arc, Mutex, atomic::{AtomicUsize, Ordering}, @@ -901,25 +903,44 @@ async fn validate_single( // Perform validation let outcome = timeout( validation_timeout, - validate_single_match( - om, - parser, - clients, - dep_vars, - missing_deps, - cache2, - validation_timeout, - validation_retries, - rate_limiter, - provider_endpoints.as_ref(), - max_body_len, - ) - .boxed(), + catch_validation_panic( + validate_single_match( + om, + parser, + clients, + dep_vars, + missing_deps, + cache2, + validation_timeout, + validation_retries, + rate_limiter, + provider_endpoints.as_ref(), + max_body_len, + ) + .boxed(), + ), ) .await; - // Store result in cache + apply_validation_outcome(om, &cache_key, outcome, success_count, fail_count, cache); + maybe_record_access_map(om, access_map); + // Remove from `in_progress` + // in_progress.remove(&cache_key); + in_progress.remove(&cache_key); + if let Some(n) = NOTIFY.remove(&cache_key) { + n.1.notify_waiters(); // wake everyone + } +} + +fn apply_validation_outcome( + om: &mut OwnedBlobMatch, + cache_key: &str, + outcome: std::result::Result, tokio::time::error::Elapsed>, + success_count: &AtomicUsize, + fail_count: &AtomicUsize, + cache: &DashMap, +) { match outcome { - Ok(_) => { + Ok(Ok(())) => { if om.validation_success && is_counted_validation_status(om.validation_response_status) { success_count.fetch_add(1, Ordering::Relaxed); @@ -927,7 +948,26 @@ async fn validate_single( fail_count.fetch_add(1, Ordering::Relaxed); } cache.insert( - cache_key.clone(), + cache_key.to_owned(), + CachedResponse { + is_valid: om.validation_success, + status: om.validation_response_status, + body: om.validation_response_body.clone(), + timestamp: Instant::now(), + }, + ); + } + Ok(Err(panic_message)) => { + om.validation_success = false; + om.validation_response_body = validation_body::from_string(format!( + "Validation panicked for rule {}: {}", + om.rule.id(), + panic_message + )); + om.validation_response_status = http::StatusCode::INTERNAL_SERVER_ERROR; + fail_count.fetch_add(1, Ordering::Relaxed); + cache.insert( + cache_key.to_owned(), CachedResponse { is_valid: om.validation_success, status: om.validation_response_status, @@ -943,19 +983,32 @@ async fn validate_single( fail_count.fetch_add(1, Ordering::Relaxed); } } - maybe_record_access_map(om, access_map); - // Remove from `in_progress` - // in_progress.remove(&cache_key); - in_progress.remove(&cache_key); - if let Some(n) = NOTIFY.remove(&cache_key) { - n.1.notify_waiters(); // wake everyone - } } fn is_counted_validation_status(status: StatusCode) -> bool { !matches!(status, StatusCode::CONTINUE | StatusCode::PRECONDITION_REQUIRED) } +async fn catch_validation_panic(future: F) -> std::result::Result<(), String> +where + F: Future, +{ + match AssertUnwindSafe(future).catch_unwind().await { + Ok(()) => Ok(()), + Err(payload) => Err(describe_panic_payload(payload)), + } +} + +fn describe_panic_payload(payload: Box) -> String { + if let Some(message) = payload.downcast_ref::<&str>() { + (*message).to_string() + } else if let Some(message) = payload.downcast_ref::() { + message.clone() + } else { + "non-string panic payload".to_string() + } +} + // Helper to compute the cache key for an OwnedBlobMatch. fn build_cache_key(om: &OwnedBlobMatch) -> String { let capture0 = om.captures.captures.get(0).map_or(String::new(), |c| c.raw_value().to_string()); @@ -1553,6 +1606,53 @@ fn extract_azure_devops_org_from_body( #[cfg(test)] mod tests { use super::*; + use crate::{ + blob::BlobId, + matcher::{OwnedBlobMatch, SerializableCapture, SerializableCaptures}, + rules::rule::{Confidence, Rule, RuleSyntax}, + util::intern, + }; + use smallvec::smallvec; + use std::sync::Arc; + + fn make_owned_blob_match() -> OwnedBlobMatch { + OwnedBlobMatch { + rule: Arc::new(Rule::new(RuleSyntax { + name: "panic-test".to_string(), + id: "test.panic".to_string(), + pattern: "panic".to_string(), + min_entropy: 0.0, + confidence: Confidence::Low, + visible: true, + examples: vec![], + negative_examples: vec![], + references: vec![], + validation: None, + revocation: None, + depends_on_rule: vec![], + pattern_requirements: None, + tls_mode: None, + })), + blob_id: BlobId::new(b"panic-test-blob"), + finding_fingerprint: 1, + matching_input_offset_span: OffsetSpan { start: 0, end: 5 }, + captures: SerializableCaptures { + captures: smallvec![SerializableCapture { + name: None, + match_number: 0, + start: 0, + end: 5, + value: intern("panic"), + }], + }, + validation_response_body: None, + validation_response_status: StatusCode::CONTINUE, + validation_success: false, + calculated_entropy: 0.0, + is_base64: false, + dependent_captures: std::collections::BTreeMap::new(), + } + } #[test] fn counted_validation_status_excludes_skipped_statuses() { @@ -1604,4 +1704,47 @@ mod tests { other => panic!("unexpected request: {other:?}"), } } + + #[tokio::test] + async fn catch_validation_panic_returns_panic_message() { + let result = catch_validation_panic(async { + panic!("validator blew up"); + }) + .await; + + assert_eq!(result.unwrap_err(), "validator blew up"); + } + + #[tokio::test] + async fn panic_outcome_is_reported_as_failure_and_cached() { + let mut om = make_owned_blob_match(); + let cache_key = build_cache_key(&om); + let cache = DashMap::new(); + let success_count = AtomicUsize::new(0); + let fail_count = AtomicUsize::new(0); + + let outcome = Ok(catch_validation_panic(async { + panic!("validator blew up"); + }) + .await); + + apply_validation_outcome(&mut om, &cache_key, outcome, &success_count, &fail_count, &cache); + + assert!(!om.validation_success); + assert_eq!(om.validation_response_status, StatusCode::INTERNAL_SERVER_ERROR); + assert!( + validation_body::clone_as_string(&om.validation_response_body) + .contains("Validation panicked for rule test.panic: validator blew up") + ); + assert_eq!(success_count.load(Ordering::Relaxed), 0); + assert_eq!(fail_count.load(Ordering::Relaxed), 1); + + let cached = cache.get(&cache_key).expect("panic result should be cached"); + assert!(!cached.is_valid); + assert_eq!(cached.status, StatusCode::INTERNAL_SERVER_ERROR); + assert!( + validation_body::clone_as_string(&cached.body) + .contains("Validation panicked for rule test.panic: validator blew up") + ); + } }