#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = ["rich>=14.0.0", "typer>=0.24.0"]
# ///
#MISE description="Sanity-check a .zip file, optionally extract to tmpdir and scan with trivy"
#USAGE arg "<zip_path>" help="Path to the .zip file to inspect"
#USAGE flag "--dangerously-unzip-to-tmp" help="Extract even if warnings are present (errors still block)"
"""Inspect a .zip file for safety, then optionally extract and scan with trivy.

Phase 1 — Structural checks (no extraction):
  - Valid zip format (magic bytes)
  - Path traversal (zip slip) via .. components
  - Symlinks
  - Suspicious file extensions (.exe, .bat, .ps1, .scr, .dll, etc.)
  - Zip bomb indicators (extreme compression ratios, nested zips)
  - Executable bit set on entries
  - Unreasonable file sizes
  - Total uncompressed size

Phase 2 — If clean (or --dangerously-unzip-to-tmp with warnings only):
  - Extract to a temporary directory
  - Run `trivy fs` on the extracted contents
  - Print the tmpdir path for manual inspection
"""

import shutil
import subprocess
import tempfile
import zipfile
from pathlib import Path, PurePosixPath
from typing import Annotated

import typer
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)

# Extensions that could be executable or contain macros
SUSPICIOUS_EXTENSIONS = {
    # Windows executables
    ".exe", ".bat", ".cmd", ".com", ".scr", ".pif", ".msi", ".msp",
    # Scripts
    ".ps1", ".psm1", ".vbs", ".vbe", ".js", ".jse", ".wsf", ".wsh",
    # Libraries / code that can auto-execute
    ".dll", ".sys", ".drv", ".ocx",
    # Office macros
    ".xlsm", ".xlam", ".docm", ".dotm", ".pptm",
    # Shortcuts / links
    ".lnk", ".url", ".desktop",
    # Archives (nested — zip bomb vector)
    ".zip", ".7z", ".rar", ".tar", ".gz", ".bz2", ".xz",
    # Interpreted
    ".py", ".rb", ".pl", ".sh", ".bash", ".fish",
}

# Compression ratio above this is suspicious (zip bomb indicator)
MAX_COMPRESSION_RATIO = 100

# Individual file size warning threshold (500 MB)
MAX_FILE_SIZE = 500 * 1024 * 1024

# Total uncompressed size warning threshold (2 GB)
MAX_TOTAL_SIZE = 2 * 1024 * 1024 * 1024


def check_magic_bytes(path: Path) -> bool:
    """Verify the file starts with a valid ZIP magic number."""
    with open(path, "rb") as f:
        magic = f.read(4)
    valid = {b"PK\x03\x04", b"PK\x05\x06", b"PK\x07\x08"}
    return magic in valid


def check_zip(zip_path: Path) -> tuple[int, list[str], list[str]]:
    """Run structural checks. Returns (exit_code, warnings, errors)."""
    path = zip_path.resolve()

    if not path.exists():
        console.print(f"[red bold]ERROR:[/] File not found: {path}")
        return 1, [], ["File not found"]

    if not path.suffix.lower() == ".zip":
        console.print(f"[yellow]WARNING:[/] File extension is '{path.suffix}', not '.zip'")

    if not check_magic_bytes(path):
        console.print(f"[red bold]FAIL:[/] File does not have valid ZIP magic bytes — possibly not a zip or is corrupted/disguised")
        return 1, [], ["Invalid magic bytes"]

    console.print(f"[green]OK:[/] Valid ZIP magic bytes")

    try:
        zf = zipfile.ZipFile(path, "r")
    except zipfile.BadZipFile as e:
        console.print(f"[red bold]FAIL:[/] Not a valid zip file: {e}")
        return 1, [], [str(e)]

    warnings: list[str] = []
    errors: list[str] = []

    with zf:
        infos = zf.infolist()
        console.print(f"[blue]INFO:[/] {len(infos)} entries, archive size {path.stat().st_size:,} bytes")

        total_uncompressed = 0

        table = Table(title="Archive Contents", show_lines=False)
        table.add_column("Flags", style="bold", width=6)
        table.add_column("Name")
        table.add_column("Size", justify="right")
        table.add_column("Compressed", justify="right")
        table.add_column("Ratio", justify="right")

        for info in infos:
            flags = []
            name = info.filename
            total_uncompressed += info.file_size

            parts = PurePosixPath(name).parts
            if ".." in parts:
                errors.append(f"PATH TRAVERSAL (zip slip): {name}")
                flags.append("[red]!!![/]")

            if name.startswith("/"):
                errors.append(f"ABSOLUTE PATH: {name}")
                flags.append("[red]!!![/]")

            unix_attrs = info.external_attr >> 16
            if unix_attrs and (unix_attrs & 0o120000) == 0o120000:
                errors.append(f"SYMLINK: {name}")
                flags.append("[red]LNK[/]")

            ext = PurePosixPath(name).suffix.lower()
            if ext in SUSPICIOUS_EXTENSIONS:
                warnings.append(f"Suspicious extension: {name}")
                flags.append("[yellow]SUS[/]")

            ratio = 0.0
            if info.compress_size > 0:
                ratio = info.file_size / info.compress_size
                if ratio > MAX_COMPRESSION_RATIO:
                    warnings.append(f"High compression ratio ({ratio:.0f}x): {name}")
                    flags.append("[yellow]BIG[/]")

            if info.file_size > MAX_FILE_SIZE:
                warnings.append(f"Large file ({info.file_size / 1024 / 1024:.0f} MB): {name}")
                flags.append("[yellow]BIG[/]")

            if unix_attrs and (unix_attrs & 0o111):
                warnings.append(f"Executable bit set: {name}")
                flags.append("[yellow]+x[/]")

            flag_str = " ".join(flags) if flags else "[green]ok[/]"
            table.add_row(
                flag_str,
                name,
                f"{info.file_size:,}",
                f"{info.compress_size:,}",
                f"{ratio:.1f}x" if info.compress_size > 0 else "-",
            )

        console.print(table)
        console.print(f"[blue]INFO:[/] Total uncompressed size: {total_uncompressed:,} bytes ({total_uncompressed / 1024 / 1024:.1f} MB)")

        if total_uncompressed > MAX_TOTAL_SIZE:
            errors.append(f"Total uncompressed size exceeds {MAX_TOTAL_SIZE / 1024 / 1024 / 1024:.0f} GB — possible zip bomb")

        bad = zf.testzip()
        if bad is not None:
            errors.append(f"CRC mismatch on: {bad}")
        else:
            console.print("[green]OK:[/] All CRC checksums valid")

    console.print()
    if errors:
        for e in errors:
            console.print(f"[red bold]ERROR:[/] {e}")
    if warnings:
        for w in warnings:
            console.print(f"[yellow]WARNING:[/] {w}")

    if errors:
        console.print(f"\n[red bold]BLOCKED:[/] {len(errors)} error(s) found — do NOT extract this file")
        return 1, warnings, errors
    elif warnings:
        console.print(f"\n[yellow]CAUTION:[/] {len(warnings)} warning(s) — review before extracting")
        return 0, warnings, errors
    else:
        console.print("\n[green bold]CLEAN:[/] No issues detected")
        return 0, warnings, errors


def extract_and_scan(zip_path: Path) -> int:
    """Extract to a tmpdir and run trivy fs on it. Returns trivy exit code."""
    path = zip_path.resolve()
    tmpdir = Path(tempfile.mkdtemp(prefix="check-zip-"))

    console.print(f"\n[blue]Extracting to:[/] {tmpdir}")

    with zipfile.ZipFile(path, "r") as zf:
        zf.extractall(tmpdir)

    console.print(f"[blue]Extracted {len(list(tmpdir.rglob('*')))} files[/]")

    trivy = shutil.which("trivy")
    if trivy is None:
        console.print("[yellow]WARNING:[/] trivy not found in PATH — install via `mise install trivy`")
        console.print(f"[blue]Files extracted to:[/] {tmpdir}")
        console.print(f"[blue]Run manually:[/] trivy fs {tmpdir}")
        return 0

    console.print(f"\n[blue]Running trivy fs scan...[/]\n")

    result = subprocess.run(
        [trivy, "fs", "--scanners", "vuln,secret,misconfig", str(tmpdir)],
    )

    console.print(f"\n[blue]Extracted files remain at:[/] {tmpdir}")
    console.print(f"[blue]Clean up with:[/] rm -rf {tmpdir}")

    return result.returncode


app = typer.Typer(help="Sanity-check a .zip file for suspicious content before extracting.")


@app.command()
def main(
    zip_path: Annotated[Path, typer.Argument(help="Path to the .zip file to inspect")],
    dangerously_unzip_to_tmp: Annotated[
        bool,
        typer.Option(
            "--dangerously-unzip-to-tmp",
            help="Extract to tmpdir and scan with trivy even if warnings are present (errors still block)",
        ),
    ] = False,
) -> None:
    """Inspect a zip file for structural safety issues, then extract and scan with trivy."""
    # Phase 1: structural checks
    exit_code, warnings, errors = check_zip(zip_path)

    if errors:
        console.print("\n[red bold]Extraction blocked due to errors.[/]")
        if dangerously_unzip_to_tmp:
            console.print("[red]--dangerously-unzip-to-tmp does not override errors, only warnings.[/]")
        raise typer.Exit(1)

    # Phase 2: extract and scan
    should_extract = True
    if warnings and not dangerously_unzip_to_tmp:
        console.print("\n[yellow]Skipping extraction due to warnings.[/]")
        console.print("[yellow]Use --dangerously-unzip-to-tmp to extract anyway.[/]")
        should_extract = False

    if should_extract:
        if warnings and dangerously_unzip_to_tmp:
            console.print("\n[yellow bold]--dangerously-unzip-to-tmp: proceeding despite warnings[/]")
        scan_code = extract_and_scan(zip_path)
        if scan_code != 0:
            raise typer.Exit(scan_code)

    raise typer.Exit(exit_code)


if __name__ == "__main__":
    app()
