diff --git a/Cargo.lock b/Cargo.lock index 4dc6b45..0b23e81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,7 +1,5 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -# SPDX-FileCopyrightText: 2023-2025 erdnaxe -# SPDX-License-Identifier: CC0-1.0 version = 4 [[package]] @@ -626,7 +624,7 @@ dependencies = [ [[package]] name = "sossette" -version = "0.1.1" +version = "0.2.0" dependencies = [ "anyhow", "clap", diff --git a/Cargo.toml b/Cargo.toml index 950044e..2f96155 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,20 +8,59 @@ keywords = ["tcp", "ctf", "socat", "fcsc"] categories = ["command-line-utilities"] repository = "https://github.com/erdnaxe/sossette" authors = ["erdnaxe "] -license = "MIT" -version = "0.1.1" +version = "0.2.0" edition = "2024" [dependencies] anyhow = "1.0" -clap = { version = "4.5", features = ["derive", "env"] } +clap = { version = "4.6", features = ["derive", "env"] } clap-verbosity-flag = "3.0" command-group = { version = "5.0", features = ["with-tokio"] } env_logger = "0.11" log = "0.4" rand = "0.10" sha2 = "0.10" -tokio = { version = "1.50", features = ["rt-multi-thread", "io-util", "signal", "net", "time", "process", "macros"] } +tokio = { version = "1.5", features = [ + "rt-multi-thread", + "io-util", + "signal", + "net", + "time", + "process", + "macros", +] } + +[lints.rust] +arithmetic_overflow = { level = "deny", priority = -1 } + +[lints.clippy] +pedantic = { level = "deny", priority = -1 } +nursery = { level = "deny", priority = -1 } +missing-errors-doc = "allow" + +indexing_slicing = { level = "deny", priority = -1 } +fallible_impl_from = { level = "deny", priority = -1 } +wildcard_enum_match_arm = { level = "deny", priority = -1 } +unneeded_field_pattern = { level = "deny", priority = -1 } +fn_params_excessive_bools = { level = "deny", priority = -1 } +must_use_candidate = { level = "deny", priority = -1 } +checked_conversions = { level = "deny", priority = -1 } +cast_possible_truncation = { level = "deny", priority = -1 } +cast_sign_loss = { level = "deny", priority = -1 } +cast_possible_wrap = { level = "deny", priority = -1 } +cast_precision_loss = { level = "deny", priority = -1 } +integer_division = { level = "deny", priority = -1 } +arithmetic_side_effects = { level = "deny", priority = -1 } +unchecked_duration_subtraction = { level = "deny", priority = -1 } +unwrap_used = "warn" +expect_used = "warn" +panicking_unwrap = { level = "deny", priority = -1 } +option_env_unwrap = { level = "deny", priority = -1 } +join_absolute_paths = { level = "deny", priority = -1 } +serde_api_misuse = { level = "deny", priority = -1 } +uninit_vec = { level = "deny", priority = -1 } +transmute_ptr_to_ref = { level = "deny", priority = -1 } +transmute_undefined_repr = { level = "deny", priority = -1 } [profile.release] codegen-units = 1 diff --git a/README.md b/README.md index 3d94725..f083f04 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,71 @@ world ^C ``` +## PROXY protocol support + +Sossette supports the [PROXY protocol v2](https://github.com/haproxy/haproxy/blob/master/doc/proxy-protocol.txt) to preserve client IP addresses when running behind a load balancer or reverse proxy. + +### Usage + +Enable PROXY protocol v2 with the `--proxy-protocol` flag. When enabled, a valid PROXY protocol v2 header is **required** and connections without one are rejected: + +```bash +$ sossette --proxy-protocol -l 0.0.0.0:4000 cat +``` + +Or using the environment variable: + +```bash +$ WRAPPER_PROXY_PROTOCOL=true sossette -l 0.0.0.0:4000 cat +``` + +### Accessing client information + +When PROXY protocol is enabled and a valid header is received, sossette: + +1. **Logs the real client IP** instead of the proxy's IP: + ``` + [2024-03-09T10:15:23Z INFO sossette] Client [::1]:55438 connected + [2024-03-09T10:15:23Z INFO sossette] Real client: 192.0.2.123:54321 (via proxy [::1]:55438) + ``` + +### Load balancer configuration + +#### HAProxy + +Configure HAProxy to send PROXY protocol v2 headers: + +```haproxy +frontend tcp_front + bind *:443 + mode tcp + default_backend tcp_back + +backend tcp_back + mode tcp + server sossette 127.0.0.1:4000 send-proxy-v2 +``` + +#### nginx + +Configure nginx stream module with PROXY protocol: + +```nginx +stream { + upstream sossette { + server 127.0.0.1:4000; + } + + server { + listen 443; + proxy_pass sossette; + proxy_protocol on; + } +} +``` + +**Security note**: When using PROXY protocol, ensure that only trusted load balancers can connect to sossette (e.g., using firewall rules). Otherwise, clients could spoof their IP addresses by sending fake PROXY protocol headers. + ## Applying transformations to stdin `process_stdin` in [src/main.rs](./src/main.rs) can be easily patched to apply diff --git a/src/handler.rs b/src/handler.rs index ed9ba15..572462c 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,15 +1,17 @@ // SPDX-FileCopyrightText: 2023-2025 erdnaxe // SPDX-License-Identifier: MIT -use crate::pow; use crate::Args; +use crate::pow; +use crate::proxy; +use std::net::SocketAddr; use std::process::Stdio; use std::time::Duration; use anyhow::{Context, Result}; use command_group::AsyncCommandGroup; -use log::debug; +use log::{debug, info, warn}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::process::Command; @@ -21,19 +23,16 @@ async fn process_stdin( mut socket: R, mut child_stdin: W, ) -> Result<()> { - let mut in_buf = [0; 1024]; + let mut in_buf = [0u8; 1024]; loop { let n = socket.read(&mut in_buf).await?; if n == 0 { return Ok(()); // socket closed } - if in_buf[0] == 3 { - debug!("Client sent Ctrl-C"); - return Ok(()); - } - debug!("Writting to stdin: {:?}", &in_buf[0..n]); + let data = in_buf.get(..n).context("stdin read index out of bounds")?; + debug!("Writting to stdin: {data:?}"); child_stdin - .write_all(&in_buf[0..n]) + .write_all(data) .await .context("Failed to write to stdin")?; } @@ -44,14 +43,18 @@ async fn process_stdout( mut socket: W, mut child_stdout: R, ) -> Result<()> { - let mut out_buf = [0; 1024]; + let mut out_buf = [0u8; 1024]; loop { let n = child_stdout.read(&mut out_buf).await?; if n == 0 { return Ok(()); // process closed } + let data = out_buf + .get(..n) + .context("stdout read index out of bounds")?; + debug!("Reading from stdout: {data:?}"); socket - .write_all(&out_buf[0..n]) + .write_all(data) .await .context("Failed to write to socket")?; } @@ -61,18 +64,46 @@ async fn process_stdout( /// /// Spawn one process and then spawn 3 tasks to manage input, output and /// timeout. If one of these tasks reach its end, kill the process. -pub async fn handle_client(mut socket: TcpStream, args: Args) -> Result<()> { - // Send message of the day +pub async fn handle_client( + mut socket: TcpStream, + peer_addr: SocketAddr, + args: Args, +) -> Result> { + // Parse PROXY protocol header if enabled + let proxy_info = if args.proxy_protocol { + match proxy::parse_proxy_v2_header(&mut socket).await { + Ok(proxy::ProxyHeader::Proxied(info)) => { + info!( + "Client: {}:{} -> {}:{} (via proxy {}) connected", + info.src_addr, info.src_port, info.dst_addr, info.dst_port, peer_addr + ); + Some(info) + } + Ok(proxy::ProxyHeader::Local) => { + debug!("PROXY protocol LOCAL command"); + None + } + Err(e) => { + warn!("Rejecting connection from {peer_addr} due to PROXY protocol error: {e:?}"); + return Err(e); + } + } + } else { + None + }; + + // MOTD if let Some(motd) = &args.motd { socket.write_all(motd.as_bytes()).await?; - socket.write_all(b"\r\n").await?; + socket.write_all(&b"\r\n"[..]).await?; } - // Proof-of-work prompt + // Proof-of-work if args.pow > 0 { - let valid = pow::proof_of_work_prompt(&mut socket, args.pow, args.pow_backdoor).await?; + let valid = + pow::proof_of_work_prompt(&mut socket, args.pow, args.pow_backdoor.as_ref()).await?; if !valid { - return Ok(()); + return Ok(proxy_info); } } @@ -80,7 +111,9 @@ pub async fn handle_client(mut socket: TcpStream, args: Args) -> Result<()> { let mut command = Command::new(&args.command); command.args(&args.arguments); command.stdin(Stdio::piped()).stdout(Stdio::piped()); + let mut child = command.group_spawn().context("Failed to run command")?; + let child_stdin = child.inner().stdin.take().context("Failed to open stdin")?; let child_stdout = child .inner() @@ -88,22 +121,37 @@ pub async fn handle_client(mut socket: TcpStream, args: Args) -> Result<()> { .take() .context("Failed to open stdout")?; - // Start tasks - let mut set = JoinSet::new(); + // Split socket let (read_half, write_half) = socket.into_split(); + + let mut set = JoinSet::new(); + set.spawn(async move { process_stdin(read_half, child_stdin).await }); + set.spawn(async move { process_stdout(write_half, child_stdout).await }); - if let Some(timeout) = args.timeout { + + let session_timeout = args.timeout.map(Duration::from_secs); + + if let Some(timeout) = session_timeout { set.spawn(async move { - sleep(Duration::from_secs(timeout)).await; + sleep(timeout).await; debug!("Timeout reached"); Ok(()) }); } - // If one task exits, drop the others - // Child group should always be killed before dropping child handle. + // Wait for first task to finish let res = set.join_next().await; + + // Cancel remaining tasks immediately + set.abort_all(); + + // Kill the process group child.kill().await.context("Failed to kill process group")?; - res.unwrap_or(Ok(Ok(())))? + + // Await child to avoid zombie process + let _ = child.wait().await; + + res.unwrap_or(Ok(Ok(())))??; + Ok(proxy_info) } diff --git a/src/main.rs b/src/main.rs index e493689..ae942b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod handler; mod pow; +mod proxy; use anyhow::{Context, Result}; use clap::Parser; @@ -35,6 +36,10 @@ struct Args { #[arg(long, value_name = "STRING", env = "WRAPPER_POW_BACKDOOR")] pow_backdoor: Option, + /// Require PROXY protocol v2 header, reject connections without it + #[arg(long, env = "WRAPPER_PROXY_PROTOCOL")] + proxy_protocol: bool, + #[command(flatten)] verbose: Verbosity, @@ -56,16 +61,24 @@ async fn serve(args: Args) -> Result<()> { match listener.accept().await { Ok((socket, peer_addr)) => { info!("Client {peer_addr:?} connected"); + info!("Client {peer_addr:?} connected"); // Spawn task to handle this client let my_args = args.clone(); tokio::spawn(async move { - match handler::handle_client(socket, my_args).await { - Ok(()) => { + match handler::handle_client(socket, peer_addr, my_args).await { + Ok(Some(proxy_info)) => { + info!( + "Client: {}:{} (via proxy {}) disconnected", + proxy_info.src_addr, proxy_info.src_port, peer_addr + ); + } + Ok(None) => { info!("Client {peer_addr:?} disconnected"); } Err(e) => { warn!("Handling client {peer_addr:?} failed: {e:?}"); + warn!("Handling client {peer_addr:?} failed: {e:?}"); } } }); @@ -95,6 +108,7 @@ async fn main() { Ok(()) => {} Err(err) => { warn!("Unable to listen for shutdown signal: {err}"); + warn!("Unable to listen for shutdown signal: {err}"); } } } @@ -104,6 +118,6 @@ mod tests { #[test] fn verify_cli() { use clap::CommandFactory; - crate::Args::command().debug_assert() + crate::Args::command().debug_assert(); } } diff --git a/src/pow.rs b/src/pow.rs index fddc30e..253a4a6 100644 --- a/src/pow.rs +++ b/src/pow.rs @@ -1,9 +1,9 @@ // SPDX-FileCopyrightText: 2023-2025 erdnaxe // SPDX-License-Identifier: MIT -use anyhow::Result; +use anyhow::{Context, Result}; +use rand::RngExt; use rand::distr::Alphanumeric; -use rand::{RngExt, rng}; use sha2::{Digest, Sha256}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -17,73 +17,92 @@ More details can be found on .\r\n"; pub async fn proof_of_work_prompt( socket: &mut S, difficulty: u32, - backdoor: Option, + backdoor: Option<&String>, ) -> Result { // Generate prefix using OS random - let prefix: [u8; 16] = rng() + let prefix: [u8; 16] = rand::rng() .sample_iter(Alphanumeric) .take(16) .collect::>() .as_slice() .try_into() - .unwrap(); + .context("Failed to generate random prefix")?; // Prompt user socket.write_all(POW_HEADER_MESSAGE).await?; - let prompt = format!("Please provide an ASCII printable string S such that SHA256({} || S) starts with {} bits equal to 0 (the string concatenation is denoted ||): ", String::from_utf8(prefix.into())?, difficulty); + let prompt = format!( + "Please provide an ASCII printable string S such that SHA256({} || S) starts with {} bits equal to 0 (the string concatenation is denoted ||): ", + String::from_utf8(prefix.into())?, + difficulty + ); socket.write_all(prompt.as_bytes()).await?; - let mut buf = [0; 256]; - let mut buf_n = 0; + let mut buf = [0u8; 256]; + let mut buf_n: usize = 0; while buf_n < 256 { - let n = socket.read(&mut buf[buf_n..=buf_n]).await?; - if n == 0 || buf[buf_n] == b'\x03' { - return Ok(false); // socket closed or Ctrl-C + let byte = buf + .get_mut(buf_n..=buf_n) + .context("read index out of bounds")?; + let n = socket.read(byte).await?; + if n == 0 { + return Ok(false); // socket closed } - if buf[buf_n] == b'\0' || buf[buf_n] == b'\n' { + let current = *buf.get(buf_n).context("index out of bounds")?; + if current == b'\0' || current == b'\n' { break; // telnet uses \r\0, netcat \r\n } - if buf[buf_n] >= 127 || buf[buf_n] < 32 { + if !(32..127).contains(¤t) { continue; // ignore non ascii printable } - buf_n += n; - } - while buf_n > 0 - && (buf[buf_n - 1] == b'\n' || buf[buf_n - 1] == b'\r' || buf[buf_n - 1] == b'\0') - { - buf_n -= 1; // trim input + buf_n = buf_n.checked_add(n).context("buffer index overflow")?; } - // Backdoor for staff testing - if let Some(backdoor_str) = backdoor - && backdoor_str.as_bytes() == &buf[..buf_n] { - return Ok(true); + // Trim trailing carriage return + if buf_n > 0 { + let last = *buf + .get(buf_n.checked_sub(1).context("underflow")?) + .context("index out of bounds")?; + if last == b'\r' { + buf_n = buf_n.checked_sub(1).context("underflow")?; } + } - // Compute hash + // Get the user input as a slice + let suffix = buf.get(..buf_n).context("slice out of bounds")?; + + // Check backdoor + if let Some(bd) = backdoor + && suffix == bd.as_bytes() + { + return Ok(true); + } + + // Verify proof of work let mut hasher = Sha256::new(); hasher.update(prefix); - hasher.update(&buf[..buf_n]); - let hash: [u8; 32] = hasher.finalize().into(); + hasher.update(suffix); + let hash = hasher.finalize(); + Ok(check_leading_zeros(&hash, difficulty)) +} - // Count zeros - let mut measured_difficulty = 0; - for hash_byte in &hash { - if *hash_byte == 0 { - measured_difficulty += 8; +/// Check that the hash starts with at least `difficulty` zero bits +fn check_leading_zeros(hash: &[u8], difficulty: u32) -> bool { + let mut remaining = difficulty; + for &byte in hash { + if remaining == 0 { + return true; + } + if remaining >= 8 { + if byte != 0 { + return false; + } + remaining = remaining.saturating_sub(8); } else { - measured_difficulty += hash_byte.leading_zeros(); - break; + // Check the top `remaining` bits of this byte + let mask = 0xFF_u8 + .checked_shl(8_u32.saturating_sub(remaining)) + .unwrap_or(0); + return byte & mask == 0; } } - - if measured_difficulty < difficulty { - let message = format!( - "Wrong proof-of-work, hash starts with only {measured_difficulty} bits equal to 0.\r\n" - ); - socket.write_all(message.as_bytes()).await?; - Ok(false) - } else { - socket.write_all(b"Thank you for solving our proof-of-work, we hope you had a great time! Launching challenge...\r\n\r\n").await?; - Ok(true) - } + remaining == 0 } diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..70f1577 --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,366 @@ +// SPDX-FileCopyrightText: 2023-2026 erdnaxe +// SPDX-License-Identifier: MIT + +use anyhow::{Result, anyhow}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::time::{Duration, timeout}; + +/// PROXY protocol v2 signature (12 bytes) +const PROXY_V2_SIGNATURE: [u8; 12] = [ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +]; + +const VERSION_MASK: u8 = 0xF0; +const VERSION_2: u8 = 0x20; +const LOW_NIBBLE_MASK: u8 = 0x0F; +const HIGH_NIBBLE_MASK: u8 = 0xF0; + +const PROXY_V2_HEADER_LEN: usize = 16; +const IPV4_BLOCK_LEN: usize = 12; +const IPV6_BLOCK_LEN: usize = 36; + +const MAX_PROXY_ADDR_LEN: usize = 512; +const READ_TIMEOUT: Duration = Duration::from_secs(2); + +#[derive(Debug, Clone)] +pub struct ProxyInfo { + pub src_addr: IpAddr, + pub src_port: u16, + pub dst_addr: IpAddr, + pub dst_port: u16, +} + +impl ProxyInfo { + pub const fn new(src_addr: IpAddr, src_port: u16, dst_addr: IpAddr, dst_port: u16) -> Self { + Self { + src_addr, + src_port, + dst_addr, + dst_port, + } + } +} + +#[derive(Debug, Clone)] +pub enum ProxyHeader { + Local, + Proxied(ProxyInfo), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Command { + Local, + Proxy, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AddressFamily { + Unspec, + Inet, + Inet6, + Unix, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TransportProtocol { + Unspec, + Stream, + Datagram, +} + +fn parse_command(version_command: u8) -> Result { + if version_command & VERSION_MASK != VERSION_2 { + return Err(anyhow!("Unsupported PROXY protocol version")); + } + + match version_command & LOW_NIBBLE_MASK { + 0x00 => Ok(Command::Local), + 0x01 => Ok(Command::Proxy), + command => Err(anyhow!("Unsupported PROXY command: {command}")), + } +} + +fn parse_family_protocol(family_protocol: u8) -> Result<(AddressFamily, TransportProtocol)> { + let family = match family_protocol & HIGH_NIBBLE_MASK { + 0x00 => AddressFamily::Unspec, + 0x10 => AddressFamily::Inet, + 0x20 => AddressFamily::Inet6, + 0x30 => AddressFamily::Unix, + family => return Err(anyhow!("Unknown address family: {family}")), + }; + + let protocol = match family_protocol & LOW_NIBBLE_MASK { + 0x00 => TransportProtocol::Unspec, + 0x01 => TransportProtocol::Stream, + 0x02 => TransportProtocol::Datagram, + protocol => return Err(anyhow!("Unknown transport protocol: {protocol}")), + }; + + Ok((family, protocol)) +} + +async fn read_exact_with_timeout( + stream: &mut R, + buf: &mut [u8], +) -> Result<()> { + timeout(READ_TIMEOUT, stream.read_exact(buf)).await??; + Ok(()) +} + +async fn drain_bytes(stream: &mut R, mut len: usize) -> Result<()> { + let mut scratch = [0u8; 256]; + + while len > 0 { + let chunk_len = len.min(scratch.len()); + let chunk = scratch + .get_mut(..chunk_len) + .ok_or_else(|| anyhow!("Read chunk length out of bounds"))?; + read_exact_with_timeout(stream, chunk).await?; + len = len + .checked_sub(chunk_len) + .ok_or_else(|| anyhow!("Drained length underflow"))?; + } + + Ok(()) +} + +/// Parse PROXY protocol v2 header +pub async fn parse_proxy_v2_header(stream: &mut R) -> Result { + let mut header = [0u8; PROXY_V2_HEADER_LEN]; + read_exact_with_timeout(stream, &mut header).await?; + + // Signature check + if header[..12] != PROXY_V2_SIGNATURE { + return Err(anyhow!("Invalid PROXY protocol v2 signature")); + } + + let command = parse_command(header[12])?; + let addr_len = usize::from(u16::from_be_bytes([header[14], header[15]])); + + if addr_len > MAX_PROXY_ADDR_LEN { + return Err(anyhow!("PROXY header too large: {addr_len}")); + } + + // Handle LOCAL command + if command == Command::Local { + if addr_len > 0 { + drain_bytes(stream, addr_len).await?; + } + return Ok(ProxyHeader::Local); + } + + let (family, protocol) = parse_family_protocol(header[13])?; + + if protocol != TransportProtocol::Stream { + return Err(anyhow!("Unsupported transport protocol: {protocol:?}")); + } + + match family { + AddressFamily::Inet => parse_ipv4(stream, addr_len).await, + AddressFamily::Inet6 => parse_ipv6(stream, addr_len).await, + AddressFamily::Unspec => { + if addr_len > 0 { + drain_bytes(stream, addr_len).await?; + } + Ok(ProxyHeader::Local) + } + AddressFamily::Unix => Err(anyhow!("UNIX addresses not supported")), + } +} + +/// Parse IPv4 address block (12 bytes) + skip TLVs +async fn parse_ipv4(stream: &mut R, addr_len: usize) -> Result { + if addr_len < IPV4_BLOCK_LEN { + return Err(anyhow!("IPv4 address block too short: {addr_len}")); + } + + let mut addr = [0u8; IPV4_BLOCK_LEN]; + read_exact_with_timeout(stream, &mut addr).await?; + + let tlv_len = addr_len + .checked_sub(IPV4_BLOCK_LEN) + .ok_or_else(|| anyhow!("IPv4 TLV length underflow"))?; + if tlv_len > 0 { + drain_bytes(stream, tlv_len).await?; + } + + let src_addr = Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]); + let dst_addr = Ipv4Addr::new(addr[4], addr[5], addr[6], addr[7]); + let src_port = u16::from_be_bytes([addr[8], addr[9]]); + let dst_port = u16::from_be_bytes([addr[10], addr[11]]); + + Ok(ProxyHeader::Proxied(ProxyInfo::new( + IpAddr::V4(src_addr), + src_port, + IpAddr::V4(dst_addr), + dst_port, + ))) +} + +/// Parse IPv6 address block (36 bytes) + skip TLVs +async fn parse_ipv6(stream: &mut R, addr_len: usize) -> Result { + if addr_len < IPV6_BLOCK_LEN { + return Err(anyhow!("IPv6 address block too short: {addr_len}")); + } + + let mut addr = [0u8; IPV6_BLOCK_LEN]; + read_exact_with_timeout(stream, &mut addr).await?; + + let tlv_len = addr_len + .checked_sub(IPV6_BLOCK_LEN) + .ok_or_else(|| anyhow!("IPv6 TLV length underflow"))?; + if tlv_len > 0 { + drain_bytes(stream, tlv_len).await?; + } + + let mut src_addr_bytes = [0u8; 16]; + src_addr_bytes.copy_from_slice(&addr[..16]); + let src_addr = Ipv6Addr::from(src_addr_bytes); + + let mut dst_addr_bytes = [0u8; 16]; + dst_addr_bytes.copy_from_slice(&addr[16..32]); + let dst_addr = Ipv6Addr::from(dst_addr_bytes); + + let src_port = u16::from_be_bytes([addr[32], addr[33]]); + let dst_port = u16::from_be_bytes([addr[34], addr[35]]); + + Ok(ProxyHeader::Proxied(ProxyInfo::new( + IpAddr::V6(src_addr), + src_port, + IpAddr::V6(dst_addr), + dst_port, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + const IPV4_BLOCK_LEN_U16: u16 = 12; + const IPV6_BLOCK_WITH_TLV_LEN_U16: u16 = 38; + + async fn parse_from_bytes(data: &[u8]) -> Result<(Result, Vec)> { + let (mut writer, mut reader) = tokio::io::duplex(2048); + writer.write_all(data).await?; + drop(writer); + + let header = parse_proxy_v2_header(&mut reader).await; + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await?; + + Ok((header, remaining)) + } + + fn build_header(version_command: u8, family_protocol: u8, addr_len: u16) -> Vec { + let mut data = Vec::with_capacity(PROXY_V2_HEADER_LEN); + data.extend_from_slice(&PROXY_V2_SIGNATURE); + data.push(version_command); + data.push(family_protocol); + data.extend_from_slice(&addr_len.to_be_bytes()); + data + } + + #[tokio::test] + async fn parses_local_header_and_discards_payload() -> Result<()> { + let mut data = build_header(0x20, 0x00, 3); + data.extend_from_slice(&[0xAA, 0xBB, 0xCC]); + + let (header, remaining) = parse_from_bytes(&data).await?; + + assert!(matches!(header, Ok(ProxyHeader::Local))); + assert!(remaining.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn parses_local_header_with_unknown_family_and_protocol() -> Result<()> { + let mut data = build_header(0x20, 0xFF, 1); + data.push(0xAA); + + let (header, remaining) = parse_from_bytes(&data).await?; + + assert!(matches!(header, Ok(ProxyHeader::Local))); + assert!(remaining.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn parses_ipv4_proxy_header() -> Result<()> { + let mut data = build_header(0x21, 0x11, IPV4_BLOCK_LEN_U16); + data.extend_from_slice(&[ + 192, 0, 2, 10, // source IPv4 + 198, 51, 100, 7, // destination IPv4 + 0x30, 0x39, // source port 12345 + 0x00, 0x50, // destination port 80 + ]); + + let (header, remaining) = parse_from_bytes(&data).await?; + let header = match header { + Ok(header) => header, + Err(error) => panic!("header should parse: {error}"), + }; + assert!(remaining.is_empty()); + + match header { + ProxyHeader::Proxied(info) => { + assert_eq!(info.src_addr, IpAddr::V4(Ipv4Addr::new(192, 0, 2, 10))); + assert_eq!(info.dst_addr, IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7))); + assert_eq!(info.src_port, 12345); + assert_eq!(info.dst_port, 80); + } + ProxyHeader::Local => panic!("expected proxied header"), + } + Ok(()) + } + + #[tokio::test] + async fn parses_ipv6_proxy_header_and_discards_tlv() -> Result<()> { + let mut data = build_header(0x21, 0x21, IPV6_BLOCK_WITH_TLV_LEN_U16); + data.extend_from_slice(&[ + 0x20, 0x01, 0x0D, 0xB8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, // source IPv6 + 0x20, 0x01, 0x0D, 0xB8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, // destination IPv6 + 0x01, 0xBB, // source port 443 + 0x82, 0x35, // destination port 33333 + 0xEE, 0xFF, // TLV payload to skip + ]); + + let (header, remaining) = parse_from_bytes(&data).await?; + let header = match header { + Ok(header) => header, + Err(error) => panic!("header should parse: {error}"), + }; + assert!(remaining.is_empty()); + + match header { + ProxyHeader::Proxied(info) => { + assert_eq!( + info.src_addr, + IpAddr::V6(Ipv6Addr::new(0x2001, 0x0DB8, 0, 0, 0, 0, 0, 1)) + ); + assert_eq!( + info.dst_addr, + IpAddr::V6(Ipv6Addr::new(0x2001, 0x0DB8, 0, 0, 0, 0, 0, 2)) + ); + assert_eq!(info.src_port, 443); + assert_eq!(info.dst_port, 33333); + } + ProxyHeader::Local => panic!("expected proxied header"), + } + Ok(()) + } + + #[tokio::test] + async fn rejects_unknown_command() -> Result<()> { + let data = build_header(0x22, 0x11, 0); + let (header, _) = parse_from_bytes(&data).await?; + let Err(error) = header else { + panic!("header should fail"); + }; + + assert!(error.to_string().contains("Unsupported PROXY command")); + Ok(()) + } +}