kingfisher/src/validation_rate_limit.rs

187 lines
5.8 KiB
Rust

use std::{sync::Arc, time::Duration};
use anyhow::{bail, Result};
use dashmap::DashMap;
use tokio::{
sync::Mutex,
time::{sleep_until, Instant},
};
use crate::rules::rule::Validation;
const DEFAULT_BUCKET: &str = "__default__";
#[derive(Clone, Debug)]
pub struct ValidationRateLimiter {
default_rps: Option<f64>,
per_rule: Vec<(String, f64)>,
next_allowed: Arc<DashMap<String, Arc<Mutex<Instant>>>>,
}
impl ValidationRateLimiter {
pub fn from_cli(default_rps: Option<f64>, per_rule: &[String]) -> Result<Option<Self>> {
let default_rps = default_rps.map(validate_rps).transpose()?;
let mut normalized = Vec::with_capacity(per_rule.len());
for item in per_rule {
let (selector, rps) = parse_rule_rps_mapping(item)?;
normalized.push((selector, rps));
}
if default_rps.is_none() && normalized.is_empty() {
return Ok(None);
}
Ok(Some(Self {
default_rps,
per_rule: normalized,
next_allowed: Arc::new(DashMap::new()),
}))
}
pub fn effective_rps(&self, rule_id: &str) -> Option<f64> {
self.effective_limit(rule_id).map(|(_, rps)| rps)
}
pub async fn wait_for_rule(&self, rule_id: &str) {
let Some((bucket, rps)) = self.effective_limit(rule_id) else {
return;
};
let interval = Duration::from_secs_f64(1.0 / rps);
let gate = self
.next_allowed
.entry(bucket)
.or_insert_with(|| Arc::new(Mutex::new(Instant::now())))
.clone();
let mut next_slot = gate.lock().await;
let now = Instant::now();
if *next_slot > now {
sleep_until(*next_slot).await;
}
*next_slot = Instant::now() + interval;
}
fn effective_limit(&self, rule_id: &str) -> Option<(String, f64)> {
let mut best: Option<(&str, f64)> = None;
for (selector, rps) in &self.per_rule {
if selector_matches(rule_id, selector)
&& best.as_ref().is_none_or(|(current, _)| selector.len() > current.len())
{
best = Some((selector.as_str(), *rps));
}
}
if let Some((selector, rps)) = best {
return Some((selector.to_string(), rps));
}
self.default_rps.map(|rps| (DEFAULT_BUCKET.to_string(), rps))
}
}
pub fn parse_rule_rps_mapping(input: &str) -> Result<(String, f64)> {
let (raw_selector, raw_rps) = input
.split_once('=')
.ok_or_else(|| anyhow::anyhow!("Invalid value '{input}'. Expected RULE=RPS"))?;
let selector = normalize_rule_selector(raw_selector)?;
let rps = validate_rps(raw_rps.parse::<f64>().map_err(|_| {
anyhow::anyhow!("Invalid RPS value '{raw_rps}' for selector '{raw_selector}'")
})?)?;
Ok((selector, rps))
}
pub fn normalize_rule_selector(input: &str) -> Result<String> {
let selector = input.trim();
if selector.is_empty() {
bail!("Rule selector cannot be empty");
}
if selector.starts_with("kingfisher.") {
return Ok(selector.to_string());
}
if selector == "kingfisher" {
return Ok("kingfisher".to_string());
}
Ok(format!("kingfisher.{selector}"))
}
fn validate_rps(value: f64) -> Result<f64> {
if !value.is_finite() || value <= 0.0 {
bail!("RPS must be a positive number");
}
Ok(value)
}
fn selector_matches(rule_id: &str, selector: &str) -> bool {
rule_id == selector
|| rule_id
.strip_prefix(selector)
.is_some_and(|suffix| suffix.starts_with('.'))
}
pub fn should_rate_limit_validation(validation: &Validation) -> bool {
!matches!(validation, Validation::Raw(_))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_rule_selector_allows_short_names() {
assert_eq!(normalize_rule_selector("github").unwrap(), "kingfisher.github");
assert_eq!(normalize_rule_selector(" kingfisher.github ").unwrap(), "kingfisher.github");
}
#[test]
fn parse_rule_rps_mapping_parses_rule_and_rate() {
let (selector, rps) = parse_rule_rps_mapping("github=2.5").unwrap();
assert_eq!(selector, "kingfisher.github");
assert_eq!(rps, 2.5);
}
#[test]
fn effective_rps_uses_longest_prefix_match() {
let limiter = ValidationRateLimiter::from_cli(
Some(10.0),
&["github=2".to_string(), "kingfisher.github.1=1".to_string()],
)
.unwrap()
.unwrap();
assert_eq!(limiter.effective_rps("kingfisher.github.1"), Some(1.0));
assert_eq!(limiter.effective_rps("kingfisher.github.9"), Some(2.0));
assert_eq!(limiter.effective_rps("kingfisher.gitlab.1"), Some(10.0));
}
#[tokio::test]
async fn wait_for_rule_spaces_requests_for_same_bucket() {
let limiter = ValidationRateLimiter::from_cli(Some(50.0), &[]).unwrap().unwrap();
limiter.wait_for_rule("kingfisher.github.1").await;
let start = std::time::Instant::now();
limiter.wait_for_rule("kingfisher.github.2").await;
// Allow timing jitter from runtime scheduling while still asserting spacing happened.
assert!(start.elapsed() >= Duration::from_millis(15));
}
#[test]
fn should_rate_limit_non_http_validators() {
assert!(should_rate_limit_validation(&Validation::AWS));
assert!(should_rate_limit_validation(&Validation::GCP));
assert!(should_rate_limit_validation(&Validation::MongoDB));
assert!(should_rate_limit_validation(&Validation::Postgres));
assert!(should_rate_limit_validation(&Validation::Coinbase));
}
#[test]
fn should_skip_rate_limit_for_raw_validation() {
assert!(!should_rate_limit_validation(&Validation::Raw("custom".to_string())));
}
}