kingfisher/src/util.rs
2025-11-11 13:24:06 -08:00

278 lines
9.6 KiB
Rust

use std::{
borrow::Cow,
fs::File,
io::{stdin, stdout, BufReader, BufWriter},
path::Path,
sync::atomic::{AtomicBool, Ordering},
};
use blake3::Hasher;
use dashmap::DashSet;
use once_cell::sync::Lazy;
use path_dedot::ParseDot;
use ring::rand::{SecureRandom, SystemRandom};
// Generate a random salt once and use it for the entire application runtime
static APP_SALT: Lazy<String> = Lazy::new(|| generate_salt());
static REDACTION_ENABLED: AtomicBool = AtomicBool::new(false);
/// Interns a string once and returns a `'static` reference to it.
pub fn intern(s: &str) -> &'static str {
static INTERN: Lazy<DashSet<&'static str>> = Lazy::new(|| DashSet::with_capacity(512));
// Fast path: string already interned?
if let Some(existing) = INTERN.get(s) {
return *existing;
}
// Slow path: allocate one new copy for eternity.
let static_str: &'static str = Box::leak(s.to_owned().into_boxed_str());
INTERN.insert(static_str);
static_str
}
pub fn is_safe_path(path: &Path) -> std::io::Result<bool> {
Ok(path
.parse_dot()
.map(|p| !p.components().any(|c| matches!(c, std::path::Component::ParentDir)))
.unwrap_or(false))
}
pub fn redact_value(value: &str) -> String {
let mut hasher = Hasher::new();
hasher.update(APP_SALT.as_bytes());
hasher.update(value.as_bytes());
let hash = hasher.finalize();
format!("[REDACTED:{}]", hash_to_short_id(&hash))
}
/// Enables or disables global output redaction.
pub fn set_redaction_enabled(enabled: bool) {
REDACTION_ENABLED.store(enabled, Ordering::Relaxed);
}
/// Returns true if redaction is enabled for user-facing output.
pub fn redaction_enabled() -> bool {
REDACTION_ENABLED.load(Ordering::Relaxed)
}
/// Returns either the original value or a redacted placeholder depending on
/// the current redaction setting.
pub fn display_value(value: &'static str) -> Cow<'static, str> {
if redaction_enabled() {
Cow::Owned(redact_value(value))
} else {
Cow::Borrowed(value)
}
}
// Generate a random salt (16-character alphanumeric string)
fn generate_salt() -> String {
let rng = SystemRandom::new();
let mut bytes = [0u8; 16];
rng.fill(&mut bytes).unwrap();
hex::encode(bytes)
}
// Convert full hash to shorter identifier
fn hash_to_short_id(hash: &blake3::Hash) -> String {
hash.to_hex().chars().take(8).collect()
}
/// Represents a countable item with properly pluralized log messages.
pub enum Counted<'a> {
Regular { singular: &'a str, count: usize },
Explicit { singular: &'a str, count: usize, plural: &'a str },
}
impl<'a> Counted<'a> {
/// Creates a `Counted` with explicit singular and plural forms.
pub fn new(count: usize, singular: &'a str, plural: &'a str) -> Self {
Counted::Explicit { singular, plural, count }
}
/// Creates a `Counted` with a singular form, automatically pluralizing by
/// adding "s".
pub fn regular(count: usize, singular: &'a str) -> Self {
Counted::Regular { singular, count }
}
}
impl<'a> std::fmt::Display for Counted<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Counted::Explicit { singular, plural, count } => {
write!(f, "{} {}", count, if *count == 1 { singular } else { plural })
}
Counted::Regular { singular, count } => {
write!(f, "{} {}{}", count, singular, if *count == 1 { "" } else { "s" })
}
}
}
}
/// Returns a buffered writer for a specified file path or stdout if none is
/// provided.
pub fn get_writer_for_file_or_stdout<P: AsRef<Path>>(
path: Option<P>,
) -> std::io::Result<Box<dyn std::io::Write>> {
match path {
None => Ok(Box::new(BufWriter::new(stdout()))),
Some(p) => Ok(Box::new(BufWriter::new(File::create(p)?))),
}
}
/// Returns a buffered reader for a specified file path or stdin if none is
/// provided.
pub fn get_reader_for_file_or_stdin<P: AsRef<Path>>(
path: Option<P>,
) -> std::io::Result<Box<dyn std::io::Read>> {
match path {
None => Ok(Box::new(BufReader::new(stdin()))),
Some(p) => Ok(Box::new(BufReader::new(File::open(p)?))),
}
}
/// Determines whether the input string is valid Base64.
pub fn is_base64(input: &str) -> bool {
input.len() % 4 == 0
&& input
.bytes()
.all(|b| matches!(b, b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'+' | b'/' | b'='))
}
/// Heuristic check whether a path points to test files or directories.
///
/// Looks for common substrings like "test", "tests", "spec", "fixture", or
/// "example" in any path component. Case-insensitive.
pub fn is_test_like_path(path: &Path) -> bool {
path.components().any(|c| {
if let std::path::Component::Normal(os) = c {
if let Some(name) = os.to_str() {
let name = name.to_ascii_lowercase();
return name.contains("test")
|| name.contains("spec")
|| name.contains("fixture")
|| name.contains("example")
|| name.contains("sample");
}
}
false
})
}
#[cfg(test)]
mod tests {
use std::{
io::{Cursor, Read, Write},
path::PathBuf,
};
use super::{is_test_like_path, *};
/// Paths that **should** be classified as test-like.
#[test]
fn test_is_test_like_path_positive() {
let positives = [
"src/tests/helpers.rs",
"/project/spec/controllers/user_spec.rb",
"C:\\repo\\fixtures\\config.json",
"examples/hello_world/main.go",
"/home/user/scripts/local-testCert.pem",
"samples/data/sample_input.txt",
];
for p in positives {
assert!(
is_test_like_path(Path::new(p)),
"Path {p:?} was expected to be test-like but was not"
);
}
}
/// Paths that **should not** be classified as test-like.
#[test]
fn test_is_test_like_path_negative() {
let negatives = [
"src/main.rs",
"/opt/service/config/production.yml",
"C:\\Program Files\\app\\README.md",
"docs/architecture/overview.md",
"assets/images/logo.png",
];
for p in negatives {
assert!(
!is_test_like_path(Path::new(p)),
"Path {p:?} was incorrectly classified as test-like"
);
}
}
#[test]
fn test_counted_display_regular() {
let single = Counted::regular(1, "rule");
let multiple = Counted::regular(3, "rule");
assert_eq!(format!("{}", single), "1 rule");
assert_eq!(format!("{}", multiple), "3 rules");
}
#[test]
fn test_counted_display_explicit() {
let single = Counted::new(1, "person", "people");
let multiple = Counted::new(5, "person", "people");
assert_eq!(format!("{}", single), "1 person");
assert_eq!(format!("{}", multiple), "5 people");
}
#[test]
fn test_get_writer_for_file_or_stdout_stdout() {
use std::io::Write;
// Test writing to stdout
let mut writer = get_writer_for_file_or_stdout::<PathBuf>(None).unwrap();
// Write a test string to ensure it's writing to stdout without errors
let result = writer.write(b"Test output to stdout\n");
assert!(result.is_ok(), "Failed to write to stdout");
}
#[test]
fn test_get_writer_for_file_or_stdout_file() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
// Test writing to a file
let mut writer = get_writer_for_file_or_stdout(Some(&path)).unwrap();
writer.write_all(b"Test content").unwrap();
writer.flush().unwrap();
// Verify file content
let mut file_content = String::new();
std::fs::File::open(&path).unwrap().read_to_string(&mut file_content).unwrap();
assert_eq!(file_content, "Test content");
}
#[test]
fn test_get_reader_for_file_or_stdin_stdin() {
// Test reading from stdin (mocked)
let input = b"stdin test content";
let mut stdin_mock = Cursor::new(input);
let mut reader = BufReader::new(&mut stdin_mock);
let mut buffer = String::new();
reader.read_to_string(&mut buffer).unwrap();
assert_eq!(buffer, "stdin test content");
}
#[test]
fn test_get_reader_for_file_or_stdin_file() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
std::fs::write(&path, "File test content").unwrap();
// Test reading from a file
let mut reader = get_reader_for_file_or_stdin(Some(&path)).unwrap();
let mut buffer = String::new();
reader.read_to_string(&mut buffer).unwrap();
assert_eq!(buffer, "File test content");
}
#[test]
fn test_is_base64_valid() {
let valid_base64 = "SGVsbG8gV29ybGQh"; // "Hello World!" in Base64
let valid_base64_with_padding = "SGVsbG8gdGhpcyB3b3JsZAo=";
let valid_empty = "";
assert!(is_base64(valid_base64));
assert!(is_base64(valid_base64_with_padding));
assert!(is_base64(valid_empty));
}
#[test]
fn test_is_base64_invalid() {
let invalid_base64 = "Hello World!";
let invalid_length = "SGVsbG8"; // Not divisible by 4
let invalid_characters = "SGVsbG8$V29ybGQh";
assert!(!is_base64(invalid_base64));
assert!(!is_base64(invalid_length));
assert!(!is_base64(invalid_characters));
}
}