kingfisher/src/grpc_validation.rs
Mick Grove 4a74e95756 v1.81.0
2026-02-10 19:43:34 -08:00

214 lines
7.5 KiB
Rust

use std::{collections::BTreeMap, sync::Arc, time::Duration};
use anyhow::{anyhow, Context, Result};
use bytes::Bytes;
use h2::client;
use http::{header::HeaderName, HeaderMap, HeaderValue, Request, Uri};
use liquid::Object;
use once_cell::sync::OnceCell;
use reqwest::Url;
use rustls::{ClientConfig, RootCertStore};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
/// Result of a gRPC unary call over HTTP/2.
pub struct GrpcCallResult {
pub http_status: http::StatusCode,
/// Response headers + trailers merged into one map.
pub headers: HeaderMap,
pub body_bytes: Vec<u8>,
}
fn build_root_store() -> Result<RootCertStore> {
let mut roots = RootCertStore::empty();
let native = rustls_native_certs::load_native_certs();
if !native.errors.is_empty() {
// Best-effort: still proceed if we got any certs.
// (Some platforms may have a few unparsable roots.)
}
for cert in native.certs {
roots.add(cert).map_err(|e| anyhow!("Failed to add native root cert: {e:?}"))?;
}
Ok(roots)
}
fn cached_h2_tls_config() -> Result<Arc<ClientConfig>> {
static TLS_CONFIG: OnceCell<Arc<ClientConfig>> = OnceCell::new();
let cfg = TLS_CONFIG.get_or_try_init(|| -> Result<Arc<ClientConfig>> {
// Loading native roots can be relatively expensive; do it once and reuse.
let mut cfg = ClientConfig::builder()
.with_root_certificates(build_root_store()?)
.with_no_client_auth();
cfg.alpn_protocols = vec![b"h2".to_vec()];
Ok(Arc::new(cfg))
})?;
Ok(Arc::clone(cfg))
}
fn url_to_h2_uri(url: &Url) -> Result<Uri> {
let scheme = url.scheme();
if scheme != "https" {
return Err(anyhow!("gRPC validation only supports https URLs, got: {scheme}"));
}
let host = url.host_str().ok_or_else(|| anyhow!("URL is missing host: {url}"))?;
let authority = match url.port() {
Some(p) => format!("{host}:{p}"),
None => host.to_string(),
};
let path_and_query = &url[url::Position::BeforePath..];
Uri::builder()
.scheme("https")
.authority(authority)
.path_and_query(path_and_query)
.build()
.context("Failed to build HTTP/2 URI for gRPC request")
}
fn header_map_from_templates(
templates: &BTreeMap<String, String>,
parser: &liquid::Parser,
globals: &Object,
) -> Result<HeaderMap> {
let mut out = HeaderMap::new();
for (k, v_template) in templates {
// Header names in YAML are expected to be static.
let name = HeaderName::from_bytes(k.as_bytes())
.with_context(|| format!("Invalid header name in GrpcValidation: '{k}'"))?;
let tmpl = parser
.parse(v_template)
.map_err(|e| anyhow!("Failed to parse header template '{k}': {e}"))?;
let rendered = tmpl
.render(globals)
.map_err(|e| anyhow!("Failed to render header template '{k}': {e}"))?;
let value = HeaderValue::from_str(&rendered)
.with_context(|| format!("Invalid header value for '{k}'"))?;
out.append(name, value);
}
Ok(out)
}
/// Execute a single unary gRPC request over HTTP/2 and return headers + trailers.
///
/// This is intentionally low-level so that rules can validate gRPC-only APIs
/// without pretending they are REST endpoints.
pub async fn grpc_unary_call(
url: &Url,
headers: HeaderMap,
body: Vec<u8>,
timeout: Duration,
) -> Result<GrpcCallResult> {
let host = url.host_str().ok_or_else(|| anyhow!("URL is missing host: {url}"))?;
let port = url.port_or_known_default().unwrap_or(443);
let addr = format!("{host}:{port}");
let tcp = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.context("Timed out connecting to gRPC host")?
.context("Failed to connect to gRPC host")?;
let connector = TlsConnector::from(cached_h2_tls_config()?);
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
.map_err(|_| anyhow!("Invalid TLS server name: {host}"))?;
let tls = tokio::time::timeout(timeout, connector.connect(server_name, tcp))
.await
.context("Timed out during TLS handshake")?
.context("TLS handshake failed")?;
let (mut h2_client, connection) = tokio::time::timeout(timeout, client::handshake(tls))
.await
.context("Timed out during HTTP/2 handshake")?
.context("HTTP/2 handshake failed")?;
// Drive the HTTP/2 connection in the background.
tokio::spawn(async move {
let _ = connection.await;
});
let uri = url_to_h2_uri(url)?;
let mut req_builder = Request::builder().method("POST").uri(uri);
{
let hdrs = req_builder.headers_mut().expect("headers_mut should exist");
for (k, v) in headers.iter() {
hdrs.append(k, v.clone());
}
}
let request = req_builder.body(()).context("Failed to build HTTP/2 request")?;
let (response_future, mut send_stream) =
h2_client.send_request(request, false).context("Failed to send gRPC request headers")?;
// Send gRPC request bytes (including the 5-byte gRPC frame prefix).
send_stream.send_data(Bytes::from(body), true).context("Failed to send gRPC request body")?;
let response = tokio::time::timeout(timeout, response_future)
.await
.context("Timed out waiting for gRPC response headers")?
.context("Failed to receive gRPC response headers")?;
let http_status = response.status();
let (parts, mut recv_stream) = response.into_parts();
let mut merged_headers = parts.headers;
// Read data frames (may be empty).
let mut body_bytes: Vec<u8> = Vec::new();
loop {
// h2 returns `Option<Result<Bytes, h2::Error>>` here:
// - None => end of stream
// - Some(Ok(bytes)) => a data chunk
// - Some(Err(err)) => stream error
let next_opt = tokio::time::timeout(timeout, recv_stream.data())
.await
.context("Timed out reading gRPC response data")?;
match next_opt {
Some(Ok(b)) => body_bytes.extend_from_slice(b.as_ref()),
Some(Err(e)) => return Err(anyhow!("Error reading gRPC response data: {e}")),
None => break,
}
}
// Read trailers (where grpc-status is typically reported).
if let Some(trailers) = tokio::time::timeout(timeout, recv_stream.trailers())
.await
.context("Timed out reading gRPC response trailers")?
.context("Error reading gRPC response trailers")?
{
for (k, v) in trailers.iter() {
merged_headers.append(k, v.clone());
}
}
Ok(GrpcCallResult { http_status, headers: merged_headers, body_bytes })
}
/// Helper to render & execute a gRPC request from rule templates.
pub async fn grpc_unary_call_from_rule(
url: &Url,
header_templates: &BTreeMap<String, String>,
body_template: &Option<String>,
parser: &liquid::Parser,
globals: &Object,
timeout: Duration,
) -> Result<GrpcCallResult> {
let headers = header_map_from_templates(header_templates, parser, globals)?;
let body = match body_template {
Some(t) => {
let tmpl =
parser.parse(t).map_err(|e| anyhow!("Failed to parse gRPC body template: {e}"))?;
let rendered = tmpl
.render(globals)
.map_err(|e| anyhow!("Failed to render gRPC body template: {e}"))?;
rendered.into_bytes()
}
None => Vec::new(),
};
grpc_unary_call(url, headers, body, timeout).await
}