diff --git a/crates/stackforge-core/src/layer/ipv4/builder.rs b/crates/stackforge-core/src/layer/ipv4/builder.rs new file mode 100644 index 0000000..09cd4f3 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/builder.rs @@ -0,0 +1,727 @@ +//! IPv4 packet builder. +//! +//! Provides a fluent API for constructing IPv4 packets with automatic +//! field calculation (checksum, length, IHL). + +use std::net::Ipv4Addr; + +use super::checksum::ipv4_checksum; +use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets}; +use super::options::{Ipv4Option, Ipv4Options, Ipv4OptionsBuilder}; +use super::protocol; +use crate::layer::field::FieldError; + +/// Builder for IPv4 packets. +/// +/// # Example +/// +/// ```rust +/// use stackforge_core::layer::ipv4::{Ipv4Builder, protocol}; +/// use std::net::Ipv4Addr; +/// +/// let packet = Ipv4Builder::new() +/// .src(Ipv4Addr::new(192, 168, 1, 1)) +/// .dst(Ipv4Addr::new(192, 168, 1, 2)) +/// .ttl(64) +/// .protocol(protocol::TCP) +/// .dont_fragment() +/// .build(); +/// +/// assert_eq!(packet.len(), 20); // Minimum header, no payload +/// ``` +#[derive(Debug, Clone)] +pub struct Ipv4Builder { + // Header fields + version: u8, + ihl: Option, + tos: u8, + total_len: Option, + id: u16, + flags: Ipv4Flags, + frag_offset: u16, + ttl: u8, + protocol: u8, + checksum: Option, + src: Ipv4Addr, + dst: Ipv4Addr, + + // Options + options: Ipv4Options, + + // Payload + payload: Vec, + + // Build options + auto_checksum: bool, + auto_length: bool, + auto_ihl: bool, +} + +impl Default for Ipv4Builder { + fn default() -> Self { + Self { + version: 4, + ihl: None, + tos: 0, + total_len: None, + id: 1, + flags: Ipv4Flags::NONE, + frag_offset: 0, + ttl: 64, + protocol: 0, + checksum: None, + src: Ipv4Addr::new(127, 0, 0, 1), + dst: Ipv4Addr::new(127, 0, 0, 1), + options: Ipv4Options::new(), + payload: Vec::new(), + auto_checksum: true, + auto_length: true, + auto_ihl: true, + } + } +} + +impl Ipv4Builder { + /// Create a new IPv4 builder with default values. + pub fn new() -> Self { + Self::default() + } + + /// Create a builder initialized from an existing packet. + pub fn from_bytes(data: &[u8]) -> Result { + let layer = Ipv4Layer::at_offset_dynamic(data, 0)?; + + let mut builder = Self::new(); + builder.version = layer.version(data)?; + builder.ihl = Some(layer.ihl(data)?); + builder.tos = layer.tos(data)?; + builder.total_len = Some(layer.total_len(data)?); + builder.id = layer.id(data)?; + builder.flags = layer.flags(data)?; + builder.frag_offset = layer.frag_offset(data)?; + builder.ttl = layer.ttl(data)?; + builder.protocol = layer.protocol(data)?; + builder.checksum = Some(layer.checksum(data)?); + builder.src = layer.src(data)?; + builder.dst = layer.dst(data)?; + + // Parse options if present + if layer.options_len(data) > 0 { + builder.options = layer.options(data)?; + } + + // Copy payload + let header_len = layer.calculate_header_len(data); + let total_len = layer.total_len(data)? as usize; + if total_len > header_len && data.len() >= total_len { + builder.payload = data[header_len..total_len].to_vec(); + } + + // Disable auto-calculation since we're copying exact values + builder.auto_checksum = false; + builder.auto_length = false; + builder.auto_ihl = false; + + Ok(builder) + } + + // ========== Header Field Setters ========== + + /// Set the IP version (should normally be 4). + pub fn version(mut self, version: u8) -> Self { + self.version = version; + self + } + + /// Set the Internet Header Length (in 32-bit words). + /// If not set, will be calculated automatically. + pub fn ihl(mut self, ihl: u8) -> Self { + self.ihl = Some(ihl); + self.auto_ihl = false; + self + } + + /// Set the Type of Service field. + pub fn tos(mut self, tos: u8) -> Self { + self.tos = tos; + self + } + + /// Set the DSCP (Differentiated Services Code Point). + pub fn dscp(mut self, dscp: u8) -> Self { + self.tos = (self.tos & 0x03) | ((dscp & 0x3F) << 2); + self + } + + /// Set the ECN (Explicit Congestion Notification). + pub fn ecn(mut self, ecn: u8) -> Self { + self.tos = (self.tos & 0xFC) | (ecn & 0x03); + self + } + + /// Set the total length field. + /// If not set, will be calculated automatically. + pub fn total_len(mut self, len: u16) -> Self { + self.total_len = Some(len); + self.auto_length = false; + self + } + + /// Alias for total_len (Scapy compatibility). + pub fn len(self, len: u16) -> Self { + self.total_len(len) + } + + /// Set the identification field. + pub fn id(mut self, id: u16) -> Self { + self.id = id; + self + } + + /// Set the flags field. + pub fn flags(mut self, flags: Ipv4Flags) -> Self { + self.flags = flags; + self + } + + /// Set the Don't Fragment flag. + pub fn dont_fragment(mut self) -> Self { + self.flags.df = true; + self + } + + /// Clear the Don't Fragment flag. + pub fn allow_fragment(mut self) -> Self { + self.flags.df = false; + self + } + + /// Set the More Fragments flag. + pub fn more_fragments(mut self) -> Self { + self.flags.mf = true; + self + } + + /// Set the reserved/evil bit. + pub fn evil(mut self) -> Self { + self.flags.reserved = true; + self + } + + /// Set the fragment offset (in 8-byte units). + pub fn frag_offset(mut self, offset: u16) -> Self { + self.frag_offset = offset & 0x1FFF; + self + } + + /// Set the fragment offset in bytes (will be divided by 8). + pub fn frag_offset_bytes(mut self, offset: u32) -> Self { + self.frag_offset = ((offset / 8) & 0x1FFF) as u16; + self + } + + /// Set the TTL (Time to Live). + pub fn ttl(mut self, ttl: u8) -> Self { + self.ttl = ttl; + self + } + + /// Set the protocol number. + pub fn protocol(mut self, protocol: u8) -> Self { + self.protocol = protocol; + self + } + + /// Alias for protocol (Scapy compatibility). + pub fn proto(self, protocol: u8) -> Self { + self.protocol(protocol) + } + + /// Set the checksum manually. + /// If not set, will be calculated automatically. + pub fn checksum(mut self, checksum: u16) -> Self { + self.checksum = Some(checksum); + self.auto_checksum = false; + self + } + + /// Alias for checksum (Scapy compatibility). + pub fn chksum(self, checksum: u16) -> Self { + self.checksum(checksum) + } + + /// Set the source IP address. + pub fn src(mut self, src: Ipv4Addr) -> Self { + self.src = src; + self + } + + /// Set the destination IP address. + pub fn dst(mut self, dst: Ipv4Addr) -> Self { + self.dst = dst; + self + } + + // ========== Options ========== + + /// Set the options. + pub fn options(mut self, options: Ipv4Options) -> Self { + self.options = options; + self + } + + /// Add a single option. + pub fn option(mut self, option: Ipv4Option) -> Self { + self.options.push(option); + self + } + + /// Add options using a builder function. + pub fn with_options(mut self, f: F) -> Self + where + F: FnOnce(Ipv4OptionsBuilder) -> Ipv4OptionsBuilder, + { + self.options = f(Ipv4OptionsBuilder::new()).build(); + self + } + + /// Add a Record Route option. + pub fn record_route(mut self, slots: usize) -> Self { + self.options.push(Ipv4Option::RecordRoute { + pointer: 4, + route: vec![Ipv4Addr::UNSPECIFIED; slots], + }); + self + } + + /// Add a Loose Source Route option. + pub fn lsrr(mut self, route: Vec) -> Self { + self.options.push(Ipv4Option::Lsrr { pointer: 4, route }); + self + } + + /// Add a Strict Source Route option. + pub fn ssrr(mut self, route: Vec) -> Self { + self.options.push(Ipv4Option::Ssrr { pointer: 4, route }); + self + } + + /// Add a Router Alert option. + pub fn router_alert(mut self, value: u16) -> Self { + self.options.push(Ipv4Option::RouterAlert { value }); + self + } + + // ========== Payload ========== + + /// Set the payload data. + pub fn payload(mut self, payload: impl Into>) -> Self { + self.payload = payload.into(); + self + } + + /// Append data to the payload. + pub fn append_payload(mut self, data: &[u8]) -> Self { + self.payload.extend_from_slice(data); + self + } + + // ========== Build Options ========== + + /// Enable or disable automatic checksum calculation. + pub fn auto_checksum(mut self, enabled: bool) -> Self { + self.auto_checksum = enabled; + self + } + + /// Enable or disable automatic length calculation. + pub fn auto_length(mut self, enabled: bool) -> Self { + self.auto_length = enabled; + self + } + + /// Enable or disable automatic IHL calculation. + pub fn auto_ihl(mut self, enabled: bool) -> Self { + self.auto_ihl = enabled; + self + } + + // ========== Build Methods ========== + + /// Calculate the header size (including options). + pub fn header_size(&self) -> usize { + if let Some(ihl) = self.ihl { + (ihl as usize) * 4 + } else { + let opts_len = self.options.padded_len(); + IPV4_MIN_HEADER_LEN + opts_len + } + } + + /// Calculate the total packet size. + pub fn packet_size(&self) -> usize { + self.header_size() + self.payload.len() + } + + /// Build the IPv4 packet. + pub fn build(&self) -> Vec { + let header_size = self.header_size(); + let total_size = self.packet_size(); + + let mut buf = vec![0u8; total_size]; + self.build_into(&mut buf) + .expect("buffer is correctly sized"); + buf + } + + /// Build the IPv4 packet into an existing buffer. + pub fn build_into(&self, buf: &mut [u8]) -> Result { + let header_size = self.header_size(); + let total_size = self.packet_size(); + + if buf.len() < total_size { + return Err(FieldError::BufferTooShort { + offset: 0, + need: total_size, + have: buf.len(), + }); + } + + // Calculate IHL + let ihl = if self.auto_ihl { + (header_size / 4) as u8 + } else { + self.ihl.unwrap_or(5) + }; + + // Calculate total length + let total_len = if self.auto_length { + total_size as u16 + } else { + self.total_len.unwrap_or(total_size as u16) + }; + + // Version + IHL + buf[offsets::VERSION_IHL] = ((self.version & 0x0F) << 4) | (ihl & 0x0F); + + // TOS + buf[offsets::TOS] = self.tos; + + // Total Length + buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8; + buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8; + + // ID + buf[offsets::ID] = (self.id >> 8) as u8; + buf[offsets::ID + 1] = (self.id & 0xFF) as u8; + + // Flags + Fragment Offset + let flags_frag = (self.flags.to_byte() as u16) << 8 | self.frag_offset; + buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8; + buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8; + + // TTL + buf[offsets::TTL] = self.ttl; + + // Protocol + buf[offsets::PROTOCOL] = self.protocol; + + // Checksum (initially 0) + buf[offsets::CHECKSUM] = 0; + buf[offsets::CHECKSUM + 1] = 0; + + // Source IP + let src_octets = self.src.octets(); + buf[offsets::SRC..offsets::SRC + 4].copy_from_slice(&src_octets); + + // Destination IP + let dst_octets = self.dst.octets(); + buf[offsets::DST..offsets::DST + 4].copy_from_slice(&dst_octets); + + // Options + if !self.options.is_empty() { + let opts_bytes = self.options.to_bytes(); + let opts_end = offsets::OPTIONS + opts_bytes.len(); + if opts_end <= header_size { + buf[offsets::OPTIONS..opts_end].copy_from_slice(&opts_bytes); + } + } + + // Payload + if !self.payload.is_empty() { + buf[header_size..header_size + self.payload.len()].copy_from_slice(&self.payload); + } + + // Checksum (computed last) + let checksum = if self.auto_checksum { + ipv4_checksum(&buf[..header_size]) + } else { + self.checksum.unwrap_or(0) + }; + buf[offsets::CHECKSUM] = (checksum >> 8) as u8; + buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8; + + Ok(total_size) + } + + /// Build only the header (no payload). + pub fn build_header(&self) -> Vec { + let header_size = self.header_size(); + let mut buf = vec![0u8; header_size]; + + // Temporarily clear payload for header-only build + let payload = std::mem::take(&mut self.payload.clone()); + let builder = Self { + payload: Vec::new(), + ..self.clone() + }; + builder + .build_into(&mut buf) + .expect("buffer is correctly sized"); + + // Don't actually need to restore since we cloned + drop(payload); + + buf + } +} + +// ========== Convenience Constructors ========== + +impl Ipv4Builder { + /// Create an ICMP packet builder. + pub fn icmp() -> Self { + Self::new().protocol(protocol::ICMP) + } + + /// Create a TCP packet builder. + pub fn tcp() -> Self { + Self::new().protocol(protocol::TCP) + } + + /// Create a UDP packet builder. + pub fn udp() -> Self { + Self::new().protocol(protocol::UDP) + } + + /// Create an IP-in-IP tunnel packet builder. + pub fn ipip() -> Self { + Self::new().protocol(protocol::IPV4) + } + + /// Create a GRE tunnel packet builder. + pub fn gre() -> Self { + Self::new().protocol(protocol::GRE) + } + + /// Create a packet destined for a specific address. + pub fn to(dst: Ipv4Addr) -> Self { + Self::new().dst(dst) + } + + /// Create a packet from a specific source. + pub fn from(src: Ipv4Addr) -> Self { + Self::new().src(src) + } +} + +// ========== Random Values ========== + +#[cfg(feature = "rand")] +impl Ipv4Builder { + /// Set a random ID. + pub fn random_id(mut self) -> Self { + use rand::Rng; + self.id = rand::rng().random(); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_build() { + let pkt = Ipv4Builder::new() + .src(Ipv4Addr::new(192, 168, 1, 1)) + .dst(Ipv4Addr::new(192, 168, 1, 2)) + .ttl(64) + .protocol(protocol::TCP) + .build(); + + assert_eq!(pkt.len(), 20); + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.version(&pkt).unwrap(), 4); + assert_eq!(layer.ihl(&pkt).unwrap(), 5); + assert_eq!(layer.ttl(&pkt).unwrap(), 64); + assert_eq!(layer.protocol(&pkt).unwrap(), protocol::TCP); + assert_eq!(layer.src(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(layer.dst(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 2)); + + // Verify checksum + assert!(layer.verify_checksum(&pkt).unwrap()); + } + + #[test] + fn test_with_payload() { + let payload = vec![1, 2, 3, 4, 5]; + let pkt = Ipv4Builder::new() + .src(Ipv4Addr::new(10, 0, 0, 1)) + .dst(Ipv4Addr::new(10, 0, 0, 2)) + .protocol(protocol::UDP) + .payload(payload.clone()) + .build(); + + assert_eq!(pkt.len(), 25); // 20 + 5 + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.total_len(&pkt).unwrap(), 25); + assert_eq!(layer.payload(&pkt).unwrap(), &payload[..]); + } + + #[test] + fn test_with_options() { + let pkt = Ipv4Builder::new() + .src(Ipv4Addr::new(10, 0, 0, 1)) + .dst(Ipv4Addr::new(10, 0, 0, 2)) + .router_alert(0) + .build(); + + // Router Alert is 4 bytes, header should be 24 bytes + assert_eq!(pkt.len(), 24); + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.ihl(&pkt).unwrap(), 6); // 24/4 = 6 + assert!(layer.verify_checksum(&pkt).unwrap()); + } + + #[test] + fn test_flags() { + let pkt = Ipv4Builder::new() + .dst(Ipv4Addr::new(8, 8, 8, 8)) + .dont_fragment() + .build(); + + let layer = Ipv4Layer::at_offset(0); + let flags = layer.flags(&pkt).unwrap(); + assert!(flags.df); + assert!(!flags.mf); + } + + #[test] + fn test_fragment() { + let pkt = Ipv4Builder::new() + .dst(Ipv4Addr::new(8, 8, 8, 8)) + .more_fragments() + .frag_offset(100) + .build(); + + let layer = Ipv4Layer::at_offset(0); + let flags = layer.flags(&pkt).unwrap(); + assert!(flags.mf); + assert_eq!(layer.frag_offset(&pkt).unwrap(), 100); + } + + #[test] + fn test_dscp_ecn() { + let pkt = Ipv4Builder::new() + .dst(Ipv4Addr::new(8, 8, 8, 8)) + .dscp(46) // EF + .ecn(2) // ECT(0) + .build(); + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.dscp(&pkt).unwrap(), 46); + assert_eq!(layer.ecn(&pkt).unwrap(), 2); + } + + #[test] + fn test_from_bytes() { + let original = Ipv4Builder::new() + .src(Ipv4Addr::new(192, 168, 1, 100)) + .dst(Ipv4Addr::new(192, 168, 1, 200)) + .ttl(128) + .id(0xABCD) + .protocol(protocol::ICMP) + .payload(vec![8, 0, 0, 0, 0, 1, 0, 1]) // ICMP echo + .build(); + + let rebuilt = Ipv4Builder::from_bytes(&original) + .unwrap() + .auto_checksum(true) + .build(); + + // Should be identical + assert_eq!(original.len(), rebuilt.len()); + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.src(&original).unwrap(), layer.src(&rebuilt).unwrap()); + assert_eq!(layer.dst(&original).unwrap(), layer.dst(&rebuilt).unwrap()); + assert_eq!(layer.ttl(&original).unwrap(), layer.ttl(&rebuilt).unwrap()); + assert_eq!(layer.id(&original).unwrap(), layer.id(&rebuilt).unwrap()); + } + + #[test] + fn test_convenience_constructors() { + let icmp = Ipv4Builder::icmp().build(); + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.protocol(&icmp).unwrap(), protocol::ICMP); + + let tcp = Ipv4Builder::tcp().build(); + assert_eq!(layer.protocol(&tcp).unwrap(), protocol::TCP); + + let udp = Ipv4Builder::udp().build(); + assert_eq!(layer.protocol(&udp).unwrap(), protocol::UDP); + } + + #[test] + fn test_manual_fields() { + let pkt = Ipv4Builder::new() + .dst(Ipv4Addr::new(8, 8, 8, 8)) + .total_len(100) + .checksum(0x1234) + .ihl(5) + .build(); + + let layer = Ipv4Layer::at_offset(0); + assert_eq!(layer.total_len(&pkt).unwrap(), 100); + assert_eq!(layer.checksum(&pkt).unwrap(), 0x1234); + assert_eq!(layer.ihl(&pkt).unwrap(), 5); + } + + #[test] + fn test_source_route_option() { + let route = vec![ + Ipv4Addr::new(10, 0, 0, 1), + Ipv4Addr::new(10, 0, 0, 2), + Ipv4Addr::new(10, 0, 0, 3), + ]; + + let pkt = Ipv4Builder::new() + .dst(Ipv4Addr::new(10, 0, 0, 4)) + .lsrr(route.clone()) + .build(); + + let layer = Ipv4Layer::at_offset(0); + let options = layer.options(&pkt).unwrap(); + + // Check that options are parsed correctly. + // We might get extra options (padding/NOP), so we just look for LSRR + let lsrr_option = options + .options + .iter() + .find(|opt| matches!(opt, Ipv4Option::Lsrr { .. })); + + assert!(lsrr_option.is_some(), "Expected LSRR option"); + + if let Some(Ipv4Option::Lsrr { + route: parsed_route, + .. + }) = lsrr_option + { + assert_eq!(parsed_route, &route); + } + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/checksum.rs b/crates/stackforge-core/src/layer/ipv4/checksum.rs new file mode 100644 index 0000000..4219e3e --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/checksum.rs @@ -0,0 +1,435 @@ +//! IPv4 header checksum calculation. +//! +//! Implements RFC 1071 Internet checksum algorithm used for IPv4 headers. +//! The checksum is computed lazily - only when the packet is serialized. + +/// Compute the Internet checksum (RFC 1071) over a byte slice. +/// +/// This is the standard one's complement sum used for IP, ICMP, TCP, and UDP. +/// +/// # Algorithm +/// +/// 1. Sum all 16-bit words in the data +/// 2. Add any odd byte as the high byte of a word +/// 3. Fold 32-bit sum to 16 bits by adding carry bits +/// 4. Take one's complement of the result +/// +/// # Example +/// +/// ```rust +/// use stackforge_core::layer::ipv4::checksum::ipv4_checksum; +/// +/// let header = [ +/// 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, +/// 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63, +/// 0xac, 0x10, 0x0a, 0x0c, +/// ]; +/// let checksum = ipv4_checksum(&header); +/// ``` +#[inline] +pub fn ipv4_checksum(data: &[u8]) -> u16 { + internet_checksum(data) +} + +/// Generic Internet checksum implementation (RFC 1071). +/// +/// Can be used for any protocol that uses the Internet checksum. +pub fn internet_checksum(data: &[u8]) -> u16 { + let mut sum: u32 = 0; + + // Process 16-bit words + let mut chunks = data.chunks_exact(2); + for chunk in chunks.by_ref() { + sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32; + } + + // Handle odd byte (pad with zero) + if let Some(&last) = chunks.remainder().first() { + sum += (last as u32) << 8; + } + + // Fold 32-bit sum to 16 bits (add carry) + while (sum >> 16) != 0 { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // One's complement + !sum as u16 +} + +/// Verify that a checksum is valid. +/// +/// When computed over data that includes a valid checksum, the result +/// should be 0x0000 or 0xFFFF. +#[inline] +pub fn verify_ipv4_checksum(data: &[u8]) -> bool { + let sum = internet_checksum(data); + sum == 0 || sum == 0xFFFF +} + +/// Incrementally update a checksum when a 16-bit field changes. +/// +/// This is more efficient than recomputing the entire checksum when +/// only a single field is modified. +/// +/// # Arguments +/// +/// * `old_checksum` - The existing checksum value +/// * `old_value` - The old 16-bit field value +/// * `new_value` - The new 16-bit field value +/// +/// # Returns +/// +/// The updated checksum value. +#[inline] +pub fn incremental_update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 { + // RFC 1624: HC' = ~(~HC + ~m + m') + let hc = !old_checksum as u32; + let m = !old_value as u32; + let m_prime = new_value as u32; + + let mut sum = hc + m + m_prime; + + // Fold carry + while (sum >> 16) != 0 { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + !sum as u16 +} + +/// Incrementally update checksum when a 32-bit field changes. +/// +/// Useful for updating checksum after IP address changes. +#[inline] +pub fn incremental_update_checksum_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 { + // Update for high 16 bits, then low 16 bits + let old_high = (old_value >> 16) as u16; + let old_low = (old_value & 0xFFFF) as u16; + let new_high = (new_value >> 16) as u16; + let new_low = (new_value & 0xFFFF) as u16; + + let tmp = incremental_update_checksum(old_checksum, old_high, new_high); + incremental_update_checksum(tmp, old_low, new_low) +} + +/// Compute the pseudo-header checksum for TCP/UDP. +/// +/// The pseudo-header includes: +/// - Source IP (4 bytes) +/// - Destination IP (4 bytes) +/// - Zero (1 byte) +/// - Protocol (1 byte) +/// - TCP/UDP length (2 bytes) +/// +/// # Arguments +/// +/// * `src_ip` - Source IP address (4 bytes) +/// * `dst_ip` - Destination IP address (4 bytes) +/// * `protocol` - IP protocol number +/// * `transport_len` - Length of the transport layer data +/// +/// # Returns +/// +/// The partial checksum from the pseudo-header. +pub fn pseudo_header_checksum( + src_ip: &[u8; 4], + dst_ip: &[u8; 4], + protocol: u8, + transport_len: u16, +) -> u32 { + let mut sum: u32 = 0; + + // Source IP + sum += u16::from_be_bytes([src_ip[0], src_ip[1]]) as u32; + sum += u16::from_be_bytes([src_ip[2], src_ip[3]]) as u32; + + // Destination IP + sum += u16::from_be_bytes([dst_ip[0], dst_ip[1]]) as u32; + sum += u16::from_be_bytes([dst_ip[2], dst_ip[3]]) as u32; + + // Zero + Protocol + sum += protocol as u32; + + // Transport length + sum += transport_len as u32; + + sum +} + +/// Compute complete transport layer checksum (TCP or UDP). +/// +/// # Arguments +/// +/// * `src_ip` - Source IP address +/// * `dst_ip` - Destination IP address +/// * `protocol` - IP protocol number (6 for TCP, 17 for UDP) +/// * `transport_data` - The complete transport layer header and payload +/// +/// # Returns +/// +/// The computed checksum value. +pub fn transport_checksum( + src_ip: &[u8; 4], + dst_ip: &[u8; 4], + protocol: u8, + transport_data: &[u8], +) -> u16 { + // Start with pseudo-header checksum + let mut sum = pseudo_header_checksum(src_ip, dst_ip, protocol, transport_data.len() as u16); + + // Add transport data + let mut chunks = transport_data.chunks_exact(2); + for chunk in chunks.by_ref() { + sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32; + } + + // Handle odd byte + if let Some(&last) = chunks.remainder().first() { + sum += (last as u32) << 8; + } + + // Fold and complement + while (sum >> 16) != 0 { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // For UDP, if checksum is 0, use 0xFFFF instead (RFC 768) + let result = !sum as u16; + if result == 0 && protocol == 17 { + 0xFFFF + } else { + result + } +} + +/// Helper to compute partial checksum (before folding and complement). +/// +/// Useful for computing checksum across multiple data segments. +pub fn partial_checksum(data: &[u8], initial: u32) -> u32 { + let mut sum = initial; + + let mut chunks = data.chunks_exact(2); + for chunk in chunks.by_ref() { + sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32; + } + + if let Some(&last) = chunks.remainder().first() { + sum += (last as u32) << 8; + } + + sum +} + +/// Finalize a partial checksum. +/// +/// Folds the 32-bit sum to 16 bits and takes one's complement. +#[inline] +pub fn finalize_checksum(sum: u32) -> u16 { + let mut s = sum; + while (s >> 16) != 0 { + s = (s & 0xFFFF) + (s >> 16); + } + !s as u16 +} + +/// Zero out checksum field in a buffer at the specified offset. +#[inline] +pub fn zero_checksum(buf: &mut [u8], offset: usize) { + if buf.len() >= offset + 2 { + buf[offset] = 0; + buf[offset + 1] = 0; + } +} + +/// Write checksum to buffer at the specified offset. +#[inline] +pub fn write_checksum(buf: &mut [u8], offset: usize, checksum: u16) { + if buf.len() >= offset + 2 { + let bytes = checksum.to_be_bytes(); + buf[offset] = bytes[0]; + buf[offset + 1] = bytes[1]; + } +} + +/// Read checksum from buffer at the specified offset. +#[inline] +pub fn read_checksum(buf: &[u8], offset: usize) -> Option { + if buf.len() >= offset + 2 { + Some(u16::from_be_bytes([buf[offset], buf[offset + 1]])) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ipv4_checksum() { + // Sample IPv4 header from RFC 1071 + let header = [ + 0x45, 0x00, 0x00, 0x3c, // Version, IHL, TOS, Total Length + 0x1c, 0x46, 0x40, 0x00, // ID, Flags, Fragment Offset + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol, Checksum (zeroed) + 0xac, 0x10, 0x0a, 0x63, // Source: 172.16.10.99 + 0xac, 0x10, 0x0a, 0x0c, // Dest: 172.16.10.12 + ]; + + let checksum = ipv4_checksum(&header); + + // Place checksum in header and verify + let mut header_with_cksum = header; + header_with_cksum[10] = (checksum >> 8) as u8; + header_with_cksum[11] = (checksum & 0xFF) as u8; + + assert!(verify_ipv4_checksum(&header_with_cksum)); + } + + #[test] + fn test_verify_valid_checksum() { + // Header with pre-computed valid checksum + let header = [ + 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1, + 0xe6, // checksum + 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, + ]; + + assert!(verify_ipv4_checksum(&header)); + } + + #[test] + fn test_verify_invalid_checksum() { + // Header with corrupted checksum + let header = [ + 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xFF, + 0xFF, // bad checksum + 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, + ]; + + assert!(!verify_ipv4_checksum(&header)); + } + + #[test] + fn test_incremental_update() { + // Original header with valid checksum + let mut header = [ + 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, + 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, + ]; + + // Compute initial checksum + let initial_checksum = ipv4_checksum(&header); + header[10] = (initial_checksum >> 8) as u8; + header[11] = (initial_checksum & 0xFF) as u8; + + // Change TTL from 0x40 to 0x3F using incremental update + let old_ttl_word = u16::from_be_bytes([header[8], header[9]]); + header[8] = 0x3F; + let new_ttl_word = u16::from_be_bytes([header[8], header[9]]); + + let new_checksum = + incremental_update_checksum(initial_checksum, old_ttl_word, new_ttl_word); + header[10] = (new_checksum >> 8) as u8; + header[11] = (new_checksum & 0xFF) as u8; + + // Verify the incrementally updated checksum is valid + assert!(verify_ipv4_checksum(&header)); + } + + #[test] + fn test_pseudo_header_checksum() { + let src = [192, 168, 1, 1]; + let dst = [192, 168, 1, 2]; + let protocol = 6; // TCP + let length = 20; // TCP header only + + let sum = pseudo_header_checksum(&src, &dst, protocol, length); + + // Verify sum contains expected components + assert!(sum > 0); + } + + #[test] + fn test_transport_checksum_tcp() { + let src_ip = [192, 168, 1, 1]; + let dst_ip = [192, 168, 1, 2]; + + // Minimal TCP header with zeroed checksum + let tcp_header = [ + 0x00, 0x50, // Source port: 80 + 0x1F, 0x90, // Dest port: 8080 + 0x00, 0x00, 0x00, 0x01, // Seq number + 0x00, 0x00, 0x00, 0x00, // Ack number + 0x50, 0x02, // Data offset + flags (SYN) + 0xFF, 0xFF, // Window + 0x00, 0x00, // Checksum (zeroed) + 0x00, 0x00, // Urgent pointer + ]; + + let checksum = transport_checksum(&src_ip, &dst_ip, 6, &tcp_header); + assert_ne!(checksum, 0); + } + + #[test] + fn test_transport_checksum_udp_zero() { + let src_ip = [0, 0, 0, 0]; + let dst_ip = [0, 0, 0, 0]; + + // UDP header that would result in zero checksum + // For UDP, zero should become 0xFFFF + let udp_header = [ + 0x00, 0x00, // Source port + 0x00, 0x00, // Dest port + 0x00, 0x08, // Length + 0x00, 0x00, // Checksum (zeroed) + ]; + + let checksum = transport_checksum(&src_ip, &dst_ip, 17, &udp_header); + // UDP checksum should never be 0 (use 0xFFFF instead) + assert_ne!(checksum, 0); + } + + #[test] + fn test_partial_checksum() { + let data1 = [0x01, 0x02, 0x03, 0x04]; + let data2 = [0x05, 0x06, 0x07, 0x08]; + + // Compute separately and combine + let sum1 = partial_checksum(&data1, 0); + let sum2 = partial_checksum(&data2, sum1); + let checksum1 = finalize_checksum(sum2); + + // Compute together + let combined = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]; + let checksum2 = internet_checksum(&combined); + + assert_eq!(checksum1, checksum2); + } + + #[test] + fn test_odd_length_data() { + // Test with odd number of bytes + let data = [0x45, 0x00, 0x00, 0x3c, 0x1c]; + let checksum = internet_checksum(&data); + + // Should handle odd byte correctly + assert_ne!(checksum, 0); + } + + #[test] + fn test_zero_and_write_checksum() { + let mut buf = [0x45, 0x00, 0xAB, 0xCD, 0x00, 0x00]; + + zero_checksum(&mut buf, 2); + assert_eq!(buf[2], 0); + assert_eq!(buf[3], 0); + + write_checksum(&mut buf, 2, 0x1234); + assert_eq!(buf[2], 0x12); + assert_eq!(buf[3], 0x34); + + assert_eq!(read_checksum(&buf, 2), Some(0x1234)); + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/fragmentation.rs b/crates/stackforge-core/src/layer/ipv4/fragmentation.rs new file mode 100644 index 0000000..d53aa51 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/fragmentation.rs @@ -0,0 +1,740 @@ +//! IPv4 fragmentation and reassembly. +//! +//! Provides utilities for fragmenting large IP packets and reassembling +//! fragments back into complete packets. + +use std::net::Ipv4Addr; + +use crate::layer::field::FieldError; + +use super::builder::Ipv4Builder; +use super::checksum::ipv4_checksum; +use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets}; +use super::options::Ipv4Options; + +/// Default MTU for fragmentation. +pub const DEFAULT_MTU: usize = 1500; + +/// Minimum fragment payload size (must be multiple of 8). +pub const MIN_FRAGMENT_PAYLOAD: usize = 8; + +/// Maximum fragment offset (in 8-byte units). +pub const MAX_FRAGMENT_OFFSET: u16 = 0x1FFF; + +/// Information about a fragment. +#[derive(Debug, Clone)] +pub struct FragmentInfo { + /// Fragment offset in bytes. + pub offset: u32, + /// Fragment payload length. + pub length: usize, + /// Whether this is the last fragment. + pub last: bool, + /// The fragment data (header + payload). + pub data: Vec, +} + +impl FragmentInfo { + /// Get the end offset (offset + length). + pub fn end_offset(&self) -> u32 { + self.offset + self.length as u32 + } +} + +/// A single fragment ready for transmission. +#[derive(Debug, Clone)] +pub struct Fragment { + /// The complete fragment packet (header + payload). + pub packet: Vec, + /// Fragment offset in bytes. + pub offset: u32, + /// Whether this is the last fragment. + pub last: bool, +} + +/// Fragmenter for IPv4 packets. +#[derive(Debug, Clone)] +pub struct Ipv4Fragmenter { + /// Maximum fragment size (including IP header). + pub mtu: usize, + /// Whether to copy options to all fragments. + pub copy_options: bool, +} + +impl Default for Ipv4Fragmenter { + fn default() -> Self { + Self { + mtu: DEFAULT_MTU, + copy_options: true, + } + } +} + +impl Ipv4Fragmenter { + /// Create a new fragmenter with default MTU. + pub fn new() -> Self { + Self::default() + } + + /// Create a fragmenter with a specific MTU. + pub fn with_mtu(mtu: usize) -> Self { + Self { + mtu, + ..Self::default() + } + } + + /// Set the MTU. + pub fn mtu(mut self, mtu: usize) -> Self { + self.mtu = mtu; + self + } + + /// Set whether to copy options to non-first fragments. + pub fn copy_options(mut self, copy: bool) -> Self { + self.copy_options = copy; + self + } + + /// Check if a packet needs fragmentation. + pub fn needs_fragmentation(&self, packet: &[u8]) -> bool { + packet.len() > self.mtu + } + + /// Fragment an IPv4 packet. + /// + /// Returns a vector of fragment packets, or None if the packet + /// has the Don't Fragment flag set and exceeds MTU. + pub fn fragment(&self, packet: &[u8]) -> Result, FragmentError> { + let layer = Ipv4Layer::at_offset_dynamic(packet, 0) + .map_err(|e| FragmentError::ParseError(e.to_string()))?; + + // Check DF flag + let flags = layer.flags(packet).unwrap_or(Ipv4Flags::NONE); + if flags.df && packet.len() > self.mtu { + return Err(FragmentError::DontFragmentSet { + packet_size: packet.len(), + mtu: self.mtu, + }); + } + + // No fragmentation needed + if packet.len() <= self.mtu { + return Ok(vec![Fragment { + packet: packet.to_vec(), + offset: 0, + last: true, + }]); + } + + // Get original header info + let header_len = layer.calculate_header_len(packet); + let total_len = layer.total_len(packet).unwrap_or(packet.len() as u16) as usize; + let payload_start = header_len; + let payload_len = total_len.saturating_sub(header_len); + + // Parse options for copying + let options = if header_len > IPV4_MIN_HEADER_LEN { + layer.options(packet).ok() + } else { + None + }; + + // Calculate fragment sizes + // Fragment payload must be multiple of 8 + let first_header_len = header_len; // First fragment keeps all options + let other_header_len = if self.copy_options { + if let Some(ref opts) = options { + IPV4_MIN_HEADER_LEN + opts.copied_options().padded_len() + } else { + IPV4_MIN_HEADER_LEN + } + } else { + IPV4_MIN_HEADER_LEN + }; + + let first_payload_max = ((self.mtu - first_header_len) / 8) * 8; + let other_payload_max = ((self.mtu - other_header_len) / 8) * 8; + + if first_payload_max < MIN_FRAGMENT_PAYLOAD || other_payload_max < MIN_FRAGMENT_PAYLOAD { + return Err(FragmentError::MtuTooSmall { + mtu: self.mtu, + min_required: other_header_len + MIN_FRAGMENT_PAYLOAD, + }); + } + + let mut fragments = Vec::new(); + let mut offset: u32 = 0; + let mut remaining = payload_len; + let mut is_first = true; + + // Original fragment offset (in case we're fragmenting a fragment) + let original_offset = layer.frag_offset(packet).unwrap_or(0) as u32 * 8; + let original_mf = flags.mf; + + while remaining > 0 { + let header_size = if is_first { + first_header_len + } else { + other_header_len + }; + let max_payload = if is_first { + first_payload_max + } else { + other_payload_max + }; + + let frag_payload_len = remaining.min(max_payload); + let is_last = frag_payload_len == remaining && !original_mf; + + // Ensure non-last fragments have payload that's multiple of 8 + let actual_payload_len = if !is_last { + (frag_payload_len / 8) * 8 + } else { + frag_payload_len + }; + + if actual_payload_len == 0 { + break; + } + + // Build fragment + let frag_packet = self.build_fragment( + packet, + &layer, + &options, + offset, + actual_payload_len, + !is_last, + is_first, + original_offset, + )?; + + fragments.push(Fragment { + packet: frag_packet, + offset: original_offset + offset, + last: is_last, + }); + + offset += actual_payload_len as u32; + remaining -= actual_payload_len; + is_first = false; + } + + Ok(fragments) + } + + /// Build a single fragment packet. + fn build_fragment( + &self, + original: &[u8], + layer: &Ipv4Layer, + options: &Option, + offset: u32, + payload_len: usize, + more_fragments: bool, + is_first: bool, + original_offset: u32, + ) -> Result, FragmentError> { + let original_header_len = layer.calculate_header_len(original); + + // Determine header for this fragment + let frag_options = if is_first { + options.clone() + } else if self.copy_options { + options.as_ref().map(|o| o.copied_options()) + } else { + None + }; + + let frag_header_len = if let Some(ref opts) = frag_options { + IPV4_MIN_HEADER_LEN + opts.padded_len() + } else { + IPV4_MIN_HEADER_LEN + }; + + let total_len = frag_header_len + payload_len; + let mut buf = vec![0u8; total_len]; + + // Copy and modify header + buf[..IPV4_MIN_HEADER_LEN].copy_from_slice(&original[..IPV4_MIN_HEADER_LEN]); + + // Update IHL + let ihl = (frag_header_len / 4) as u8; + buf[offsets::VERSION_IHL] = (buf[offsets::VERSION_IHL] & 0xF0) | (ihl & 0x0F); + + // Update total length + buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8; + buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8; + + // Update flags and fragment offset + let frag_offset_units = ((original_offset + offset) / 8) as u16; + let mut flags_byte = if more_fragments { 0x20 } else { 0x00 }; // MF flag + + // Preserve DF flag? No - if we're fragmenting, DF must be clear + // Preserve evil bit from original + let orig_flags = layer.flags(original).unwrap_or(Ipv4Flags::NONE); + if orig_flags.reserved { + flags_byte |= 0x80; + } + + let flags_frag = ((flags_byte as u16) << 8) | frag_offset_units; + buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8; + buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8; + + // Copy options (if any) + if let Some(ref opts) = frag_options { + let opts_bytes = opts.to_bytes(); + buf[offsets::OPTIONS..offsets::OPTIONS + opts_bytes.len()].copy_from_slice(&opts_bytes); + } + + // Copy payload portion + let payload_start = original_header_len + offset as usize; + let payload_end = payload_start + payload_len; + if payload_end <= original.len() { + buf[frag_header_len..].copy_from_slice(&original[payload_start..payload_end]); + } + + // Recompute checksum + buf[offsets::CHECKSUM] = 0; + buf[offsets::CHECKSUM + 1] = 0; + let checksum = ipv4_checksum(&buf[..frag_header_len]); + buf[offsets::CHECKSUM] = (checksum >> 8) as u8; + buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8; + + Ok(buf) + } +} + +/// Errors that can occur during fragmentation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FragmentError { + /// Don't Fragment flag is set but packet exceeds MTU. + DontFragmentSet { packet_size: usize, mtu: usize }, + /// MTU is too small to fragment. + MtuTooSmall { mtu: usize, min_required: usize }, + /// Error parsing the packet. + ParseError(String), +} + +impl std::fmt::Display for FragmentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::DontFragmentSet { packet_size, mtu } => { + write!( + f, + "packet size {} exceeds MTU {} but DF flag is set", + packet_size, mtu + ) + } + Self::MtuTooSmall { mtu, min_required } => { + write!( + f, + "MTU {} is too small, minimum required is {}", + mtu, min_required + ) + } + Self::ParseError(msg) => write!(f, "parse error: {}", msg), + } + } +} + +impl std::error::Error for FragmentError {} + +/// Key for identifying a fragment group. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FragmentKey { + pub src: Ipv4Addr, + pub dst: Ipv4Addr, + pub id: u16, + pub protocol: u8, +} + +impl FragmentKey { + /// Create a key from a packet. + pub fn from_packet(packet: &[u8]) -> Result { + let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?; + Ok(Self { + src: layer.src(packet)?, + dst: layer.dst(packet)?, + id: layer.id(packet)?, + protocol: layer.protocol(packet)?, + }) + } +} + +/// A collection of fragments being reassembled. +#[derive(Debug, Clone)] +pub struct FragmentGroup { + /// The fragment key. + pub key: FragmentKey, + /// Collected fragments. + pub fragments: Vec, + /// Total expected length (known when last fragment received). + pub total_length: Option, + /// First fragment header (for reconstruction). + pub first_header: Option>, + /// Timestamp when first fragment was received. + pub first_received: std::time::Instant, +} + +impl FragmentGroup { + /// Create a new fragment group. + pub fn new(key: FragmentKey) -> Self { + Self { + key, + fragments: Vec::new(), + total_length: None, + first_header: None, + first_received: std::time::Instant::now(), + } + } + + /// Add a fragment to the group. + pub fn add_fragment(&mut self, packet: &[u8]) -> Result<(), FieldError> { + let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?; + let header_len = layer.calculate_header_len(packet); + let total_len = layer.total_len(packet)? as usize; + let flags = layer.flags(packet)?; + let offset = layer.frag_offset(packet)? as u32 * 8; + let payload_len = total_len.saturating_sub(header_len); + + // Store first fragment header + if offset == 0 { + self.first_header = Some(packet[..header_len].to_vec()); + } + + // If this is the last fragment, we know the total length + if !flags.mf { + self.total_length = Some(offset + payload_len as u32); + } + + // Add fragment info + self.fragments.push(FragmentInfo { + offset, + length: payload_len, + last: !flags.mf, + data: packet.to_vec(), + }); + + Ok(()) + } + + /// Check if all fragments have been received. + pub fn is_complete(&self) -> bool { + let total = match self.total_length { + Some(t) => t, + None => return false, + }; + + // Sort fragments by offset + let mut sorted: Vec<_> = self.fragments.iter().collect(); + sorted.sort_by_key(|f| f.offset); + + // Check for gaps + let mut expected_offset = 0u32; + for frag in sorted { + if frag.offset != expected_offset { + return false; + } + expected_offset = frag.end_offset(); + } + + expected_offset >= total + } + + /// Reassemble the fragments into a complete packet. + pub fn reassemble(&self) -> Result, ReassemblyError> { + if !self.is_complete() { + return Err(ReassemblyError::Incomplete); + } + + let total_length = self.total_length.ok_or(ReassemblyError::Incomplete)?; + let first_header = self + .first_header + .as_ref() + .ok_or(ReassemblyError::MissingFirstFragment)?; + + let header_len = first_header.len(); + let mut result = vec![0u8; header_len + total_length as usize]; + + // Copy header + result[..header_len].copy_from_slice(first_header); + + // Sort and copy payloads + let mut sorted: Vec<_> = self.fragments.iter().collect(); + sorted.sort_by_key(|f| f.offset); + + for frag in sorted { + let layer = Ipv4Layer::at_offset_dynamic(&frag.data, 0) + .map_err(|e| ReassemblyError::ParseError(e.to_string()))?; + let frag_header_len = layer.calculate_header_len(&frag.data); + + let src_start = frag_header_len; + let src_end = src_start + frag.length; + let dst_start = header_len + frag.offset as usize; + let dst_end = dst_start + frag.length; + + if src_end <= frag.data.len() && dst_end <= result.len() { + result[dst_start..dst_end].copy_from_slice(&frag.data[src_start..src_end]); + } + } + + // Update header fields + let new_total_len = (header_len + total_length as usize) as u16; + result[offsets::TOTAL_LEN] = (new_total_len >> 8) as u8; + result[offsets::TOTAL_LEN + 1] = (new_total_len & 0xFF) as u8; + + // Clear MF flag and fragment offset + result[offsets::FLAGS_FRAG] &= 0xC0; // Preserve DF and reserved + result[offsets::FLAGS_FRAG + 1] = 0; + + // Recompute checksum + result[offsets::CHECKSUM] = 0; + result[offsets::CHECKSUM + 1] = 0; + let checksum = ipv4_checksum(&result[..header_len]); + result[offsets::CHECKSUM] = (checksum >> 8) as u8; + result[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8; + + Ok(result) + } +} + +/// Errors during reassembly. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReassemblyError { + /// Not all fragments received. + Incomplete, + /// First fragment (offset 0) not received. + MissingFirstFragment, + /// Fragment overlaps with another. + Overlap, + /// Error parsing fragment. + ParseError(String), + /// Timeout waiting for fragments. + Timeout, +} + +impl std::fmt::Display for ReassemblyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Incomplete => write!(f, "not all fragments received"), + Self::MissingFirstFragment => write!(f, "first fragment not received"), + Self::Overlap => write!(f, "fragment overlap detected"), + Self::ParseError(msg) => write!(f, "parse error: {}", msg), + Self::Timeout => write!(f, "timeout waiting for fragments"), + } + } +} + +impl std::error::Error for ReassemblyError {} + +/// Reassemble fragments from a list of packets. +/// +/// This is a convenience function for simple reassembly. +/// For stateful reassembly across multiple calls, use `FragmentGroup`. +pub fn reassemble_fragments(fragments: &[Vec]) -> Result, ReassemblyError> { + if fragments.is_empty() { + return Err(ReassemblyError::Incomplete); + } + + // Get key from first fragment + let key = FragmentKey::from_packet(&fragments[0]) + .map_err(|e| ReassemblyError::ParseError(e.to_string()))?; + + let mut group = FragmentGroup::new(key); + + for frag in fragments { + group + .add_fragment(frag) + .map_err(|e| ReassemblyError::ParseError(e.to_string()))?; + } + + group.reassemble() +} + +/// Fragment a packet into multiple fragments. +/// +/// Convenience function using default fragmenter. +pub fn fragment_packet(packet: &[u8], mtu: usize) -> Result>, FragmentError> { + let fragmenter = Ipv4Fragmenter::with_mtu(mtu); + let fragments = fragmenter.fragment(packet)?; + Ok(fragments.into_iter().map(|f| f.packet).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn build_large_packet(payload_size: usize) -> Vec { + Ipv4Builder::new() + .src(Ipv4Addr::new(192, 168, 1, 1)) + .dst(Ipv4Addr::new(192, 168, 1, 2)) + .id(0x1234) + .protocol(17) // UDP + .payload(vec![0xAA; payload_size]) + .build() + } + + #[test] + fn test_no_fragmentation_needed() { + let packet = build_large_packet(100); + let fragmenter = Ipv4Fragmenter::with_mtu(1500); + + assert!(!fragmenter.needs_fragmentation(&packet)); + + let frags = fragmenter.fragment(&packet).unwrap(); + assert_eq!(frags.len(), 1); + assert!(frags[0].last); + assert_eq!(frags[0].offset, 0); + } + + #[test] + fn test_basic_fragmentation() { + let packet = build_large_packet(3000); + let fragmenter = Ipv4Fragmenter::with_mtu(1500); + + assert!(fragmenter.needs_fragmentation(&packet)); + + let frags = fragmenter.fragment(&packet).unwrap(); + assert!(frags.len() >= 2); + + // First fragment + assert_eq!(frags[0].offset, 0); + assert!(!frags[0].last); + + // Last fragment + assert!(frags.last().unwrap().last); + + // Verify each fragment is within MTU + for frag in &frags { + assert!(frag.packet.len() <= 1500); + } + } + + #[test] + fn test_dont_fragment_flag() { + let packet = Ipv4Builder::new() + .src(Ipv4Addr::new(192, 168, 1, 1)) + .dst(Ipv4Addr::new(192, 168, 1, 2)) + .dont_fragment() + .payload(vec![0; 2000]) + .build(); + + let fragmenter = Ipv4Fragmenter::with_mtu(1500); + let result = fragmenter.fragment(&packet); + + assert!(matches!(result, Err(FragmentError::DontFragmentSet { .. }))); + } + + #[test] + fn test_reassembly() { + let original = build_large_packet(3000); + let fragmenter = Ipv4Fragmenter::with_mtu(1000); + + let frags = fragmenter.fragment(&original).unwrap(); + let frag_packets: Vec> = frags.into_iter().map(|f| f.packet).collect(); + + let reassembled = reassemble_fragments(&frag_packets).unwrap(); + + // Check payload matches + let orig_layer = Ipv4Layer::at_offset(0); + let reasm_layer = Ipv4Layer::at_offset(0); + + let orig_payload = orig_layer.payload(&original).unwrap(); + let reasm_payload = reasm_layer.payload(&reassembled).unwrap(); + + assert_eq!(orig_payload, reasm_payload); + } + + #[test] + fn test_fragment_key() { + let packet = build_large_packet(100); + let key = FragmentKey::from_packet(&packet).unwrap(); + + assert_eq!(key.src, Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(key.dst, Ipv4Addr::new(192, 168, 1, 2)); + assert_eq!(key.id, 0x1234); + assert_eq!(key.protocol, 17); + } + + #[test] + fn test_fragment_group_complete() { + let packet = build_large_packet(2000); + let fragmenter = Ipv4Fragmenter::with_mtu(1000); + + let frags = fragmenter.fragment(&packet).unwrap(); + let key = FragmentKey::from_packet(&frags[0].packet).unwrap(); + + let mut group = FragmentGroup::new(key); + + // Add fragments in random order + for frag in frags.iter().rev() { + group.add_fragment(&frag.packet).unwrap(); + } + + assert!(group.is_complete()); + } + + #[test] + fn test_fragment_group_incomplete() { + let packet = build_large_packet(2000); + let fragmenter = Ipv4Fragmenter::with_mtu(1000); + + let frags = fragmenter.fragment(&packet).unwrap(); + let key = FragmentKey::from_packet(&frags[0].packet).unwrap(); + + let mut group = FragmentGroup::new(key); + + // Add only first fragment + group.add_fragment(&frags[0].packet).unwrap(); + + assert!(!group.is_complete()); + } + + #[test] + fn test_small_mtu() { + let packet = build_large_packet(1000); + let fragmenter = Ipv4Fragmenter::with_mtu(100); + + let frags = fragmenter.fragment(&packet).unwrap(); + + // Should create many small fragments + assert!(frags.len() > 10); + + // All should be within MTU + for frag in &frags { + assert!(frag.packet.len() <= 100); + } + } + + #[test] + fn test_fragment_offset_alignment() { + let packet = build_large_packet(1000); + let fragmenter = Ipv4Fragmenter::with_mtu(500); + + let frags = fragmenter.fragment(&packet).unwrap(); + + // All non-last fragments should have payloads that are multiples of 8 + for frag in &frags[..frags.len() - 1] { + let layer = Ipv4Layer::at_offset(0); + let header_len = layer.calculate_header_len(&frag.packet); + let payload_len = frag.packet.len() - header_len; + assert_eq!( + payload_len % 8, + 0, + "payload len {} not multiple of 8", + payload_len + ); + } + } + + #[test] + fn test_mtu_too_small() { + let packet = build_large_packet(100); + let fragmenter = Ipv4Fragmenter::with_mtu(20); // Too small even for header + + let result = fragmenter.fragment(&packet); + assert!(matches!(result, Err(FragmentError::MtuTooSmall { .. }))); + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/header.rs b/crates/stackforge-core/src/layer/ipv4/header.rs new file mode 100644 index 0000000..6a9bbf0 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/header.rs @@ -0,0 +1,1007 @@ +//! IPv4 header layer implementation. +//! +//! Provides zero-copy access to IPv4 header fields. + +use std::net::Ipv4Addr; + +use crate::layer::field::{Field, FieldDesc, FieldError, FieldType, FieldValue}; +use crate::layer::{Layer, LayerIndex, LayerKind}; + +use super::checksum::ipv4_checksum; +use super::options::{Ipv4Options, parse_options}; +use super::protocol; +use super::routing::Ipv4Route; + +/// Minimum IPv4 header length (no options). +pub const IPV4_MIN_HEADER_LEN: usize = 20; + +/// Maximum IPv4 header length (with maximum options). +pub const IPV4_MAX_HEADER_LEN: usize = 60; + +/// Maximum total length of an IPv4 packet. +pub const IPV4_MAX_PACKET_LEN: usize = 65535; + +/// Field offsets within the IPv4 header. +pub mod offsets { + /// Version (4 bits) + IHL (4 bits) + pub const VERSION_IHL: usize = 0; + /// DSCP (6 bits) + ECN (2 bits) - also known as TOS + pub const TOS: usize = 1; + /// Total length (16 bits) + pub const TOTAL_LEN: usize = 2; + /// Identification (16 bits) + pub const ID: usize = 4; + /// Flags (3 bits) + Fragment offset (13 bits) + pub const FLAGS_FRAG: usize = 6; + /// Time to live (8 bits) + pub const TTL: usize = 8; + /// Protocol (8 bits) + pub const PROTOCOL: usize = 9; + /// Header checksum (16 bits) + pub const CHECKSUM: usize = 10; + /// Source address (32 bits) + pub const SRC: usize = 12; + /// Destination address (32 bits) + pub const DST: usize = 16; + /// Options start (if IHL > 5) + pub const OPTIONS: usize = 20; +} + +/// Field descriptors for dynamic access. +pub static FIELDS: &[FieldDesc] = &[ + FieldDesc::new("version", offsets::VERSION_IHL, 1, FieldType::U8), + FieldDesc::new("ihl", offsets::VERSION_IHL, 1, FieldType::U8), + FieldDesc::new("tos", offsets::TOS, 1, FieldType::U8), + FieldDesc::new("dscp", offsets::TOS, 1, FieldType::U8), + FieldDesc::new("ecn", offsets::TOS, 1, FieldType::U8), + FieldDesc::new("len", offsets::TOTAL_LEN, 2, FieldType::U16), + FieldDesc::new("id", offsets::ID, 2, FieldType::U16), + FieldDesc::new("flags", offsets::FLAGS_FRAG, 1, FieldType::U8), + FieldDesc::new("frag", offsets::FLAGS_FRAG, 2, FieldType::U16), + FieldDesc::new("ttl", offsets::TTL, 1, FieldType::U8), + FieldDesc::new("proto", offsets::PROTOCOL, 1, FieldType::U8), + FieldDesc::new("chksum", offsets::CHECKSUM, 2, FieldType::U16), + FieldDesc::new("src", offsets::SRC, 4, FieldType::Ipv4), + FieldDesc::new("dst", offsets::DST, 4, FieldType::Ipv4), +]; + +/// IPv4 flags in the flags/fragment offset field. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Ipv4Flags { + /// Reserved/Evil bit (should be 0) + pub reserved: bool, + /// Don't Fragment + pub df: bool, + /// More Fragments + pub mf: bool, +} + +impl Ipv4Flags { + pub const NONE: Self = Self { + reserved: false, + df: false, + mf: false, + }; + + pub const DF: Self = Self { + reserved: false, + df: true, + mf: false, + }; + + pub const MF: Self = Self { + reserved: false, + df: false, + mf: true, + }; + + /// Create flags from a raw byte value (upper 3 bits). + #[inline] + pub fn from_byte(byte: u8) -> Self { + Self { + reserved: (byte & 0x80) != 0, + df: (byte & 0x40) != 0, + mf: (byte & 0x20) != 0, + } + } + + /// Convert to a raw byte value (upper 3 bits). + #[inline] + pub fn to_byte(self) -> u8 { + let mut b = 0u8; + if self.reserved { + b |= 0x80; + } + if self.df { + b |= 0x40; + } + if self.mf { + b |= 0x20; + } + b + } + + /// Check if this is a fragment (MF set or offset > 0). + #[inline] + pub fn is_fragment(self, offset: u16) -> bool { + self.mf || offset > 0 + } +} + +impl std::fmt::Display for Ipv4Flags { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut parts = Vec::new(); + if self.reserved { + parts.push("evil"); + } + if self.df { + parts.push("DF"); + } + if self.mf { + parts.push("MF"); + } + if parts.is_empty() { + write!(f, "-") + } else { + write!(f, "{}", parts.join("+")) + } + } +} + +/// A view into an IPv4 packet header. +#[derive(Debug, Clone)] +pub struct Ipv4Layer { + pub index: LayerIndex, +} + +impl Ipv4Layer { + /// Create a new IPv4 layer view with specified bounds. + #[inline] + pub const fn new(start: usize, end: usize) -> Self { + Self { + index: LayerIndex::new(LayerKind::Ipv4, start, end), + } + } + + /// Create a layer at offset 0 with minimum header length. + #[inline] + pub const fn at_start() -> Self { + Self::new(0, IPV4_MIN_HEADER_LEN) + } + + /// Create a layer at the specified offset with minimum header length. + #[inline] + pub const fn at_offset(offset: usize) -> Self { + Self::new(offset, offset + IPV4_MIN_HEADER_LEN) + } + + /// Create a layer at offset, calculating actual header length from IHL. + pub fn at_offset_dynamic(buf: &[u8], offset: usize) -> Result { + if buf.len() < offset + 1 { + return Err(FieldError::BufferTooShort { + offset, + need: 1, + have: buf.len().saturating_sub(offset), + }); + } + + let ihl = (buf[offset] & 0x0F) as usize; + let header_len = ihl * 4; + + if header_len < IPV4_MIN_HEADER_LEN { + return Err(FieldError::InvalidValue(format!( + "IHL {} is less than minimum (5)", + ihl + ))); + } + + if buf.len() < offset + header_len { + return Err(FieldError::BufferTooShort { + offset, + need: header_len, + have: buf.len().saturating_sub(offset), + }); + } + + Ok(Self::new(offset, offset + header_len)) + } + + /// Validate that the buffer contains a valid IPv4 header at the offset. + pub fn validate(buf: &[u8], offset: usize) -> Result<(), FieldError> { + if buf.len() < offset + IPV4_MIN_HEADER_LEN { + return Err(FieldError::BufferTooShort { + offset, + need: IPV4_MIN_HEADER_LEN, + have: buf.len().saturating_sub(offset), + }); + } + + let version = (buf[offset] >> 4) & 0x0F; + if version != 4 { + return Err(FieldError::InvalidValue(format!( + "not IPv4: version = {}", + version + ))); + } + + let ihl = (buf[offset] & 0x0F) as usize; + if ihl < 5 { + return Err(FieldError::InvalidValue(format!( + "IHL {} is less than minimum (5)", + ihl + ))); + } + + let header_len = ihl * 4; + if buf.len() < offset + header_len { + return Err(FieldError::BufferTooShort { + offset, + need: header_len, + have: buf.len().saturating_sub(offset), + }); + } + + Ok(()) + } + + /// Calculate the actual header length from the buffer. + pub fn calculate_header_len(&self, buf: &[u8]) -> usize { + self.ihl(buf).map(|ihl| (ihl as usize) * 4).unwrap_or(20) + } + + /// Get the options length (header length - 20). + pub fn options_len(&self, buf: &[u8]) -> usize { + self.calculate_header_len(buf) + .saturating_sub(IPV4_MIN_HEADER_LEN) + } + + // ========== Field Readers ========== + + /// Read the version field (should be 4). + #[inline] + pub fn version(&self, buf: &[u8]) -> Result { + let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?; + Ok((b >> 4) & 0x0F) + } + + /// Read the Internet Header Length (in 32-bit words). + #[inline] + pub fn ihl(&self, buf: &[u8]) -> Result { + let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?; + Ok(b & 0x0F) + } + + /// Read the header length in bytes. + #[inline] + pub fn header_len_bytes(&self, buf: &[u8]) -> Result { + Ok((self.ihl(buf)? as usize) * 4) + } + + /// Read the Type of Service (TOS) field. + #[inline] + pub fn tos(&self, buf: &[u8]) -> Result { + u8::read(buf, self.index.start + offsets::TOS) + } + + /// Read the DSCP (Differentiated Services Code Point). + #[inline] + pub fn dscp(&self, buf: &[u8]) -> Result { + Ok((self.tos(buf)? >> 2) & 0x3F) + } + + /// Read the ECN (Explicit Congestion Notification). + #[inline] + pub fn ecn(&self, buf: &[u8]) -> Result { + Ok(self.tos(buf)? & 0x03) + } + + /// Read the total length field. + #[inline] + pub fn total_len(&self, buf: &[u8]) -> Result { + u16::read(buf, self.index.start + offsets::TOTAL_LEN) + } + + /// Read the identification field. + #[inline] + pub fn id(&self, buf: &[u8]) -> Result { + u16::read(buf, self.index.start + offsets::ID) + } + + /// Read the flags as a structured type. + #[inline] + pub fn flags(&self, buf: &[u8]) -> Result { + let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?; + Ok(Ipv4Flags::from_byte(b)) + } + + /// Read the raw flags byte (upper 3 bits of flags/frag field). + #[inline] + pub fn flags_raw(&self, buf: &[u8]) -> Result { + let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?; + Ok((b >> 5) & 0x07) + } + + /// Read the fragment offset (in 8-byte units). + #[inline] + pub fn frag_offset(&self, buf: &[u8]) -> Result { + let val = u16::read(buf, self.index.start + offsets::FLAGS_FRAG)?; + Ok(val & 0x1FFF) + } + + /// Read the fragment offset in bytes. + #[inline] + pub fn frag_offset_bytes(&self, buf: &[u8]) -> Result { + Ok((self.frag_offset(buf)? as u32) * 8) + } + + /// Read the Time to Live field. + #[inline] + pub fn ttl(&self, buf: &[u8]) -> Result { + u8::read(buf, self.index.start + offsets::TTL) + } + + /// Read the protocol field. + #[inline] + pub fn protocol(&self, buf: &[u8]) -> Result { + u8::read(buf, self.index.start + offsets::PROTOCOL) + } + + /// Read the header checksum. + #[inline] + pub fn checksum(&self, buf: &[u8]) -> Result { + u16::read(buf, self.index.start + offsets::CHECKSUM) + } + + /// Read the source IP address. + #[inline] + pub fn src(&self, buf: &[u8]) -> Result { + Ipv4Addr::read(buf, self.index.start + offsets::SRC) + } + + /// Read the destination IP address. + #[inline] + pub fn dst(&self, buf: &[u8]) -> Result { + Ipv4Addr::read(buf, self.index.start + offsets::DST) + } + + /// Read the options bytes (if any). + pub fn options_bytes<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> { + let header_len = self.calculate_header_len(buf); + let opts_start = self.index.start + IPV4_MIN_HEADER_LEN; + let opts_end = self.index.start + header_len; + + if buf.len() < opts_end { + return Err(FieldError::BufferTooShort { + offset: opts_start, + need: header_len - IPV4_MIN_HEADER_LEN, + have: buf.len().saturating_sub(opts_start), + }); + } + + Ok(&buf[opts_start..opts_end]) + } + + /// Parse and return the options. + pub fn options(&self, buf: &[u8]) -> Result { + let opts_bytes = self.options_bytes(buf)?; + parse_options(opts_bytes) + } + + // ========== Field Writers ========== + + /// Set the version field. + #[inline] + pub fn set_version(&self, buf: &mut [u8], version: u8) -> Result<(), FieldError> { + let offset = self.index.start + offsets::VERSION_IHL; + let current = u8::read(buf, offset)?; + let new_val = (current & 0x0F) | ((version & 0x0F) << 4); + new_val.write(buf, offset) + } + + /// Set the IHL field. + #[inline] + pub fn set_ihl(&self, buf: &mut [u8], ihl: u8) -> Result<(), FieldError> { + let offset = self.index.start + offsets::VERSION_IHL; + let current = u8::read(buf, offset)?; + let new_val = (current & 0xF0) | (ihl & 0x0F); + new_val.write(buf, offset) + } + + /// Set the TOS field. + #[inline] + pub fn set_tos(&self, buf: &mut [u8], tos: u8) -> Result<(), FieldError> { + tos.write(buf, self.index.start + offsets::TOS) + } + + /// Set the DSCP field. + #[inline] + pub fn set_dscp(&self, buf: &mut [u8], dscp: u8) -> Result<(), FieldError> { + let offset = self.index.start + offsets::TOS; + let current = u8::read(buf, offset)?; + let new_val = (current & 0x03) | ((dscp & 0x3F) << 2); + new_val.write(buf, offset) + } + + /// Set the ECN field. + #[inline] + pub fn set_ecn(&self, buf: &mut [u8], ecn: u8) -> Result<(), FieldError> { + let offset = self.index.start + offsets::TOS; + let current = u8::read(buf, offset)?; + let new_val = (current & 0xFC) | (ecn & 0x03); + new_val.write(buf, offset) + } + + /// Set the total length field. + #[inline] + pub fn set_total_len(&self, buf: &mut [u8], len: u16) -> Result<(), FieldError> { + len.write(buf, self.index.start + offsets::TOTAL_LEN) + } + + /// Set the identification field. + #[inline] + pub fn set_id(&self, buf: &mut [u8], id: u16) -> Result<(), FieldError> { + id.write(buf, self.index.start + offsets::ID) + } + + /// Set the flags field. + #[inline] + pub fn set_flags(&self, buf: &mut [u8], flags: Ipv4Flags) -> Result<(), FieldError> { + let offset = self.index.start + offsets::FLAGS_FRAG; + let current = u8::read(buf, offset)?; + let new_val = (current & 0x1F) | flags.to_byte(); + new_val.write(buf, offset) + } + + /// Set the fragment offset (in 8-byte units). + #[inline] + pub fn set_frag_offset(&self, buf: &mut [u8], offset_val: u16) -> Result<(), FieldError> { + let offset = self.index.start + offsets::FLAGS_FRAG; + let current = u16::read(buf, offset)?; + let new_val = (current & 0xE000) | (offset_val & 0x1FFF); + new_val.write(buf, offset) + } + + /// Set the TTL field. + #[inline] + pub fn set_ttl(&self, buf: &mut [u8], ttl: u8) -> Result<(), FieldError> { + ttl.write(buf, self.index.start + offsets::TTL) + } + + /// Set the protocol field. + #[inline] + pub fn set_protocol(&self, buf: &mut [u8], proto: u8) -> Result<(), FieldError> { + proto.write(buf, self.index.start + offsets::PROTOCOL) + } + + /// Set the checksum field. + #[inline] + pub fn set_checksum(&self, buf: &mut [u8], checksum: u16) -> Result<(), FieldError> { + checksum.write(buf, self.index.start + offsets::CHECKSUM) + } + + /// Set the source IP address. + #[inline] + pub fn set_src(&self, buf: &mut [u8], src: Ipv4Addr) -> Result<(), FieldError> { + src.write(buf, self.index.start + offsets::SRC) + } + + /// Set the destination IP address. + #[inline] + pub fn set_dst(&self, buf: &mut [u8], dst: Ipv4Addr) -> Result<(), FieldError> { + dst.write(buf, self.index.start + offsets::DST) + } + + /// Compute and set the header checksum. + pub fn compute_checksum(&self, buf: &mut [u8]) -> Result { + // Zero out existing checksum first + self.set_checksum(buf, 0)?; + + // Calculate header length + let header_len = self.calculate_header_len(buf); + let header_end = self.index.start + header_len; + + if buf.len() < header_end { + return Err(FieldError::BufferTooShort { + offset: self.index.start, + need: header_len, + have: buf.len().saturating_sub(self.index.start), + }); + } + + // Compute checksum over header bytes + let checksum = ipv4_checksum(&buf[self.index.start..header_end]); + + // Write checksum + self.set_checksum(buf, checksum)?; + + Ok(checksum) + } + + /// Verify the header checksum. + pub fn verify_checksum(&self, buf: &[u8]) -> Result { + let header_len = self.calculate_header_len(buf); + let header_end = self.index.start + header_len; + + if buf.len() < header_end { + return Err(FieldError::BufferTooShort { + offset: self.index.start, + need: header_len, + have: buf.len().saturating_sub(self.index.start), + }); + } + + let checksum = ipv4_checksum(&buf[self.index.start..header_end]); + Ok(checksum == 0 || checksum == 0xFFFF) + } + + // ========== Dynamic Field Access ========== + + /// Get a field value by name. + pub fn get_field(&self, buf: &[u8], name: &str) -> Option> { + match name { + "version" => Some(self.version(buf).map(FieldValue::U8)), + "ihl" => Some(self.ihl(buf).map(FieldValue::U8)), + "tos" => Some(self.tos(buf).map(FieldValue::U8)), + "dscp" => Some(self.dscp(buf).map(FieldValue::U8)), + "ecn" => Some(self.ecn(buf).map(FieldValue::U8)), + "len" | "total_len" => Some(self.total_len(buf).map(FieldValue::U16)), + "id" => Some(self.id(buf).map(FieldValue::U16)), + "flags" => Some(self.flags_raw(buf).map(FieldValue::U8)), + "frag" | "frag_offset" => Some(self.frag_offset(buf).map(FieldValue::U16)), + "ttl" => Some(self.ttl(buf).map(FieldValue::U8)), + "proto" | "protocol" => Some(self.protocol(buf).map(FieldValue::U8)), + "chksum" | "checksum" => Some(self.checksum(buf).map(FieldValue::U16)), + "src" => Some(self.src(buf).map(FieldValue::Ipv4)), + "dst" => Some(self.dst(buf).map(FieldValue::Ipv4)), + _ => None, + } + } + + /// Set a field value by name. + pub fn set_field( + &self, + buf: &mut [u8], + name: &str, + value: FieldValue, + ) -> Option> { + match (name, value) { + ("version", FieldValue::U8(v)) => Some(self.set_version(buf, v)), + ("ihl", FieldValue::U8(v)) => Some(self.set_ihl(buf, v)), + ("tos", FieldValue::U8(v)) => Some(self.set_tos(buf, v)), + ("dscp", FieldValue::U8(v)) => Some(self.set_dscp(buf, v)), + ("ecn", FieldValue::U8(v)) => Some(self.set_ecn(buf, v)), + ("len" | "total_len", FieldValue::U16(v)) => Some(self.set_total_len(buf, v)), + ("id", FieldValue::U16(v)) => Some(self.set_id(buf, v)), + ("frag" | "frag_offset", FieldValue::U16(v)) => Some(self.set_frag_offset(buf, v)), + ("ttl", FieldValue::U8(v)) => Some(self.set_ttl(buf, v)), + ("proto" | "protocol", FieldValue::U8(v)) => Some(self.set_protocol(buf, v)), + ("chksum" | "checksum", FieldValue::U16(v)) => Some(self.set_checksum(buf, v)), + ("src", FieldValue::Ipv4(v)) => Some(self.set_src(buf, v)), + ("dst", FieldValue::Ipv4(v)) => Some(self.set_dst(buf, v)), + _ => None, + } + } + + /// Get list of field names. + pub fn field_names() -> &'static [&'static str] { + &[ + "version", "ihl", "tos", "dscp", "ecn", "len", "id", "flags", "frag", "ttl", "proto", + "chksum", "src", "dst", + ] + } + + // ========== Utility Methods ========== + + /// Check if this is a fragment. + pub fn is_fragment(&self, buf: &[u8]) -> bool { + let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE); + let offset = self.frag_offset(buf).unwrap_or(0); + flags.is_fragment(offset) + } + + /// Check if this is the first fragment. + pub fn is_first_fragment(&self, buf: &[u8]) -> bool { + let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE); + let offset = self.frag_offset(buf).unwrap_or(0); + flags.mf && offset == 0 + } + + /// Check if this is the last fragment. + pub fn is_last_fragment(&self, buf: &[u8]) -> bool { + let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE); + let offset = self.frag_offset(buf).unwrap_or(0); + !flags.mf && offset > 0 + } + + /// Check if the Don't Fragment flag is set. + pub fn is_dont_fragment(&self, buf: &[u8]) -> bool { + self.flags(buf).map(|f| f.df).unwrap_or(false) + } + + /// Get the payload length (total_len - header_len). + pub fn payload_len(&self, buf: &[u8]) -> Result { + let total = self.total_len(buf)? as usize; + let header = self.calculate_header_len(buf); + Ok(total.saturating_sub(header)) + } + + /// Get a slice of the payload data. + pub fn payload<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> { + let header_len = self.calculate_header_len(buf); + let total_len = self.total_len(buf)? as usize; + let payload_start = self.index.start + header_len; + let payload_end = (self.index.start + total_len).min(buf.len()); + + if payload_start > buf.len() { + return Err(FieldError::BufferTooShort { + offset: payload_start, + need: 0, + have: buf.len().saturating_sub(payload_start), + }); + } + + Ok(&buf[payload_start..payload_end]) + } + + /// Get the header bytes. + #[inline] + pub fn header_bytes<'a>(&self, buf: &'a [u8]) -> &'a [u8] { + let header_len = self.calculate_header_len(buf); + let end = (self.index.start + header_len).min(buf.len()); + &buf[self.index.start..end] + } + + /// Copy the header bytes. + #[inline] + pub fn header_copy(&self, buf: &[u8]) -> Vec { + self.header_bytes(buf).to_vec() + } + + /// Get the protocol name. + pub fn protocol_name(&self, buf: &[u8]) -> &'static str { + self.protocol(buf) + .map(protocol::to_name) + .unwrap_or("Unknown") + } + + /// Determine the next layer kind based on protocol. + pub fn next_layer(&self, buf: &[u8]) -> Option { + self.protocol(buf).ok().and_then(|proto| match proto { + protocol::TCP => Some(LayerKind::Tcp), + protocol::UDP => Some(LayerKind::Udp), + protocol::ICMP => Some(LayerKind::Icmp), + protocol::IPV4 => Some(LayerKind::Ipv4), + protocol::IPV6 => Some(LayerKind::Ipv6), + _ => None, + }) + } + + /// Compute hash for packet matching (like Scapy's hashret). + pub fn hashret(&self, buf: &[u8]) -> Vec { + let proto = self.protocol(buf).unwrap_or(0); + + // For ICMP error messages, delegate to inner packet + if proto == protocol::ICMP { + // Check if it's an ICMP error (type 3, 4, 5, 11, 12) + let header_len = self.calculate_header_len(buf); + let icmp_start = self.index.start + header_len; + if buf.len() > icmp_start { + let icmp_type = buf[icmp_start]; + if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) { + // Return hash of the embedded packet + // For now, just use src/dst XOR + proto + } + } + } + + // For IP-in-IP tunnels, delegate to inner packet + if matches!(proto, protocol::IPV4 | protocol::IPV6) { + // Could recurse here + } + + // Standard hash: XOR of src and dst + protocol + let src = self.src(buf).map(|ip| ip.octets()).unwrap_or([0; 4]); + let dst = self.dst(buf).map(|ip| ip.octets()).unwrap_or([0; 4]); + + let mut result = Vec::with_capacity(5); + for i in 0..4 { + result.push(src[i] ^ dst[i]); + } + result.push(proto); + result + } + + /// Check if this packet answers another (for sr() matching). + pub fn answers(&self, buf: &[u8], other: &Ipv4Layer, other_buf: &[u8]) -> bool { + // Protocol must match + let self_proto = self.protocol(buf).unwrap_or(0); + let other_proto = other.protocol(other_buf).unwrap_or(0); + + // Handle ICMP errors + if self_proto == protocol::ICMP { + let header_len = self.calculate_header_len(buf); + let icmp_start = self.index.start + header_len; + if buf.len() > icmp_start { + let icmp_type = buf[icmp_start]; + if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) { + // ICMP error - check embedded packet + // The embedded packet should match the original + return true; // Simplified - real impl would check embedded + } + } + } + + // Handle IP-in-IP tunnels + if matches!(other_proto, protocol::IPV4 | protocol::IPV6) { + // Delegate to inner packet + } + + if self_proto != other_proto { + return false; + } + + // Check addresses + let self_src = self.src(buf).ok(); + let self_dst = self.dst(buf).ok(); + let other_src = other.src(other_buf).ok(); + let other_dst = other.dst(other_buf).ok(); + + // Response src should match request dst + if self_src != other_dst { + return false; + } + + // Response dst should match request src + if self_dst != other_src { + return false; + } + + true + } + + /// Extract padding from the packet. + /// Returns (payload, padding) tuple. + pub fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) { + let header_len = self.calculate_header_len(buf); + let total_len = self.total_len(buf).unwrap_or(0) as usize; + + let payload_start = self.index.start + header_len; + let payload_end = (self.index.start + total_len).min(buf.len()); + + if payload_start >= buf.len() { + return (&[], &buf[buf.len()..]); + } + + let payload = &buf[payload_start..payload_end]; + let padding = &buf[payload_end..]; + + (payload, padding) + } + + /// Get routing information for this packet. + pub fn route(&self, buf: &[u8]) -> Ipv4Route { + use crate::layer::ipv4::routing::get_route; + let dst = self.dst(buf).unwrap_or(Ipv4Addr::UNSPECIFIED); + get_route(dst) + } + + /// Estimate the original TTL. + pub fn original_ttl(&self, buf: &[u8]) -> u8 { + let current = self.ttl(buf).unwrap_or(0); + super::ttl::estimate_original(current) + } + + /// Estimate the number of hops. + pub fn hops(&self, buf: &[u8]) -> u8 { + let current = self.ttl(buf).unwrap_or(0); + super::ttl::estimate_hops(current) + } +} + +impl Layer for Ipv4Layer { + fn kind(&self) -> LayerKind { + LayerKind::Ipv4 + } + + fn summary(&self, buf: &[u8]) -> String { + let src = self + .src(buf) + .map(|ip| ip.to_string()) + .unwrap_or_else(|_| "?".into()); + let dst = self + .dst(buf) + .map(|ip| ip.to_string()) + .unwrap_or_else(|_| "?".into()); + let proto = self.protocol_name(buf); + let ttl = self.ttl(buf).unwrap_or(0); + + let mut s = format!("IP {} > {} {} ttl={}", src, dst, proto, ttl); + + // Add fragment info if fragmented + if self.is_fragment(buf) { + let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE); + let offset = self.frag_offset(buf).unwrap_or(0); + s.push_str(&format!( + " frag:{}+{}", + offset, + if flags.mf { "MF" } else { "" } + )); + } + + s + } + + fn header_len(&self, buf: &[u8]) -> usize { + self.calculate_header_len(buf) + } + + fn hashret(&self, buf: &[u8]) -> Vec { + self.hashret(buf) + } + + fn answers(&self, buf: &[u8], other: &Self, other_buf: &[u8]) -> bool { + self.answers(buf, other, other_buf) + } + + fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) { + self.extract_padding(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_ipv4_header() -> Vec { + vec![ + 0x45, // Version=4, IHL=5 + 0x00, // TOS=0 + 0x00, 0x3c, // Total length = 60 + 0x1c, 0x46, // ID = 0x1c46 + 0x40, 0x00, // Flags=DF, Frag offset=0 + 0x40, // TTL=64 + 0x06, // Protocol=TCP + 0x00, 0x00, // Checksum (to be computed) + 0xc0, 0xa8, 0x01, 0x01, // Src: 192.168.1.1 + 0xc0, 0xa8, 0x01, 0x02, // Dst: 192.168.1.2 + ] + } + + #[test] + fn test_field_readers() { + let buf = sample_ipv4_header(); + let layer = Ipv4Layer::at_offset(0); + + assert_eq!(layer.version(&buf).unwrap(), 4); + assert_eq!(layer.ihl(&buf).unwrap(), 5); + assert_eq!(layer.tos(&buf).unwrap(), 0); + assert_eq!(layer.total_len(&buf).unwrap(), 60); + assert_eq!(layer.id(&buf).unwrap(), 0x1c46); + assert!(layer.flags(&buf).unwrap().df); + assert_eq!(layer.frag_offset(&buf).unwrap(), 0); + assert_eq!(layer.ttl(&buf).unwrap(), 64); + assert_eq!(layer.protocol(&buf).unwrap(), protocol::TCP); + assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(layer.dst(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 2)); + } + + #[test] + fn test_field_writers() { + let mut buf = sample_ipv4_header(); + let layer = Ipv4Layer::at_offset(0); + + layer.set_ttl(&mut buf, 128).unwrap(); + assert_eq!(layer.ttl(&buf).unwrap(), 128); + + layer.set_src(&mut buf, Ipv4Addr::new(10, 0, 0, 1)).unwrap(); + assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(10, 0, 0, 1)); + + layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap(); + assert!(layer.flags(&buf).unwrap().mf); + assert!(!layer.flags(&buf).unwrap().df); + } + + #[test] + fn test_checksum() { + let mut buf = sample_ipv4_header(); + let layer = Ipv4Layer::at_offset(0); + + // Compute checksum + let checksum = layer.compute_checksum(&mut buf).unwrap(); + assert_ne!(checksum, 0); + + // Verify it + assert!(layer.verify_checksum(&buf).unwrap()); + + // Corrupt and verify fails + buf[0] ^= 0x01; + assert!(!layer.verify_checksum(&buf).unwrap()); + } + + #[test] + fn test_flags() { + let flags = Ipv4Flags::from_byte(0x40); // DF + assert!(flags.df); + assert!(!flags.mf); + assert!(!flags.reserved); + assert_eq!(flags.to_byte(), 0x40); + + let flags = Ipv4Flags::from_byte(0x20); // MF + assert!(!flags.df); + assert!(flags.mf); + assert!(!flags.reserved); + + let flags = Ipv4Flags::from_byte(0xE0); // All set + assert!(flags.df); + assert!(flags.mf); + assert!(flags.reserved); + } + + #[test] + fn test_is_fragment() { + let mut buf = sample_ipv4_header(); + let layer = Ipv4Layer::at_offset(0); + + // DF set, not a fragment + assert!(!layer.is_fragment(&buf)); + + // Set MF + layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap(); + assert!(layer.is_fragment(&buf)); + + // Clear MF, set offset + layer.set_flags(&mut buf, Ipv4Flags::NONE).unwrap(); + layer.set_frag_offset(&mut buf, 100).unwrap(); + assert!(layer.is_fragment(&buf)); + } + + #[test] + fn test_dynamic_field_access() { + let buf = sample_ipv4_header(); + let layer = Ipv4Layer::at_offset(0); + + let ttl = layer.get_field(&buf, "ttl").unwrap().unwrap(); + assert_eq!(ttl.as_u8(), Some(64)); + + let src = layer.get_field(&buf, "src").unwrap().unwrap(); + assert_eq!(src.as_ipv4(), Some(Ipv4Addr::new(192, 168, 1, 1))); + } + + #[test] + fn test_validate() { + let buf = sample_ipv4_header(); + assert!(Ipv4Layer::validate(&buf, 0).is_ok()); + + // Too short + let short = vec![0x45, 0x00]; + assert!(Ipv4Layer::validate(&short, 0).is_err()); + + // Wrong version + let mut wrong_version = sample_ipv4_header(); + wrong_version[0] = 0x65; // Version 6 + assert!(Ipv4Layer::validate(&wrong_version, 0).is_err()); + + // Invalid IHL + let mut bad_ihl = sample_ipv4_header(); + bad_ihl[0] = 0x43; // IHL=3 (< minimum 5) + assert!(Ipv4Layer::validate(&bad_ihl, 0).is_err()); + } + + #[test] + fn test_extract_padding() { + let mut buf = sample_ipv4_header(); + // Set total_len to 30 (header=20 + payload=10) + buf[2] = 0x00; + buf[3] = 0x1e; // 30 + + // Add some payload and padding + buf.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); // 10 bytes payload + buf.extend_from_slice(&[0, 0, 0, 0]); // 4 bytes padding + + let layer = Ipv4Layer::at_offset(0); + let (payload, padding) = layer.extract_padding(&buf); + + assert_eq!(payload.len(), 10); + assert_eq!(padding.len(), 4); + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/mod.rs b/crates/stackforge-core/src/layer/ipv4/mod.rs new file mode 100644 index 0000000..f46d43c --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/mod.rs @@ -0,0 +1,24 @@ +//! IPv4 layer module. +//! +//! This module implements the IPv4 protocol, providing packet parsing (via `Ipv4Layer`), +//! construction (via `Ipv4Builder`), fragmentation, options handling, and checksum verification. + +// Register submodules +pub mod builder; +pub mod checksum; +pub mod fragmentation; +pub mod header; +pub mod options; +pub mod protocol; +pub mod routing; +pub mod ttl; + +// Re-export primary types for easier access +pub use builder::Ipv4Builder; +pub use checksum::ipv4_checksum; +pub use fragmentation::{DEFAULT_MTU, Fragment, FragmentInfo, Ipv4Fragmenter}; +pub use header::{ + IPV4_MAX_HEADER_LEN, IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets as ipv4_offsets, +}; +pub use options::{Ipv4Option, Ipv4OptionClass, Ipv4OptionType, Ipv4Options, Ipv4OptionsBuilder}; +pub use routing::Ipv4Route; diff --git a/crates/stackforge-core/src/layer/ipv4/options.rs b/crates/stackforge-core/src/layer/ipv4/options.rs new file mode 100644 index 0000000..a70fb60 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/options.rs @@ -0,0 +1,1024 @@ +//! IPv4 options parsing and building. +//! +//! IP options follow the header and are variable-length. +//! Maximum options length is 40 bytes (60 byte header - 20 byte minimum). + +use std::net::Ipv4Addr; + +use crate::layer::field::FieldError; + +/// Maximum length of IP options (in bytes). +pub const MAX_OPTIONS_LEN: usize = 40; + +/// IP option class (bits 6-7 of option type). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum Ipv4OptionClass { + /// Control options + Control = 0, + /// Reserved (1) + Reserved1 = 1, + /// Debugging and measurement + DebuggingMeasurement = 2, + /// Reserved (3) + Reserved3 = 3, +} + +impl Ipv4OptionClass { + #[inline] + pub fn from_type(opt_type: u8) -> Self { + match (opt_type >> 5) & 0x03 { + 0 => Self::Control, + 1 => Self::Reserved1, + 2 => Self::DebuggingMeasurement, + 3 => Self::Reserved3, + _ => unreachable!(), + } + } +} + +/// Well-known IP option types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum Ipv4OptionType { + /// End of Option List + EndOfList = 0, + /// No Operation (padding) + Nop = 1, + /// Security (RFC 1108) + Security = 130, + /// Loose Source and Record Route + Lsrr = 131, + /// Internet Timestamp + Timestamp = 68, + /// Extended Security (RFC 1108) + ExtendedSecurity = 133, + /// Commercial Security + CommercialSecurity = 134, + /// Record Route + RecordRoute = 7, + /// Stream ID + StreamId = 136, + /// Strict Source and Record Route + Ssrr = 137, + /// Experimental Measurement + ExperimentalMeasurement = 10, + /// MTU Probe + MtuProbe = 11, + /// MTU Reply + MtuReply = 12, + /// Traceroute + Traceroute = 82, + /// Address Extension + AddressExtension = 147, + /// Router Alert + RouterAlert = 148, + /// Selective Directed Broadcast + SelectiveDirectedBroadcast = 149, + /// Unknown option + Unknown(u8), +} + +impl Ipv4OptionType { + /// Create from raw option type byte. + pub fn from_byte(b: u8) -> Self { + match b { + 0 => Self::EndOfList, + 1 => Self::Nop, + 7 => Self::RecordRoute, + 10 => Self::ExperimentalMeasurement, + 11 => Self::MtuProbe, + 12 => Self::MtuReply, + 68 => Self::Timestamp, + 82 => Self::Traceroute, + 130 => Self::Security, + 131 => Self::Lsrr, + 133 => Self::ExtendedSecurity, + 134 => Self::CommercialSecurity, + 136 => Self::StreamId, + 137 => Self::Ssrr, + 147 => Self::AddressExtension, + 148 => Self::RouterAlert, + 149 => Self::SelectiveDirectedBroadcast, + x => Self::Unknown(x), + } + } + + /// Convert to raw option type byte. + pub fn to_byte(self) -> u8 { + match self { + Self::EndOfList => 0, + Self::Nop => 1, + Self::RecordRoute => 7, + Self::ExperimentalMeasurement => 10, + Self::MtuProbe => 11, + Self::MtuReply => 12, + Self::Timestamp => 68, + Self::Traceroute => 82, + Self::Security => 130, + Self::Lsrr => 131, + Self::ExtendedSecurity => 133, + Self::CommercialSecurity => 134, + Self::StreamId => 136, + Self::Ssrr => 137, + Self::AddressExtension => 147, + Self::RouterAlert => 148, + Self::SelectiveDirectedBroadcast => 149, + Self::Unknown(x) => x, + } + } + + /// Get the name of the option. + pub fn name(&self) -> &'static str { + match self { + Self::EndOfList => "EOL", + Self::Nop => "NOP", + Self::Security => "Security", + Self::Lsrr => "LSRR", + Self::Timestamp => "Timestamp", + Self::ExtendedSecurity => "Extended Security", + Self::CommercialSecurity => "Commercial Security", + Self::RecordRoute => "Record Route", + Self::StreamId => "Stream ID", + Self::Ssrr => "SSRR", + Self::ExperimentalMeasurement => "Experimental Measurement", + Self::MtuProbe => "MTU Probe", + Self::MtuReply => "MTU Reply", + Self::Traceroute => "Traceroute", + Self::AddressExtension => "Address Extension", + Self::RouterAlert => "Router Alert", + Self::SelectiveDirectedBroadcast => "Selective Directed Broadcast", + Self::Unknown(_) => "Unknown", + } + } + + /// Check if this option should be copied on fragmentation. + #[inline] + pub fn is_copied(&self) -> bool { + (self.to_byte() & 0x80) != 0 + } + + /// Check if this is a single-byte option (no length/value). + #[inline] + pub fn is_single_byte(&self) -> bool { + matches!(self, Self::EndOfList | Self::Nop) + } +} + +impl std::fmt::Display for Ipv4OptionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Unknown(x) => write!(f, "Unknown({})", x), + _ => write!(f, "{}", self.name()), + } + } +} + +/// A parsed IP option. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Ipv4Option { + /// End of Option List (type 0) + EndOfList, + + /// No Operation / Padding (type 1) + Nop, + + /// Record Route (type 7) + RecordRoute { pointer: u8, route: Vec }, + + /// Loose Source and Record Route (type 131) + Lsrr { pointer: u8, route: Vec }, + + /// Strict Source and Record Route (type 137) + Ssrr { pointer: u8, route: Vec }, + + /// Internet Timestamp (type 68) + Timestamp { + pointer: u8, + overflow: u8, + flag: u8, + data: Vec<(Option, u32)>, + }, + + /// Security (type 130) + Security { + security: u16, + compartment: u16, + handling_restrictions: u16, + transmission_control_code: [u8; 3], + }, + + /// Stream ID (type 136) + StreamId { id: u16 }, + + /// MTU Probe (type 11) + MtuProbe { mtu: u16 }, + + /// MTU Reply (type 12) + MtuReply { mtu: u16 }, + + /// Traceroute (type 82) + Traceroute { + id: u16, + outbound_hops: u16, + return_hops: u16, + originator: Ipv4Addr, + }, + + /// Router Alert (type 148) + RouterAlert { value: u16 }, + + /// Address Extension (type 147) + AddressExtension { + src_ext: Ipv4Addr, + dst_ext: Ipv4Addr, + }, + + /// Unknown or unimplemented option + Unknown { option_type: u8, data: Vec }, +} + +impl Ipv4Option { + /// Get the option type. + pub fn option_type(&self) -> Ipv4OptionType { + match self { + Self::EndOfList => Ipv4OptionType::EndOfList, + Self::Nop => Ipv4OptionType::Nop, + Self::RecordRoute { .. } => Ipv4OptionType::RecordRoute, + Self::Lsrr { .. } => Ipv4OptionType::Lsrr, + Self::Ssrr { .. } => Ipv4OptionType::Ssrr, + Self::Timestamp { .. } => Ipv4OptionType::Timestamp, + Self::Security { .. } => Ipv4OptionType::Security, + Self::StreamId { .. } => Ipv4OptionType::StreamId, + Self::MtuProbe { .. } => Ipv4OptionType::MtuProbe, + Self::MtuReply { .. } => Ipv4OptionType::MtuReply, + Self::Traceroute { .. } => Ipv4OptionType::Traceroute, + Self::RouterAlert { .. } => Ipv4OptionType::RouterAlert, + Self::AddressExtension { .. } => Ipv4OptionType::AddressExtension, + Self::Unknown { option_type, .. } => Ipv4OptionType::Unknown(*option_type), + } + } + + /// Get the serialized length of this option. + pub fn len(&self) -> usize { + match self { + Self::EndOfList => 1, + Self::Nop => 1, + Self::RecordRoute { route, .. } + | Self::Lsrr { route, .. } + | Self::Ssrr { route, .. } => 3 + route.len() * 4, + Self::Timestamp { data, flag, .. } => 4 + data.len() * if *flag == 0 { 4 } else { 8 }, + Self::Security { .. } => 11, + Self::StreamId { .. } => 4, + Self::MtuProbe { .. } | Self::MtuReply { .. } => 4, + Self::Traceroute { .. } => 12, + Self::RouterAlert { .. } => 4, + Self::AddressExtension { .. } => 10, + Self::Unknown { data, .. } => 2 + data.len(), + } + } + + /// Check if the option is empty (for variants with data). + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Serialize the option to bytes. + pub fn to_bytes(&self) -> Vec { + match self { + Self::EndOfList => vec![0], + Self::Nop => vec![1], + + Self::RecordRoute { pointer, route } => { + let mut buf = vec![7, (3 + route.len() * 4) as u8, *pointer]; + for ip in route { + buf.extend_from_slice(&ip.octets()); + } + buf + } + + Self::Lsrr { pointer, route } => { + let mut buf = vec![131, (3 + route.len() * 4) as u8, *pointer]; + for ip in route { + buf.extend_from_slice(&ip.octets()); + } + buf + } + + Self::Ssrr { pointer, route } => { + let mut buf = vec![137, (3 + route.len() * 4) as u8, *pointer]; + for ip in route { + buf.extend_from_slice(&ip.octets()); + } + buf + } + + Self::Timestamp { + pointer, + overflow, + flag, + data, + } => { + let entry_size = if *flag == 0 { 4 } else { 8 }; + let mut buf = vec![ + 68, + (4 + data.len() * entry_size) as u8, + *pointer, + (*overflow << 4) | (*flag & 0x0F), + ]; + for (ip, ts) in data { + if let Some(addr) = ip { + buf.extend_from_slice(&addr.octets()); + } + buf.extend_from_slice(&ts.to_be_bytes()); + } + buf + } + + Self::Security { + security, + compartment, + handling_restrictions, + transmission_control_code, + } => { + let mut buf = vec![130, 11]; + buf.extend_from_slice(&security.to_be_bytes()); + buf.extend_from_slice(&compartment.to_be_bytes()); + buf.extend_from_slice(&handling_restrictions.to_be_bytes()); + buf.extend_from_slice(transmission_control_code); + buf + } + + Self::StreamId { id } => { + let mut buf = vec![136, 4]; + buf.extend_from_slice(&id.to_be_bytes()); + buf + } + + Self::MtuProbe { mtu } => { + let mut buf = vec![11, 4]; + buf.extend_from_slice(&mtu.to_be_bytes()); + buf + } + + Self::MtuReply { mtu } => { + let mut buf = vec![12, 4]; + buf.extend_from_slice(&mtu.to_be_bytes()); + buf + } + + Self::Traceroute { + id, + outbound_hops, + return_hops, + originator, + } => { + let mut buf = vec![82, 12]; + buf.extend_from_slice(&id.to_be_bytes()); + buf.extend_from_slice(&outbound_hops.to_be_bytes()); + buf.extend_from_slice(&return_hops.to_be_bytes()); + buf.extend_from_slice(&originator.octets()); + buf + } + + Self::RouterAlert { value } => { + let mut buf = vec![148, 4]; + buf.extend_from_slice(&value.to_be_bytes()); + buf + } + + Self::AddressExtension { src_ext, dst_ext } => { + let mut buf = vec![147, 10]; + buf.extend_from_slice(&src_ext.octets()); + buf.extend_from_slice(&dst_ext.octets()); + buf + } + + Self::Unknown { option_type, data } => { + let mut buf = vec![*option_type, (2 + data.len()) as u8]; + buf.extend_from_slice(data); + buf + } + } + } +} + +/// Collection of IP options. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct Ipv4Options { + pub options: Vec, +} + +impl Ipv4Options { + /// Create empty options. + pub fn new() -> Self { + Self::default() + } + + /// Create from a list of options. + pub fn from_vec(options: Vec) -> Self { + Self { options } + } + + /// Check if there are no options. + pub fn is_empty(&self) -> bool { + self.options.is_empty() + } + + /// Get the number of options. + pub fn len(&self) -> usize { + self.options.len() + } + + /// Get the total serialized length. + pub fn byte_len(&self) -> usize { + self.options.iter().map(|o| o.len()).sum() + } + + /// Get the padded length (aligned to 4 bytes). + pub fn padded_len(&self) -> usize { + let len = self.byte_len(); + (len + 3) & !3 + } + + /// Add an option. + pub fn push(&mut self, option: Ipv4Option) { + self.options.push(option); + } + + /// Get source route options (LSRR or SSRR). + pub fn source_route(&self) -> Option<&Ipv4Option> { + self.options + .iter() + .find(|o| matches!(o, Ipv4Option::Lsrr { .. } | Ipv4Option::Ssrr { .. })) + } + + /// Get the final destination from source route options. + pub fn final_destination(&self) -> Option { + self.source_route().and_then(|opt| match opt { + Ipv4Option::Lsrr { route, .. } | Ipv4Option::Ssrr { route, .. } => { + route.last().copied() + } + _ => None, + }) + } + + /// Serialize all options to bytes (with padding). + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + for opt in &self.options { + buf.extend_from_slice(&opt.to_bytes()); + } + + // Pad to 4-byte boundary + let pad = (4 - (buf.len() % 4)) % 4; + buf.extend(std::iter::repeat(0u8).take(pad)); + + buf + } + + /// Filter options that should be copied on fragmentation. + pub fn copied_options(&self) -> Self { + Self { + options: self + .options + .iter() + .filter(|o| o.option_type().is_copied()) + .cloned() + .collect(), + } + } +} + +impl IntoIterator for Ipv4Options { + type Item = Ipv4Option; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.options.into_iter() + } +} + +impl<'a> IntoIterator for &'a Ipv4Options { + type Item = &'a Ipv4Option; + type IntoIter = std::slice::Iter<'a, Ipv4Option>; + + fn into_iter(self) -> Self::IntoIter { + self.options.iter() + } +} + +/// Parse IP options from bytes. +pub fn parse_options(data: &[u8]) -> Result { + let mut options = Vec::new(); + let mut offset = 0; + + while offset < data.len() { + let opt_type = data[offset]; + + match opt_type { + // End of Option List + 0 => { + options.push(Ipv4Option::EndOfList); + break; + } + + // NOP + 1 => { + options.push(Ipv4Option::Nop); + offset += 1; + } + + // Multi-byte options + _ => { + if offset + 1 >= data.len() { + return Err(FieldError::InvalidValue( + "option length field missing".to_string(), + )); + } + + let length = data[offset + 1] as usize; + if length < 2 { + return Err(FieldError::InvalidValue(format!( + "option length {} is less than minimum (2)", + length + ))); + } + + if offset + length > data.len() { + return Err(FieldError::BufferTooShort { + offset, + need: length, + have: data.len() - offset, + }); + } + + let opt_data = &data[offset..offset + length]; + let opt = parse_single_option(opt_type, opt_data)?; + options.push(opt); + + offset += length; + } + } + } + + Ok(Ipv4Options { options }) +} + +/// Parse a single multi-byte option. +fn parse_single_option(opt_type: u8, data: &[u8]) -> Result { + let length = data[1] as usize; + + match opt_type { + // Record Route + 7 => { + if length < 3 { + return Err(FieldError::InvalidValue( + "Record Route option too short".to_string(), + )); + } + let pointer = data[2]; + let route = parse_ip_list(&data[3..length]); + Ok(Ipv4Option::RecordRoute { pointer, route }) + } + + // LSRR + 131 => { + if length < 3 { + return Err(FieldError::InvalidValue( + "LSRR option too short".to_string(), + )); + } + let pointer = data[2]; + let route = parse_ip_list(&data[3..length]); + Ok(Ipv4Option::Lsrr { pointer, route }) + } + + // SSRR + 137 => { + if length < 3 { + return Err(FieldError::InvalidValue( + "SSRR option too short".to_string(), + )); + } + let pointer = data[2]; + let route = parse_ip_list(&data[3..length]); + Ok(Ipv4Option::Ssrr { pointer, route }) + } + + // Timestamp + 68 => { + if length < 4 { + return Err(FieldError::InvalidValue( + "Timestamp option too short".to_string(), + )); + } + let pointer = data[2]; + let oflw_flag = data[3]; + let overflow = oflw_flag >> 4; + let flag = oflw_flag & 0x0F; + + let timestamps = parse_timestamps(&data[4..length], flag)?; + Ok(Ipv4Option::Timestamp { + pointer, + overflow, + flag, + data: timestamps, + }) + } + + // Security + 130 => { + if length != 11 { + return Err(FieldError::InvalidValue(format!( + "Security option length {} != 11", + length + ))); + } + let security = u16::from_be_bytes([data[2], data[3]]); + let compartment = u16::from_be_bytes([data[4], data[5]]); + let handling_restrictions = u16::from_be_bytes([data[6], data[7]]); + let mut tcc = [0u8; 3]; + tcc.copy_from_slice(&data[8..11]); + + Ok(Ipv4Option::Security { + security, + compartment, + handling_restrictions, + transmission_control_code: tcc, + }) + } + + // Stream ID + 136 => { + if length != 4 { + return Err(FieldError::InvalidValue(format!( + "Stream ID option length {} != 4", + length + ))); + } + let id = u16::from_be_bytes([data[2], data[3]]); + Ok(Ipv4Option::StreamId { id }) + } + + // MTU Probe + 11 => { + if length != 4 { + return Err(FieldError::InvalidValue(format!( + "MTU Probe option length {} != 4", + length + ))); + } + let mtu = u16::from_be_bytes([data[2], data[3]]); + Ok(Ipv4Option::MtuProbe { mtu }) + } + + // MTU Reply + 12 => { + if length != 4 { + return Err(FieldError::InvalidValue(format!( + "MTU Reply option length {} != 4", + length + ))); + } + let mtu = u16::from_be_bytes([data[2], data[3]]); + Ok(Ipv4Option::MtuReply { mtu }) + } + + // Traceroute + 82 => { + if length != 12 { + return Err(FieldError::InvalidValue(format!( + "Traceroute option length {} != 12", + length + ))); + } + let id = u16::from_be_bytes([data[2], data[3]]); + let outbound_hops = u16::from_be_bytes([data[4], data[5]]); + let return_hops = u16::from_be_bytes([data[6], data[7]]); + let originator = Ipv4Addr::new(data[8], data[9], data[10], data[11]); + + Ok(Ipv4Option::Traceroute { + id, + outbound_hops, + return_hops, + originator, + }) + } + + // Router Alert + 148 => { + if length != 4 { + return Err(FieldError::InvalidValue(format!( + "Router Alert option length {} != 4", + length + ))); + } + let value = u16::from_be_bytes([data[2], data[3]]); + Ok(Ipv4Option::RouterAlert { value }) + } + + // Address Extension + 147 => { + if length != 10 { + return Err(FieldError::InvalidValue(format!( + "Address Extension option length {} != 10", + length + ))); + } + let src_ext = Ipv4Addr::new(data[2], data[3], data[4], data[5]); + let dst_ext = Ipv4Addr::new(data[6], data[7], data[8], data[9]); + Ok(Ipv4Option::AddressExtension { src_ext, dst_ext }) + } + + // Unknown option + _ => Ok(Ipv4Option::Unknown { + option_type: opt_type, + data: data[2..length].to_vec(), + }), + } +} + +/// Parse a list of IP addresses from option data. +fn parse_ip_list(data: &[u8]) -> Vec { + data.chunks_exact(4) + .map(|c| Ipv4Addr::new(c[0], c[1], c[2], c[3])) + .collect() +} + +/// Parse timestamp data based on flag. +fn parse_timestamps(data: &[u8], flag: u8) -> Result, u32)>, FieldError> { + match flag { + // Timestamps only + 0 => { + let timestamps: Vec<_> = data + .chunks_exact(4) + .map(|c| { + let ts = u32::from_be_bytes([c[0], c[1], c[2], c[3]]); + (None, ts) + }) + .collect(); + Ok(timestamps) + } + + // IP + Timestamp pairs + 1 | 3 => { + if data.len() % 8 != 0 { + return Err(FieldError::InvalidValue( + "Timestamp data not aligned to 8 bytes".to_string(), + )); + } + let timestamps: Vec<_> = data + .chunks_exact(8) + .map(|c| { + let ip = Ipv4Addr::new(c[0], c[1], c[2], c[3]); + let ts = u32::from_be_bytes([c[4], c[5], c[6], c[7]]); + (Some(ip), ts) + }) + .collect(); + Ok(timestamps) + } + + _ => Err(FieldError::InvalidValue(format!( + "Unknown timestamp flag: {}", + flag + ))), + } +} + +/// Builder for IP options. +#[derive(Debug, Clone, Default)] +pub struct Ipv4OptionsBuilder { + options: Vec, +} + +impl Ipv4OptionsBuilder { + /// Create a new builder. + pub fn new() -> Self { + Self::default() + } + + /// Add an End of Option List marker. + pub fn eol(mut self) -> Self { + self.options.push(Ipv4Option::EndOfList); + self + } + + /// Add a NOP (padding). + pub fn nop(mut self) -> Self { + self.options.push(Ipv4Option::Nop); + self + } + + /// Add a Record Route option. + pub fn record_route(mut self, route: Vec) -> Self { + self.options.push(Ipv4Option::RecordRoute { + pointer: 4, // First slot + route, + }); + self + } + + /// Add a Loose Source Route option. + pub fn lsrr(mut self, route: Vec) -> Self { + self.options.push(Ipv4Option::Lsrr { pointer: 4, route }); + self + } + + /// Add a Strict Source Route option. + pub fn ssrr(mut self, route: Vec) -> Self { + self.options.push(Ipv4Option::Ssrr { pointer: 4, route }); + self + } + + /// Add a Timestamp option (timestamps only). + pub fn timestamp(mut self) -> Self { + self.options.push(Ipv4Option::Timestamp { + pointer: 5, + overflow: 0, + flag: 0, + data: vec![], + }); + self + } + + /// Add a Timestamp option with IP addresses. + pub fn timestamp_with_addresses(mut self, prespecified: bool) -> Self { + self.options.push(Ipv4Option::Timestamp { + pointer: 9, + overflow: 0, + flag: if prespecified { 3 } else { 1 }, + data: vec![], + }); + self + } + + /// Add a Router Alert option. + pub fn router_alert(mut self, value: u16) -> Self { + self.options.push(Ipv4Option::RouterAlert { value }); + self + } + + /// Add a custom option. + pub fn option(mut self, option: Ipv4Option) -> Self { + self.options.push(option); + self + } + + /// Build the options. + pub fn build(self) -> Ipv4Options { + Ipv4Options { + options: self.options, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_nop_eol() { + let data = [1, 1, 1, 0]; // 3 NOPs + EOL + let opts = parse_options(&data).unwrap(); + + assert_eq!(opts.len(), 4); + assert!(matches!(opts.options[0], Ipv4Option::Nop)); + assert!(matches!(opts.options[1], Ipv4Option::Nop)); + assert!(matches!(opts.options[2], Ipv4Option::Nop)); + assert!(matches!(opts.options[3], Ipv4Option::EndOfList)); + } + + #[test] + fn test_parse_record_route() { + let data = [ + 7, // Type: Record Route + 11, // Length: 11 bytes + 4, // Pointer: first slot + 192, 168, 1, 1, // First IP + 192, 168, 1, 2, // Second IP + ]; + + let opts = parse_options(&data).unwrap(); + assert_eq!(opts.len(), 1); + + if let Ipv4Option::RecordRoute { pointer, route } = &opts.options[0] { + assert_eq!(*pointer, 4); + assert_eq!(route.len(), 2); + assert_eq!(route[0], Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(route[1], Ipv4Addr::new(192, 168, 1, 2)); + } else { + panic!("Expected RecordRoute option"); + } + } + + #[test] + fn test_parse_timestamp() { + let data = [ + 68, // Type: Timestamp + 12, // Length: 12 bytes + 5, // Pointer + 0x01, // Overflow=0, Flag=1 (IP + timestamp) + 192, 168, 1, 1, // IP + 0x00, 0x00, 0x10, 0x00, // Timestamp + ]; + + let opts = parse_options(&data).unwrap(); + assert_eq!(opts.len(), 1); + + if let Ipv4Option::Timestamp { + pointer, + overflow, + flag, + data: ts_data, + } = &opts.options[0] + { + assert_eq!(*pointer, 5); + assert_eq!(*overflow, 0); + assert_eq!(*flag, 1); + + assert_eq!(ts_data.len(), 1); + assert_eq!(ts_data[0].0, Some(Ipv4Addr::new(192, 168, 1, 1))); + assert_eq!(ts_data[0].1, 0x1000); + } + } + + #[test] + fn test_parse_router_alert() { + let data = [ + 148, // Type: Router Alert + 4, // Length + 0x00, // Value high + 0x00, // Value low + ]; + + let opts = parse_options(&data).unwrap(); + assert_eq!(opts.len(), 1); + + if let Ipv4Option::RouterAlert { value } = opts.options[0] { + assert_eq!(value, 0); + } else { + panic!("Expected RouterAlert option"); + } + } + + #[test] + fn test_serialize_options() { + let opts = Ipv4OptionsBuilder::new() + .nop() + .router_alert(0) + .nop() + .build(); + + let bytes = opts.to_bytes(); + + // Should be padded to 4-byte boundary + assert_eq!(bytes.len() % 4, 0); + + // Parse back + let parsed = parse_options(&bytes).unwrap(); + assert!(matches!(parsed.options[0], Ipv4Option::Nop)); + assert!(matches!( + parsed.options[1], + Ipv4Option::RouterAlert { value: 0 } + )); + } + + #[test] + fn test_option_type_properties() { + // LSRR should be copied + assert!(Ipv4OptionType::Lsrr.is_copied()); + // Record Route should not be copied + assert!(!Ipv4OptionType::RecordRoute.is_copied()); + // NOP is single byte + assert!(Ipv4OptionType::Nop.is_single_byte()); + // Timestamp is not single byte + assert!(!Ipv4OptionType::Timestamp.is_single_byte()); + } + + #[test] + fn test_final_destination() { + let opts = Ipv4OptionsBuilder::new() + .lsrr(vec![ + Ipv4Addr::new(10, 0, 0, 1), + Ipv4Addr::new(10, 0, 0, 2), + Ipv4Addr::new(10, 0, 0, 3), + ]) + .build(); + + assert_eq!(opts.final_destination(), Some(Ipv4Addr::new(10, 0, 0, 3))); + } + + #[test] + fn test_copied_options() { + let opts = Ipv4OptionsBuilder::new() + .record_route(vec![]) // Not copied + .lsrr(vec![Ipv4Addr::new(10, 0, 0, 1)]) // Copied + .nop() // Not copied + .build(); + + let copied = opts.copied_options(); + assert_eq!(copied.len(), 1); + assert!(matches!(copied.options[0], Ipv4Option::Lsrr { .. })); + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/protocol.rs b/crates/stackforge-core/src/layer/ipv4/protocol.rs new file mode 100644 index 0000000..02ed8ee --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/protocol.rs @@ -0,0 +1,656 @@ +//! IPv4 Protocol numbers. +//! +//! Registry of common IP protocol numbers assigned by IANA. +//! These values correspond to the "Protocol" field in the IPv4 header. + +/// IPv6 Hop-by-Hop Option (RFC 8200) +pub const HOPOPT: u8 = 0; + +/// Internet Control Message (RFC 792) +pub const ICMP: u8 = 1; + +/// Internet Group Management (RFC 1112) +pub const IGMP: u8 = 2; + +/// Gateway-to-Gateway (RFC 823) +pub const GGP: u8 = 3; + +/// IP in IP (encapsulation) (RFC 2003) +pub const IPIP: u8 = 4; + +/// IPv4 Encapsulation (alias for IPIP) +pub const IPV4: u8 = 4; + +/// Stream Internet (RFC 1190) +pub const ST: u8 = 5; + +/// Transmission Control (RFC 793) +pub const TCP: u8 = 6; + +/// CBT (RFC 2189) +pub const CBT: u8 = 7; + +/// Exterior Gateway Protocol (RFC 888) +pub const EGP: u8 = 8; + +/// IGP (any private interior gateway: used by Cisco for their IGRP) +pub const IGP: u8 = 9; + +/// BBN RCC Monitoring +pub const BBN_RCC_MON: u8 = 10; + +/// Network Voice Protocol (RFC 741) +pub const NVP_II: u8 = 11; + +/// PUP +pub const PUP: u8 = 12; + +/// ARGUS +pub const ARGUS: u8 = 13; + +/// EMCON +pub const EMCON: u8 = 14; + +/// Cross Net Debugger +pub const XNET: u8 = 15; + +/// Chaos +pub const CHAOS: u8 = 16; + +/// User Datagram (RFC 768) +pub const UDP: u8 = 17; + +/// Multiplexing +pub const MUX: u8 = 18; + +/// DCN Measurement Subsystems +pub const DCN_MEAS: u8 = 19; + +/// Host Monitoring (RFC 869) +pub const HMP: u8 = 20; + +/// Packet Radio Measurement +pub const PRM: u8 = 21; + +/// XEROX NS IDP +pub const XNS_IDP: u8 = 22; + +/// Trunk-1 +pub const TRUNK_1: u8 = 23; + +/// Trunk-2 +pub const TRUNK_2: u8 = 24; + +/// Leaf-1 +pub const LEAF_1: u8 = 25; + +/// Leaf-2 +pub const LEAF_2: u8 = 26; + +/// Reliable Data Protocol (RFC 908) +pub const RDP: u8 = 27; + +/// Internet Reliable Transaction (RFC 938) +pub const IRTP: u8 = 28; + +/// ISO Transport Class 4 (RFC 905) +pub const ISO_TP4: u8 = 29; + +/// Bulk Data Transfer Protocol (RFC 969) +pub const NETBLT: u8 = 30; + +/// MFE Network Services Protocol +pub const MFE_NSP: u8 = 31; + +/// MERIT Internodal Protocol +pub const MERIT_INP: u8 = 32; + +/// Datagram Congestion Control Protocol (RFC 4340) +pub const DCCP: u8 = 33; + +/// Third Party Connect Protocol +pub const THIRD_PARTY_CONNECT: u8 = 34; + +/// IDPR (RFC 1479) +pub const IDPR: u8 = 35; + +/// XTP +pub const XTP: u8 = 36; + +/// Datagram Delivery Protocol +pub const DDP: u8 = 37; + +/// IDPR Control Message Transport Proto +pub const IDPR_CMTP: u8 = 38; + +/// TP++ Transport Protocol +pub const TP_PLUS_PLUS: u8 = 39; + +/// IL Transport Protocol +pub const IL: u8 = 40; + +/// IPv6 Encapsulation (RFC 2473) +pub const IPV6: u8 = 41; + +/// Source Demand Routing Protocol +pub const SDRP: u8 = 42; + +/// Routing Header for IPv6 (RFC 8200) +pub const IPV6_ROUTE: u8 = 43; + +/// Fragment Header for IPv6 (RFC 8200) +pub const IPV6_FRAG: u8 = 44; + +/// Inter-Domain Policy Routing Protocol +pub const IDRP: u8 = 45; + +/// Reservation Protocol (RFC 2205) +pub const RSVP: u8 = 46; + +/// Generic Routing Encapsulation (RFC 2784) +pub const GRE: u8 = 47; + +/// Dynamic Source Routing Protocol (RFC 4728) +pub const DSR: u8 = 48; + +/// BNA +pub const BNA: u8 = 49; + +/// Encap Security Payload (RFC 4303) +pub const ESP: u8 = 50; + +/// Authentication Header (RFC 4302) +pub const AH: u8 = 51; + +/// Integrated Net Layer Security TUBA +pub const I_NLSP: u8 = 52; + +/// IP with Encryption +pub const SWIPE: u8 = 53; + +/// NBMA Address Resolution Protocol (RFC 1735) +pub const NARP: u8 = 54; + +/// IP Mobility (RFC 2004) +pub const MOBILE: u8 = 55; + +/// Transport Layer Security Protocol using Kryptonet key management +pub const TLSP: u8 = 56; + +/// SKIP +pub const SKIP: u8 = 57; + +/// ICMP for IPv6 (RFC 4443) +pub const ICMPV6: u8 = 58; + +/// No Next Header for IPv6 (RFC 8200) +pub const IPV6_NONXT: u8 = 59; + +/// Destination Options for IPv6 (RFC 8200) +pub const IPV6_OPTS: u8 = 60; + +/// Any host internal protocol +pub const ANY_HOST_INTERNAL: u8 = 61; + +/// CFTP +pub const CFTP: u8 = 62; + +/// Any local network +pub const ANY_LOCAL_NETWORK: u8 = 63; + +/// SATNET and Backroom EXPAK +pub const SAT_EXPAK: u8 = 64; + +/// Kryptolan +pub const KRYPTOLAN: u8 = 65; + +/// MIT Remote Virtual Disk Protocol +pub const RVD: u8 = 66; + +/// Internet Pluribus Packet Core +pub const IPPC: u8 = 67; + +/// Any distributed file system +pub const ANY_DIST_FS: u8 = 68; + +/// SATNET Monitoring +pub const SAT_MON: u8 = 69; + +/// VISA Protocol +pub const VISA: u8 = 70; + +/// Internet Packet Core Utility +pub const IPCV: u8 = 71; + +/// Computer Protocol Network Executive +pub const CPNX: u8 = 72; + +/// Computer Protocol Heart Beat +pub const CPHB: u8 = 73; + +/// Wang Span Network +pub const WSN: u8 = 74; + +/// Packet Video Protocol +pub const PVP: u8 = 75; + +/// Backroom SATNET Monitoring +pub const BR_SAT_MON: u8 = 76; + +/// SUN ND PROTOCOL-Temporary +pub const SUN_ND: u8 = 77; + +/// WIDEBAND Monitoring +pub const WB_MON: u8 = 78; + +/// WIDEBAND EXPAK +pub const WB_EXPAK: u8 = 79; + +/// ISO Internet Protocol +pub const ISO_IP: u8 = 80; + +/// VMTP +pub const VMTP: u8 = 81; + +/// SECURE-VMTP +pub const SECURE_VMTP: u8 = 82; + +/// VINES +pub const VINES: u8 = 83; + +/// Transaction Transport Protocol (IPTM) +pub const TTP: u8 = 84; + +/// NSFNET-IGP +pub const NSFNET_IGP: u8 = 85; + +/// Dissimilar Gateway Protocol +pub const DGP: u8 = 86; + +/// TCF +pub const TCF: u8 = 87; + +/// EIGRP (RFC 7868) +pub const EIGRP: u8 = 88; + +/// OSPF (RFC 2328) +pub const OSPFIGP: u8 = 89; + +/// Sprite RPC Protocol +pub const SPRITE_RPC: u8 = 90; + +/// Locus Address Resolution Protocol +pub const LARP: u8 = 91; + +/// Multicast Transport Protocol +pub const MTP: u8 = 92; + +/// AX.25 Frames +pub const AX_25: u8 = 93; + +/// IP-within-IP Encapsulation Protocol +pub const IPIP_ENCAP: u8 = 94; + +/// Mobile Internetworking Control Pro. (RFC 2003) +pub const MICP: u8 = 95; + +/// Semaphore Communications Sec. Pro. +pub const SCC_SP: u8 = 96; + +/// Ethernet-within-IP Encapsulation (RFC 3378) +pub const ETHERIP: u8 = 97; + +/// Encapsulation Header (RFC 1241) +pub const ENCAP: u8 = 98; + +/// Any private encryption scheme +pub const ANY_ENC: u8 = 99; + +/// GMTP +pub const GMTP: u8 = 100; + +/// Ipsilon Flow Management Protocol +pub const IFMP: u8 = 101; + +/// PNNI over IP +pub const PNNI: u8 = 102; + +/// Protocol Independent Multicast (RFC 7761) +pub const PIM: u8 = 103; + +/// ARIS +pub const ARIS: u8 = 104; + +/// SCPS +pub const SCPS: u8 = 105; + +/// QNX +pub const QNX: u8 = 106; + +/// Active Networks +pub const AN: u8 = 107; + +/// IP Payload Compression Protocol (RFC 2393) +pub const IPCOMP: u8 = 108; + +/// Sitara Networks Protocol +pub const SNP: u8 = 109; + +/// Compaq Peer Protocol +pub const COMPAQ_PEER: u8 = 110; + +/// IPX in IP +pub const IPX_IN_IP: u8 = 111; + +/// Virtual Router Redundancy Protocol (RFC 5798) +pub const VRRP: u8 = 112; + +/// PGM Reliable Transport Protocol +pub const PGM: u8 = 113; + +/// Any 0-hop protocol +pub const ANY_0_HOP: u8 = 114; + +/// Layer Two Tunneling Protocol (RFC 3931) +pub const L2TP: u8 = 115; + +/// D-II Data Exchange (DX) +pub const DDX: u8 = 116; + +/// Interactive Agent Transfer Protocol +pub const IATP: u8 = 117; + +/// Schedule Transfer Protocol +pub const STP: u8 = 118; + +/// SpectraLink Radio Protocol +pub const SRP: u8 = 119; + +/// UTI +pub const UTI: u8 = 120; + +/// Simple Message Protocol +pub const SMP: u8 = 121; + +/// Simple Multicast Protocol +pub const SM: u8 = 122; + +/// Performance Transparency Protocol +pub const PTP: u8 = 123; + +/// ISIS over IPv4 (RFC 1142) +pub const ISIS_OVER_IPV4: u8 = 124; + +/// CRTP +pub const FIRE: u8 = 125; + +/// Combat Radio Transport Protocol +pub const CRTP: u8 = 126; + +/// Combat Radio User Datagram +pub const CRUDP: u8 = 127; + +/// SSCOPMCE +pub const SSCOPMCE: u8 = 128; + +/// IPLT +pub const IPLT: u8 = 129; + +/// Secure Packet Shield +pub const SPS: u8 = 130; + +/// Private IP Encapsulation within IP +pub const PIPE: u8 = 131; + +/// Stream Control Transmission Protocol (RFC 4960) +pub const SCTP: u8 = 132; + +/// Fibre Channel +pub const FC: u8 = 133; + +/// RSVP-E2E-IGNORE (RFC 3175) +pub const RSVP_E2E_IGNORE: u8 = 134; + +/// Mobility Header (RFC 6275) +pub const MOBILITY_HEADER: u8 = 135; + +/// UDPLite (RFC 3828) +pub const UDPLITE: u8 = 136; + +/// MPLS-in-IP (RFC 4023) +pub const MPLS_IN_IP: u8 = 137; + +/// MANET Protocols (RFC 5498) +pub const MANET: u8 = 138; + +/// Host Identity Protocol (RFC 7401) +pub const HIP: u8 = 139; + +/// Shim6 Protocol (RFC 5533) +pub const SHIM6: u8 = 140; + +/// Wrapped Encapsulating Security Payload (RFC 5840) +pub const WESP: u8 = 141; + +/// RObust Header Compression (RFC 5858) +pub const ROHC: u8 = 142; + +/// Ethernet (RFC 8986) +pub const ETHERNET: u8 = 143; + +/// AGGFRAG encapsulation payload for ESP (RFC 9347) +pub const AGGFRAG: u8 = 144; + +/// Reserved +pub const RESERVED: u8 = 255; + +/// Convert a string name to a protocol number (case-insensitive). +pub fn from_name(name: &str) -> Option { + match name.to_ascii_lowercase().as_str() { + "ip" => Some(0), // Sometimes used as alias for HOPOPT + "ipv4" => Some(IPV4), + "icmp" => Some(ICMP), + "igmp" => Some(IGMP), + "ggp" => Some(GGP), + "ipip" => Some(IPIP), + "tcp" => Some(TCP), + "egp" => Some(EGP), + "igp" => Some(IGP), + "pup" => Some(PUP), + "udp" => Some(UDP), + "hmp" => Some(HMP), + "xns-idp" => Some(XNS_IDP), + "rdp" => Some(RDP), + "iso-tp4" => Some(ISO_TP4), + "dccp" => Some(DCCP), + "xtp" => Some(XTP), + "ddp" => Some(DDP), + "idpr-cmtp" => Some(IDPR_CMTP), + "ipv6" => Some(IPV6), + "ipv6-route" => Some(IPV6_ROUTE), + "ipv6-frag" => Some(IPV6_FRAG), + "idrp" => Some(IDRP), + "rsvp" => Some(RSVP), + "gre" => Some(GRE), + "esp" => Some(ESP), + "ah" => Some(AH), + "skip" => Some(SKIP), + "icmpv6" => Some(ICMPV6), + "ipv6-nonxt" => Some(IPV6_NONXT), + "ipv6-opts" => Some(IPV6_OPTS), + "eigrp" => Some(EIGRP), + "ospf" => Some(OSPFIGP), + "mtp" => Some(MTP), + "encap" => Some(ENCAP), + "pim" => Some(PIM), + "ipcomp" => Some(IPCOMP), + "vrrp" => Some(VRRP), + "l2tp" => Some(L2TP), + "isis" => Some(ISIS_OVER_IPV4), + "sctp" => Some(SCTP), + "fc" => Some(FC), + "mobility-header" => Some(MOBILITY_HEADER), + "udplite" => Some(UDPLITE), + "mpls-in-ip" => Some(MPLS_IN_IP), + "manet" => Some(MANET), + "hip" => Some(HIP), + "shim6" => Some(SHIM6), + "wesp" => Some(WESP), + "rohc" => Some(ROHC), + "ethernet" => Some(ETHERNET), + _ => None, + } +} + +/// Get the string name for a protocol number. +pub fn to_name(proto: u8) -> &'static str { + match proto { + HOPOPT => "HOPOPT", + ICMP => "ICMP", + IGMP => "IGMP", + GGP => "GGP", + IPIP => "IPIP", + ST => "ST", + TCP => "TCP", + CBT => "CBT", + EGP => "EGP", + IGP => "IGP", + BBN_RCC_MON => "BBN_RCC_MON", + NVP_II => "NVP_II", + PUP => "PUP", + ARGUS => "ARGUS", + EMCON => "EMCON", + XNET => "XNET", + CHAOS => "CHAOS", + UDP => "UDP", + MUX => "MUX", + DCN_MEAS => "DCN_MEAS", + HMP => "HMP", + PRM => "PRM", + XNS_IDP => "XNS_IDP", + TRUNK_1 => "TRUNK_1", + TRUNK_2 => "TRUNK_2", + LEAF_1 => "LEAF_1", + LEAF_2 => "LEAF_2", + RDP => "RDP", + IRTP => "IRTP", + ISO_TP4 => "ISO_TP4", + NETBLT => "NETBLT", + MFE_NSP => "MFE_NSP", + MERIT_INP => "MERIT_INP", + DCCP => "DCCP", + THIRD_PARTY_CONNECT => "THIRD_PARTY_CONNECT", + IDPR => "IDPR", + XTP => "XTP", + DDP => "DDP", + IDPR_CMTP => "IDPR_CMTP", + TP_PLUS_PLUS => "TP_PLUS_PLUS", + IL => "IL", + IPV6 => "IPV6", + SDRP => "SDRP", + IPV6_ROUTE => "IPV6_ROUTE", + IPV6_FRAG => "IPV6_FRAG", + IDRP => "IDRP", + RSVP => "RSVP", + GRE => "GRE", + DSR => "DSR", + BNA => "BNA", + ESP => "ESP", + AH => "AH", + I_NLSP => "I_NLSP", + SWIPE => "SWIPE", + NARP => "NARP", + MOBILE => "MOBILE", + TLSP => "TLSP", + SKIP => "SKIP", + ICMPV6 => "ICMPV6", + IPV6_NONXT => "IPV6_NONXT", + IPV6_OPTS => "IPV6_OPTS", + ANY_HOST_INTERNAL => "ANY_HOST_INTERNAL", + CFTP => "CFTP", + ANY_LOCAL_NETWORK => "ANY_LOCAL_NETWORK", + SAT_EXPAK => "SAT_EXPAK", + KRYPTOLAN => "KRYPTOLAN", + RVD => "RVD", + IPPC => "IPPC", + ANY_DIST_FS => "ANY_DIST_FS", + SAT_MON => "SAT_MON", + VISA => "VISA", + IPCV => "IPCV", + CPNX => "CPNX", + CPHB => "CPHB", + WSN => "WSN", + PVP => "PVP", + BR_SAT_MON => "BR_SAT_MON", + SUN_ND => "SUN_ND", + WB_MON => "WB_MON", + WB_EXPAK => "WB_EXPAK", + ISO_IP => "ISO_IP", + VMTP => "VMTP", + SECURE_VMTP => "SECURE_VMTP", + VINES => "VINES", + TTP => "TTP", + NSFNET_IGP => "NSFNET_IGP", + DGP => "DGP", + TCF => "TCF", + EIGRP => "EIGRP", + OSPFIGP => "OSPFIGP", + SPRITE_RPC => "SPRITE_RPC", + LARP => "LARP", + MTP => "MTP", + AX_25 => "AX_25", + IPIP_ENCAP => "IPIP_ENCAP", + MICP => "MICP", + SCC_SP => "SCC_SP", + ETHERIP => "ETHERIP", + ENCAP => "ENCAP", + ANY_ENC => "ANY_ENC", + GMTP => "GMTP", + IFMP => "IFMP", + PNNI => "PNNI", + PIM => "PIM", + ARIS => "ARIS", + SCPS => "SCPS", + QNX => "QNX", + AN => "AN", + IPCOMP => "IPCOMP", + SNP => "SNP", + COMPAQ_PEER => "COMPAQ_PEER", + IPX_IN_IP => "IPX_IN_IP", + VRRP => "VRRP", + PGM => "PGM", + ANY_0_HOP => "ANY_0_HOP", + L2TP => "L2TP", + DDX => "DDX", + IATP => "IATP", + STP => "STP", + SRP => "SRP", + UTI => "UTI", + SMP => "SMP", + SM => "SM", + PTP => "PTP", + ISIS_OVER_IPV4 => "ISIS_OVER_IPV4", + FIRE => "FIRE", + CRTP => "CRTP", + CRUDP => "CRUDP", + SSCOPMCE => "SSCOPMCE", + IPLT => "IPLT", + SPS => "SPS", + PIPE => "PIPE", + SCTP => "SCTP", + FC => "FC", + RSVP_E2E_IGNORE => "RSVP_E2E_IGNORE", + MOBILITY_HEADER => "MOBILITY_HEADER", + UDPLITE => "UDPLITE", + MPLS_IN_IP => "MPLS_IN_IP", + MANET => "MANET", + HIP => "HIP", + SHIM6 => "SHIM6", + WESP => "WESP", + ROHC => "ROHC", + ETHERNET => "ETHERNET", + AGGFRAG => "AGGFRAG", + RESERVED => "RESERVED", + _ => "UNKNOWN", + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/routing.rs b/crates/stackforge-core/src/layer/ipv4/routing.rs new file mode 100644 index 0000000..7e3af22 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/routing.rs @@ -0,0 +1,444 @@ +//! IPv4 routing utilities. +//! +//! Provides routing table lookup and interface selection for IPv4 packets. +//! This is used to determine which interface to send packets on and what +//! the next-hop address should be. + +use std::net::Ipv4Addr; + +/// Routing information for an IPv4 destination. +#[derive(Debug, Clone, Default)] +pub struct Ipv4Route { + /// The outgoing interface name. + pub interface: Option, + /// The source IP address to use. + pub source: Option, + /// The next-hop gateway IP (if not on local network). + pub gateway: Option, + /// Whether the destination is on the local network. + pub is_local: bool, + /// The network mask (prefix length). + pub prefix_len: u8, + /// Route metric (lower is better). + pub metric: u32, +} + +impl Ipv4Route { + /// Create an empty route (no routing info available). + pub fn none() -> Self { + Self::default() + } + + /// Create a route for a local destination. + pub fn local(interface: String, source: Ipv4Addr) -> Self { + Self { + interface: Some(interface), + source: Some(source), + gateway: None, + is_local: true, + prefix_len: 32, + metric: 0, + } + } + + /// Create a route via a gateway. + pub fn via_gateway(interface: String, source: Ipv4Addr, gateway: Ipv4Addr) -> Self { + Self { + interface: Some(interface), + source: Some(source), + gateway: Some(gateway), + is_local: false, + prefix_len: 0, + metric: 0, + } + } + + /// Check if this route is valid (has interface). + pub fn is_valid(&self) -> bool { + self.interface.is_some() + } + + /// Get the next-hop address (gateway or destination if local). + pub fn next_hop(&self, dst: Ipv4Addr) -> Ipv4Addr { + self.gateway.unwrap_or(dst) + } +} + +/// A router that can look up routes for destinations. +pub trait Ipv4Router { + /// Look up the route for a destination. + fn route(&self, dst: Ipv4Addr) -> Ipv4Route; +} + +/// Get routing information for a destination address. +/// +/// This uses the system routing table to determine the interface, +/// source address, and gateway for reaching a destination. +pub fn get_route(dst: Ipv4Addr) -> Ipv4Route { + // Handle special addresses + if dst.is_loopback() { + return Ipv4Route::local("lo".to_string(), Ipv4Addr::LOCALHOST); + } + + if dst.is_broadcast() || dst == Ipv4Addr::BROADCAST { + return get_default_route() + .map(|r| Ipv4Route { + is_local: true, + ..r + }) + .unwrap_or_default(); + } + + if dst.is_multicast() { + // Multicast uses the default route interface + return get_default_route() + .map(|r| Ipv4Route { + gateway: None, + is_local: true, + ..r + }) + .unwrap_or_default(); + } + + // Try to find a matching interface + let interfaces = pnet_datalink::interfaces(); + + // First, check for direct connectivity + for iface in &interfaces { + if !iface.is_up() || iface.is_loopback() { + continue; + } + + for ip_network in &iface.ips { + if let std::net::IpAddr::V4(ip) = ip_network.ip() { + // Check if destination is in this subnet + if ip_network.contains(std::net::IpAddr::V4(dst)) { + return Ipv4Route { + interface: Some(iface.name.clone()), + source: Some(ip), + gateway: None, + is_local: true, + prefix_len: ip_network.prefix(), + metric: 0, + }; + } + } + } + } + + // Use default gateway + if let Some(route) = get_default_route() { + return route; + } + + Ipv4Route::none() +} + +/// Get the default route (gateway). +pub fn get_default_route() -> Option { + // Try to get default interface using default-net crate + let default_iface = default_net::get_default_interface().ok()?; + let gateway = default_net::get_default_gateway().ok()?; + + let interfaces = pnet_datalink::interfaces(); + let iface = interfaces.iter().find(|i| i.name == default_iface.name)?; + + // Find IPv4 address on this interface + let source = iface.ips.iter().find_map(|ip| { + if let std::net::IpAddr::V4(v4) = ip.ip() { + Some(v4) + } else { + None + } + })?; + + let gw_ip = match gateway.ip_addr { + std::net::IpAddr::V4(v4) => v4, + _ => return None, + }; + + Some(Ipv4Route { + interface: Some(iface.name.clone()), + source: Some(source), + gateway: Some(gw_ip), + is_local: false, + prefix_len: 0, + metric: 0, + }) +} + +/// Get the source IP address for a destination. +/// +/// This performs a routing lookup and returns the appropriate source. +pub fn get_source_for_dst(dst: Ipv4Addr) -> Option { + get_route(dst).source +} + +/// Get the interface name for a destination. +pub fn get_interface_for_dst(dst: Ipv4Addr) -> Option { + get_route(dst).interface +} + +/// Check if a destination is on the local network. +pub fn is_local_destination(dst: Ipv4Addr) -> bool { + get_route(dst).is_local +} + +/// Get all available IPv4 interfaces. +pub fn get_ipv4_interfaces() -> Vec { + pnet_datalink::interfaces() + .into_iter() + .filter_map(|iface| { + if !iface.is_up() { + return None; + } + + let addrs: Vec<_> = iface + .ips + .iter() + .filter_map(|ip| { + if let std::net::IpAddr::V4(v4) = ip.ip() { + Some(Ipv4InterfaceAddr { + address: v4, + prefix_len: ip.prefix(), + // Fix for `and_then` error: match directly on IpAddr + broadcast: match ip.broadcast() { + std::net::IpAddr::V4(v4) => Some(v4), + _ => None, + }, + }) + } else { + None + } + }) + .collect(); + + if addrs.is_empty() { + return None; + } + + // Fix for "Value used after being moved": + // Extract boolean flags BEFORE moving `iface.name` + let is_loopback = iface.is_loopback(); + let is_up = iface.is_up(); + let is_running = iface.is_running(); + let is_multicast = iface.is_multicast(); + let is_broadcast = iface.is_broadcast(); + + Some(Ipv4Interface { + name: iface.name, // Move occurs here + index: iface.index, + mac: iface.mac.map(|m| m.octets()), + addresses: addrs, + is_loopback, + is_up, + is_running, + is_multicast, + is_broadcast, + mtu: None, // Not available from pnet + }) + }) + .collect() +} + +/// Information about an IPv4-capable interface. +#[derive(Debug, Clone)] +pub struct Ipv4Interface { + /// Interface name (e.g., "eth0", "en0"). + pub name: String, + /// Interface index. + pub index: u32, + /// MAC address (if available). + pub mac: Option<[u8; 6]>, + /// IPv4 addresses on this interface. + pub addresses: Vec, + /// Whether this is a loopback interface. + pub is_loopback: bool, + /// Whether the interface is up. + pub is_up: bool, + /// Whether the interface is running. + pub is_running: bool, + /// Whether multicast is enabled. + pub is_multicast: bool, + /// Whether broadcast is supported. + pub is_broadcast: bool, + /// Interface MTU. + pub mtu: Option, +} + +impl Ipv4Interface { + /// Get the first (primary) IPv4 address. + pub fn primary_address(&self) -> Option { + self.addresses.first().map(|a| a.address) + } + + /// Check if an address is on this interface's network. + pub fn contains(&self, addr: Ipv4Addr) -> bool { + self.addresses.iter().any(|a| { + let mask = prefix_to_mask(a.prefix_len); + (u32::from(a.address) & mask) == (u32::from(addr) & mask) + }) + } +} + +/// An IPv4 address on an interface. +#[derive(Debug, Clone)] +pub struct Ipv4InterfaceAddr { + /// The IPv4 address. + pub address: Ipv4Addr, + /// Network prefix length (e.g., 24 for /24). + pub prefix_len: u8, + /// Broadcast address (if applicable). + pub broadcast: Option, +} + +impl Ipv4InterfaceAddr { + /// Get the network address. + pub fn network(&self) -> Ipv4Addr { + let mask = prefix_to_mask(self.prefix_len); + Ipv4Addr::from(u32::from(self.address) & mask) + } + + /// Get the subnet mask. + pub fn netmask(&self) -> Ipv4Addr { + Ipv4Addr::from(prefix_to_mask(self.prefix_len)) + } +} + +/// Convert a prefix length to a subnet mask. +pub fn prefix_to_mask(prefix: u8) -> u32 { + if prefix >= 32 { + 0xFFFFFFFF + } else if prefix == 0 { + 0 + } else { + !((1u32 << (32 - prefix)) - 1) + } +} + +/// Convert a subnet mask to a prefix length. +pub fn mask_to_prefix(mask: Ipv4Addr) -> u8 { + let mask_u32 = u32::from(mask); + mask_u32.leading_ones() as u8 +} + +/// Check if an IP address matches a network/prefix. +pub fn ip_in_network(ip: Ipv4Addr, network: Ipv4Addr, prefix: u8) -> bool { + let mask = prefix_to_mask(prefix); + (u32::from(ip) & mask) == (u32::from(network) & mask) +} + +/// Calculate the broadcast address for a network. +pub fn broadcast_address(network: Ipv4Addr, prefix: u8) -> Ipv4Addr { + let mask = prefix_to_mask(prefix); + Ipv4Addr::from(u32::from(network) | !mask) +} + +/// Calculate the number of hosts in a network. +pub fn network_host_count(prefix: u8) -> u32 { + if prefix >= 31 { + // /31 has 2 hosts, /32 has 1 + 2u32.saturating_sub(prefix as u32 - 30) + } else { + (1u32 << (32 - prefix)) - 2 // Subtract network and broadcast + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prefix_to_mask() { + assert_eq!(prefix_to_mask(0), 0x00000000); + assert_eq!(prefix_to_mask(8), 0xFF000000); + assert_eq!(prefix_to_mask(16), 0xFFFF0000); + assert_eq!(prefix_to_mask(24), 0xFFFFFF00); + assert_eq!(prefix_to_mask(32), 0xFFFFFFFF); + } + + #[test] + fn test_mask_to_prefix() { + assert_eq!(mask_to_prefix(Ipv4Addr::new(255, 0, 0, 0)), 8); + assert_eq!(mask_to_prefix(Ipv4Addr::new(255, 255, 0, 0)), 16); + assert_eq!(mask_to_prefix(Ipv4Addr::new(255, 255, 255, 0)), 24); + assert_eq!(mask_to_prefix(Ipv4Addr::new(255, 255, 255, 255)), 32); + } + + #[test] + fn test_ip_in_network() { + let network = Ipv4Addr::new(192, 168, 1, 0); + assert!(ip_in_network(Ipv4Addr::new(192, 168, 1, 100), network, 24)); + assert!(ip_in_network(Ipv4Addr::new(192, 168, 1, 255), network, 24)); + assert!(!ip_in_network(Ipv4Addr::new(192, 168, 2, 1), network, 24)); + } + + #[test] + fn test_broadcast_address() { + assert_eq!( + broadcast_address(Ipv4Addr::new(192, 168, 1, 0), 24), + Ipv4Addr::new(192, 168, 1, 255) + ); + assert_eq!( + broadcast_address(Ipv4Addr::new(10, 0, 0, 0), 8), + Ipv4Addr::new(10, 255, 255, 255) + ); + } + + #[test] + fn test_network_host_count() { + assert_eq!(network_host_count(24), 254); + assert_eq!(network_host_count(16), 65534); + assert_eq!(network_host_count(8), 16777214); + } + + #[test] + fn test_route_loopback() { + let route = get_route(Ipv4Addr::LOCALHOST); + assert!(route.is_local); + assert_eq!(route.interface, Some("lo".to_string())); + } + + #[test] + fn test_ipv4_route() { + let route = Ipv4Route::local("eth0".to_string(), Ipv4Addr::new(192, 168, 1, 100)); + assert!(route.is_valid()); + assert!(route.is_local); + assert_eq!( + route.next_hop(Ipv4Addr::new(192, 168, 1, 200)), + Ipv4Addr::new(192, 168, 1, 200) + ); + + let route = Ipv4Route::via_gateway( + "eth0".to_string(), + Ipv4Addr::new(192, 168, 1, 100), + Ipv4Addr::new(192, 168, 1, 1), + ); + assert!(!route.is_local); + assert_eq!( + route.next_hop(Ipv4Addr::new(8, 8, 8, 8)), + Ipv4Addr::new(192, 168, 1, 1) + ); + } + + #[test] + fn test_get_ipv4_interfaces() { + let interfaces = get_ipv4_interfaces(); + for iface in &interfaces { + assert!(!iface.name.is_empty()); + assert!(!iface.addresses.is_empty()); + } + } + + #[test] + fn test_interface_addr() { + let addr = Ipv4InterfaceAddr { + address: Ipv4Addr::new(192, 168, 1, 100), + prefix_len: 24, + broadcast: Some(Ipv4Addr::new(192, 168, 1, 255)), + }; + + assert_eq!(addr.network(), Ipv4Addr::new(192, 168, 1, 0)); + assert_eq!(addr.netmask(), Ipv4Addr::new(255, 255, 255, 0)); + } +} diff --git a/crates/stackforge-core/src/layer/ipv4/ttl.rs b/crates/stackforge-core/src/layer/ipv4/ttl.rs new file mode 100644 index 0000000..3ed0995 --- /dev/null +++ b/crates/stackforge-core/src/layer/ipv4/ttl.rs @@ -0,0 +1,65 @@ +//! TTL (Time To Live) utilities. +//! +//! Provides functions to guess the original TTL and number of hops +//! based on the current TTL value. + +/// Common initial TTL values used by various operating systems. +/// 32: Some ancient systems +/// 64: Linux, macOS, modern Windows +/// 128: Older Windows +/// 255: Network infrastructure (Cisco, etc.) +const INITIAL_TTLS: [u8; 4] = [32, 64, 128, 255]; + +/// Estimate the original TTL based on the current TTL. +/// +/// This finds the smallest standard initial TTL that is greater than +/// or equal to the current TTL. +pub fn estimate_original(current_ttl: u8) -> u8 { + for &initial in &INITIAL_TTLS { + if current_ttl <= initial { + return initial; + } + } + // Should be unreachable for u8, but fallback to 255 + 255 +} + +/// Estimate the number of hops the packet has traveled. +/// +/// Calculated as `original_ttl - current_ttl`. +pub fn estimate_hops(current_ttl: u8) -> u8 { + let original = estimate_original(current_ttl); + original.saturating_sub(current_ttl) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_estimate_original() { + assert_eq!(estimate_original(1), 32); + assert_eq!(estimate_original(30), 32); + assert_eq!(estimate_original(32), 32); + + assert_eq!(estimate_original(33), 64); + assert_eq!(estimate_original(60), 64); + assert_eq!(estimate_original(64), 64); + + assert_eq!(estimate_original(65), 128); + assert_eq!(estimate_original(100), 128); + assert_eq!(estimate_original(128), 128); + + assert_eq!(estimate_original(129), 255); + assert_eq!(estimate_original(200), 255); + assert_eq!(estimate_original(255), 255); + } + + #[test] + fn test_estimate_hops() { + assert_eq!(estimate_hops(30), 2); // 32 - 30 + assert_eq!(estimate_hops(63), 1); // 64 - 63 + assert_eq!(estimate_hops(54), 10); // 64 - 54 + assert_eq!(estimate_hops(128), 0); // 128 - 128 + } +} diff --git a/crates/stackforge-core/src/layer/mod.rs b/crates/stackforge-core/src/layer/mod.rs index dd147cb..acfb849 100644 --- a/crates/stackforge-core/src/layer/mod.rs +++ b/crates/stackforge-core/src/layer/mod.rs @@ -7,6 +7,7 @@ pub mod arp; pub mod bindings; pub mod ethernet; pub mod field; +pub mod ipv4; pub mod neighbor; use std::ops::Range; @@ -16,6 +17,7 @@ pub use arp::{ArpBuilder, ArpLayer}; pub use bindings::{LAYER_BINDINGS, LayerBinding}; pub use ethernet::{Dot3Builder, Dot3Layer, EthernetBuilder, EthernetLayer}; pub use field::{BytesField, Field, FieldDesc, FieldError, FieldType, FieldValue, MacAddress}; +pub use ipv4::{Ipv4Builder, Ipv4Flags, Ipv4Layer, Ipv4Options, Ipv4Route}; pub use neighbor::{NeighborCache, NeighborResolver}; /// Identifies the type of network protocol layer. @@ -68,7 +70,7 @@ impl LayerKind { match self { Self::Ethernet | Self::Dot3 => ethernet::ETHERNET_HEADER_LEN, Self::Arp => arp::ARP_HEADER_LEN, - Self::Ipv4 => 20, + Self::Ipv4 => ipv4::IPV4_MIN_HEADER_LEN, Self::Ipv6 => 40, Self::Icmp | Self::Icmpv6 => 8, Self::Tcp => 20, @@ -169,29 +171,16 @@ pub trait Layer { fn header_len(&self, data: &[u8]) -> usize; /// Compute a hash for packet matching. - /// - /// This is used to correlate requests with responses. - /// Packets that should match (e.g., ARP request/reply) should - /// return the same hash value. fn hashret(&self, _data: &[u8]) -> Vec { vec![] } /// Check if this packet answers another packet. - /// - /// Used by sr()/sr1() to match responses to requests. - /// For example, an ARP reply answers an ARP request if: - /// - reply.op == request.op + 1 - /// - reply.psrc matches request.pdst fn answers(&self, _data: &[u8], _other: &Self, _other_data: &[u8]) -> bool { false } /// Extract padding from the packet. - /// - /// Returns (payload, padding) tuple. - /// Some protocols (like ARP) have no payload, so everything - /// after the header is padding. fn extract_padding<'a>(&self, data: &'a [u8]) -> (&'a [u8], &'a [u8]) { let header_len = self.header_len(data); (&data[header_len..], &[]) @@ -204,9 +193,6 @@ pub trait Layer { } /// Enum dispatch for protocol layers. -/// -/// This enum allows efficient static dispatch to layer implementations -/// without the overhead of dynamic dispatch (vtables). #[derive(Debug, Clone)] pub enum LayerEnum { Ethernet(EthernetLayer), @@ -277,6 +263,7 @@ impl LayerEnum { match self { Self::Ethernet(l) => l.hashret(buf), Self::Arp(l) => l.hashret(buf), + Self::Ipv4(l) => l.hashret(buf), _ => vec![], } } @@ -298,34 +285,7 @@ impl LayerEnum { } } -// Placeholder layer structs (to be fully implemented) -#[derive(Debug, Clone)] -pub struct Ipv4Layer { - pub index: LayerIndex, -} - -impl Ipv4Layer { - pub fn summary(&self, buf: &[u8]) -> String { - let slice = self.index.slice(buf); - if slice.len() >= 20 { - let src = std::net::Ipv4Addr::new(slice[12], slice[13], slice[14], slice[15]); - let dst = std::net::Ipv4Addr::new(slice[16], slice[17], slice[18], slice[19]); - format!("IP {} > {}", src, dst) - } else { - "IP (truncated)".to_string() - } - } - - pub fn header_len(&self, buf: &[u8]) -> usize { - let slice = self.index.slice(buf); - if !slice.is_empty() { - ((slice[0] & 0x0F) as usize) * 4 - } else { - 20 - } - } -} - +// Placeholder layer structs (to be fully implemented in later weeks) #[derive(Debug, Clone)] pub struct Ipv6Layer { pub index: LayerIndex, @@ -476,7 +436,6 @@ pub mod ethertype { } } - /// Get LayerKind for EtherType pub fn to_layer_kind(t: u16) -> Option { match t { IPV4 => Some(LayerKind::Ipv4), @@ -489,7 +448,6 @@ pub mod ethertype { } } - /// Get EtherType for LayerKind pub fn from_layer_kind(kind: LayerKind) -> Option { match kind { LayerKind::Ipv4 => Some(IPV4), @@ -505,32 +463,7 @@ pub mod ethertype { /// IP protocol numbers pub mod ip_protocol { - use crate::LayerKind; - - pub const ICMP: u8 = 1; - pub const TCP: u8 = 6; - pub const UDP: u8 = 17; - pub const ICMPV6: u8 = 58; - - pub fn name(p: u8) -> &'static str { - match p { - ICMP => "ICMP", - TCP => "TCP", - UDP => "UDP", - ICMPV6 => "ICMPv6", - _ => "Unknown", - } - } - - pub fn to_layer_kind(p: u8) -> Option { - match p { - ICMP => Some(LayerKind::Icmp), - TCP => Some(LayerKind::Tcp), - UDP => Some(LayerKind::Udp), - ICMPV6 => Some(LayerKind::Icmpv6), - _ => None, - } - } + pub use crate::layer::ipv4::protocol::*; } #[cfg(test)] diff --git a/crates/stackforge-core/src/packet.rs b/crates/stackforge-core/src/packet.rs index 39c4f8b..814ce7d 100644 --- a/crates/stackforge-core/src/packet.rs +++ b/crates/stackforge-core/src/packet.rs @@ -18,6 +18,7 @@ use crate::layer::{ arp::ArpLayer, ethernet::{ETHERNET_HEADER_LEN, EthernetLayer}, ethertype, ip_protocol, + ipv4::Ipv4Layer, }; /// Maximum number of layers to store inline before heap allocation. @@ -166,6 +167,12 @@ impl Packet { .map(|idx| EthernetLayer::new(idx.start, idx.end)) } + /// Get the IPv4 layer view if present. + pub fn ipv4(&self) -> Option { + self.get_layer(LayerKind::Ipv4) + .map(|idx| Ipv4Layer::new(idx.start, idx.end)) + } + /// Get the ARP layer view if present. pub fn arp(&self) -> Option { self.get_layer(LayerKind::Arp) diff --git a/tests/integration/arp.rs b/tests/integration/arp.rs index cc4b261..89e51c6 100644 --- a/tests/integration/arp.rs +++ b/tests/integration/arp.rs @@ -3,6 +3,37 @@ use stackforge_core::prelude::*; use std::net::Ipv4Addr; +fn make_arp_request() -> Vec { + let eth = EthernetBuilder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .build_with_payload(LayerKind::Arp); + + let arp = ArpBuilder::who_has(Ipv4Addr::new(192, 168, 1, 100)) + .hwsrc(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .psrc(Ipv4Addr::new(192, 168, 1, 1)) + .build(); + + [eth, arp].concat() +} + +fn make_arp_reply() -> Vec { + let eth = EthernetBuilder::new() + .dst(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .src(MacAddress::new([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff])) + .build_with_payload(LayerKind::Arp); + + let arp = ArpBuilder::is_at( + Ipv4Addr::new(192, 168, 1, 100), + MacAddress::new([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]), + ) + .hwdst(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .pdst(Ipv4Addr::new(192, 168, 1, 1)) + .build(); + + [eth, arp].concat() +} + #[test] fn test_arp_who_has_builder() { let arp = ArpBuilder::who_has(Ipv4Addr::new(192, 168, 1, 100)) @@ -205,3 +236,114 @@ fn test_parse_real_arp_capture() { Ipv4Addr::new(10, 0, 0, 254) ); } + +#[test] +fn test_arp_request_creation() { + let arp_data = ArpBuilder::who_has(Ipv4Addr::new(192, 168, 1, 100)) + .hwsrc(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .psrc(Ipv4Addr::new(192, 168, 1, 1)) + .build(); + + assert_eq!(arp_data.len(), ARP_HEADER_LEN); + + let arp = ArpLayer::at_offset(0); + assert_eq!(arp.op(&arp_data).unwrap(), arp_opcode::REQUEST); + assert_eq!(arp.hwtype(&arp_data).unwrap(), arp_hardware::ETHERNET); + assert!(arp.is_request(&arp_data)); + assert!(arp.is_who_has(&arp_data)); +} + +#[test] +fn test_arp_reply_creation() { + let arp_data = ArpBuilder::is_at( + Ipv4Addr::new(192, 168, 1, 100), + MacAddress::new([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]), + ) + .build(); + + let arp = ArpLayer::at_offset(0); + assert_eq!(arp.op(&arp_data).unwrap(), arp_opcode::REPLY); + assert!(arp.is_reply(&arp_data)); + assert!(arp.is_is_at(&arp_data)); +} + +#[test] +fn test_arp_ipv6_addresses() { + // Create ARP-like packet with IPv6 addresses + let arp_data = ArpBuilder::new() + .hwtype(arp_hardware::ETHERNET) + .ptype(0x86DD) // IPv6 + .hwlen(6) + .plen(16) + .op(arp_opcode::REQUEST) + .psrc_v6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)) + .pdst_v6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2)) + .build(); + + // Size should be 8 (fixed) + 2*6 (hwaddr) + 2*16 (paddr) = 52 + assert_eq!(arp_data.len(), 52); + + let arp = ArpLayer::at_offset(0); + assert_eq!(arp.plen(&arp_data).unwrap(), 16); + assert_eq!( + arp.psrc_v6(&arp_data).unwrap(), + Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1) + ); +} + +#[test] +fn test_arp_hashret_matching() { + let request = make_arp_request(); + let reply = make_arp_reply(); + + let mut req_pkt = Packet::from_bytes(request); + let mut rep_pkt = Packet::from_bytes(reply); + + req_pkt.parse().unwrap(); + rep_pkt.parse().unwrap(); + + // Request and reply should have same hashret (same op group) + let req_arp = req_pkt.arp().unwrap(); + let rep_arp = rep_pkt.arp().unwrap(); + + let req_hash = req_arp.hashret(req_pkt.as_bytes()); + let rep_hash = rep_arp.hashret(rep_pkt.as_bytes()); + + assert_eq!(req_hash, rep_hash); +} + +#[test] +fn test_arp_extract_padding() { + let mut data = make_arp_request(); + // Add padding to make 64-byte frame + data.resize(64, 0); + + let mut pkt = Packet::from_bytes(data); + pkt.parse().unwrap(); + + let arp = pkt.arp().unwrap(); + let (payload, padding) = arp.extract_padding(pkt.as_bytes()); + + // ARP has no payload + assert!(payload.is_empty()); + // Should have padding + assert!(!padding.is_empty()); +} + +#[test] +fn test_arp_resolve_dst_mac() { + let request_data = ArpBuilder::who_has(Ipv4Addr::new(192, 168, 1, 100)).build(); + + let arp = ArpLayer::at_offset(0); + let dst_mac = arp.resolve_dst_mac(&request_data); + + // ARP requests should resolve to broadcast + assert_eq!(dst_mac, Some(MacAddress::BROADCAST)); + + // ARP replies should not auto-resolve + let reply_data = + ArpBuilder::is_at(Ipv4Addr::new(192, 168, 1, 100), MacAddress::BROADCAST).build(); + + let dst_mac = arp.resolve_dst_mac(&reply_data); + assert_eq!(dst_mac, None); +} diff --git a/tests/integration/ethernet.rs b/tests/integration/ethernet.rs index 6b1593c..8c6f612 100644 --- a/tests/integration/ethernet.rs +++ b/tests/integration/ethernet.rs @@ -2,6 +2,14 @@ use stackforge_core::prelude::*; +fn make_dot3_frame() -> Vec { + Dot3Builder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .len(4) + .build() +} + #[test] fn test_ethernet_builder() { let frame = EthernetBuilder::new() @@ -95,3 +103,80 @@ fn test_mac_address_properties() { let local = MacAddress::new([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); assert!(local.is_local()); } + +#[test] +fn test_ethernet_frame_creation() { + let frame = EthernetBuilder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .ethertype(ethertype::IPV4) + .build(); + + assert_eq!(frame.len(), ETHERNET_HEADER_LEN); + + let eth = EthernetLayer::at_start(); + assert_eq!(eth.dst(&frame).unwrap(), MacAddress::BROADCAST); + assert_eq!(eth.ethertype(&frame).unwrap(), ethertype::IPV4); +} + +#[test] +fn test_ethernet_auto_ethertype() { + let frame = EthernetBuilder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::ZERO) + .build_with_payload(LayerKind::Arp); + + let eth = EthernetLayer::at_start(); + assert_eq!(eth.ethertype(&frame).unwrap(), ethertype::ARP); +} + +#[test] +fn test_ethernet_dispatch_hook() { + // Ethernet II (EtherType > 1500) + let eth2 = EthernetBuilder::new().ethertype(ethertype::IPV4).build(); + assert!(is_ethernet_ii(ð2, 0)); + assert!(!is_dot3(ð2, 0)); + + // 802.3 (Length <= 1500) + let dot3 = Dot3Builder::new().len(100).build(); + assert!(is_dot3(&dot3, 0)); + assert!(!is_ethernet_ii(&dot3, 0)); +} + +#[test] +fn test_ethernet_hashret_and_answers() { + let frame1 = EthernetBuilder::new().ethertype(ethertype::ARP).build(); + let frame2 = EthernetBuilder::new().ethertype(ethertype::ARP).build(); + + let eth1 = EthernetLayer::at_start(); + let eth2 = EthernetLayer::at_start(); + + // Same EtherType should match + assert!(eth1.answers(&frame1, ð2, &frame2)); + assert_eq!(eth1.hashret(&frame1), eth2.hashret(&frame2)); +} + +#[test] +fn test_dot3_frame_creation() { + let frame = Dot3Builder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .len(100) + .build(); + + let dot3 = Dot3Layer::at_start(); + assert_eq!(dot3.len_field(&frame).unwrap(), 100); +} + +#[test] +fn test_dot3_extract_padding() { + let mut frame = Dot3Builder::new().len(4).build(); + frame.extend_from_slice(&[0xde, 0xad, 0xbe, 0xef]); // Payload + frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // Padding + + let dot3 = Dot3Layer::at_start(); + let (payload, padding) = dot3.extract_padding(&frame); + + assert_eq!(payload.len(), 4); + assert_eq!(padding.len(), 4); +} diff --git a/tests/integration/field.rs b/tests/integration/field.rs new file mode 100644 index 0000000..251e437 --- /dev/null +++ b/tests/integration/field.rs @@ -0,0 +1,46 @@ +use stackforge_core::{Field, HardwareAddr, MacAddress, ProtocolAddr}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +#[test] +fn test_ipv6_field_operations() { + let ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1); + let mut buf = [0u8; 20]; + + ip.write(&mut buf, 2).unwrap(); + let read_ip = Ipv6Addr::read(&buf, 2).unwrap(); + + assert_eq!(ip, read_ip); +} + +#[test] +fn test_mac_address_parsing() { + let mac1 = MacAddress::parse("00:11:22:33:44:55").unwrap(); + let mac2 = MacAddress::parse("00-11-22-33-44-55").unwrap(); + + assert_eq!(mac1, mac2); + assert_eq!(mac1.0, [0x00, 0x11, 0x22, 0x33, 0x44, 0x55]); +} + +#[test] +fn test_hardware_addr_enum() { + let mac = MacAddress::new([0xaa; 6]); + let hw_addr = HardwareAddr::from(mac); + + assert_eq!(hw_addr.len(), 6); + assert_eq!(hw_addr.as_mac(), Some(mac)); + + let raw = HardwareAddr::from_bytes(&[0x01, 0x02, 0x03]); + assert_eq!(raw.len(), 3); + assert_eq!(raw.as_mac(), None); +} + +#[test] +fn test_protocol_addr_enum() { + let ipv4 = ProtocolAddr::from(Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(ipv4.len(), 4); + assert_eq!(ipv4.as_ipv4(), Some(Ipv4Addr::new(192, 168, 1, 1))); + + let ipv6 = ProtocolAddr::from(Ipv6Addr::LOCALHOST); + assert_eq!(ipv6.len(), 16); + assert_eq!(ipv6.as_ipv6(), Some(Ipv6Addr::LOCALHOST)); +} diff --git a/tests/integration/layer.rs b/tests/integration/layer.rs new file mode 100644 index 0000000..618b483 --- /dev/null +++ b/tests/integration/layer.rs @@ -0,0 +1,69 @@ +use stackforge_core::infer_upper_layer; +use stackforge_core::prelude::*; +use std::net::Ipv4Addr; + +#[test] +fn test_layer_bindings() { + // Ethernet -> ARP + let binding = find_binding(LayerKind::Ethernet, LayerKind::Arp); + assert!(binding.is_some()); + assert_eq!(binding.unwrap().field_value, ethertype::ARP); + + // Ethernet -> IPv4 + let binding = find_binding(LayerKind::Ethernet, LayerKind::Ipv4); + assert!(binding.is_some()); + assert_eq!(binding.unwrap().field_value, ethertype::IPV4); + + // Apply binding helper + let (field, value) = apply_binding(LayerKind::Ethernet, LayerKind::Ipv6).unwrap(); + assert_eq!(field, "type"); + assert_eq!(value, ethertype::IPV6); +} + +#[test] +fn test_infer_upper_layer() { + assert_eq!( + infer_upper_layer(LayerKind::Ethernet, "type", ethertype::ARP), + Some(LayerKind::Arp) + ); + assert_eq!( + infer_upper_layer(LayerKind::Ethernet, "type", ethertype::IPV4), + Some(LayerKind::Ipv4) + ); + assert_eq!(infer_upper_layer(LayerKind::Ethernet, "type", 0x9999), None); +} + +// ============================================================================ +// Neighbor Resolution Tests +// ============================================================================ + +#[test] +fn test_neighbor_cache() { + let cache = NeighborCache::new(); + + let ip = Ipv4Addr::new(192, 168, 1, 1); + let mac = MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]); + + cache.cache_arp(ip, mac); + assert_eq!(cache.lookup_arp(&ip), Some(mac)); +} + +#[test] +fn test_multicast_mac_generation() { + // IPv4 multicast + let ip = Ipv4Addr::new(224, 0, 0, 1); + let mac = ipv4_multicast_mac(ip); + assert!(mac.is_multicast()); + assert!(mac.is_ipv4_multicast()); + assert_eq!(mac.0[0], 0x01); + assert_eq!(mac.0[1], 0x00); + assert_eq!(mac.0[2], 0x5e); + + // IPv6 multicast + let ip6 = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1); + let mac6 = ipv6_multicast_mac(ip6); + assert!(mac6.is_multicast()); + assert!(mac6.is_ipv6_multicast()); + assert_eq!(mac6.0[0], 0x33); + assert_eq!(mac6.0[1], 0x33); +} diff --git a/tests/integration/main.rs b/tests/integration/main.rs index ea6b495..bb4bc00 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -5,4 +5,7 @@ mod arp; mod ethernet; +mod field; +mod layer; mod packet; +mod util; diff --git a/tests/integration/packet.rs b/tests/integration/packet.rs index abc5efd..cd80019 100644 --- a/tests/integration/packet.rs +++ b/tests/integration/packet.rs @@ -3,6 +3,20 @@ use stackforge_core::prelude::*; use std::net::Ipv4Addr; +fn make_arp_request() -> Vec { + let eth = EthernetBuilder::new() + .dst(MacAddress::BROADCAST) + .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .build_with_payload(LayerKind::Arp); + + let arp = ArpBuilder::who_has(Ipv4Addr::new(192, 168, 1, 100)) + .hwsrc(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])) + .psrc(Ipv4Addr::new(192, 168, 1, 1)) + .build(); + + [eth, arp].concat() +} + #[test] fn test_packet_from_bytes() { let data = vec![1, 2, 3, 4, 5]; @@ -226,3 +240,86 @@ fn test_packet_into_bytes() { let bytes = packet.into_bytes(); assert_eq!(&bytes[..], &data[..]); } + +#[test] +fn test_full_arp_request_packet() { + let data = make_arp_request(); + let mut pkt = Packet::from_bytes(data.clone()); + + pkt.parse().unwrap(); + + // Verify layers + assert_eq!(pkt.layer_count(), 2); + assert!(pkt.get_layer(LayerKind::Ethernet).is_some()); + assert!(pkt.get_layer(LayerKind::Arp).is_some()); + + // Verify Ethernet + let eth = pkt.ethernet().unwrap(); + assert_eq!(eth.dst(pkt.as_bytes()).unwrap(), MacAddress::BROADCAST); + assert_eq!(eth.ethertype(pkt.as_bytes()).unwrap(), ethertype::ARP); + + // Verify ARP + let arp = pkt.arp().unwrap(); + assert!(arp.is_request(pkt.as_bytes())); + assert_eq!( + arp.pdst(pkt.as_bytes()).unwrap(), + Ipv4Addr::new(192, 168, 1, 100) + ); +} + +#[test] +fn test_packet_modification() { + let data = make_arp_request(); + let mut pkt = Packet::from_bytes(data); + pkt.parse().unwrap(); + + // Modify source MAC + let eth = pkt.ethernet().unwrap(); + pkt.with_data_mut(|buf| { + eth.set_src(buf, MacAddress::new([0xde, 0xad, 0xbe, 0xef, 0x00, 0x00])) + .unwrap(); + }); + + assert!(pkt.is_dirty()); + + // Verify modification + let eth = pkt.ethernet().unwrap(); + assert_eq!( + eth.src(pkt.as_bytes()).unwrap(), + MacAddress::new([0xde, 0xad, 0xbe, 0xef, 0x00, 0x00]) + ); +} + +#[test] +fn test_arp_compatibility() { + // Scapy: ARP(op="who-has", pdst="192.168.1.100") + let arp = ArpBuilder::new() + .op_name("who-has") + .pdst(Ipv4Addr::new(192, 168, 1, 100)) + .build(); + + let layer = ArpLayer::at_offset(0); + + // Verify default values match Scapy + assert_eq!(layer.hwtype(&arp).unwrap(), 1); // Ethernet + assert_eq!(layer.ptype(&arp).unwrap(), 0x0800); // IPv4 + assert_eq!(layer.hwlen(&arp).unwrap(), 6); + assert_eq!(layer.plen(&arp).unwrap(), 4); + assert_eq!(layer.op(&arp).unwrap(), 1); // who-has + + // Check summary format + let summary = layer.summary(&arp); + assert!(summary.contains("who has") || summary.contains("ARP")); +} + +#[test] +fn test_opcode_names() { + use stackforge_core::arp_opcode; + + assert_eq!(arp_opcode::name(1), "who-has"); + assert_eq!(arp_opcode::name(2), "is-at"); + assert_eq!(arp_opcode::from_name("who-has"), Some(1)); + assert_eq!(arp_opcode::from_name("is-at"), Some(2)); + assert_eq!(arp_opcode::from_name("request"), Some(1)); + assert_eq!(arp_opcode::from_name("reply"), Some(2)); +} diff --git a/tests/integration/util.rs b/tests/integration/util.rs new file mode 100644 index 0000000..7c4ed54 --- /dev/null +++ b/tests/integration/util.rs @@ -0,0 +1,30 @@ +use stackforge_core::{hexdump, hexstr, internet_checksum}; + +#[test] +fn test_hexdump_output() { + let data = b"Hello, World!"; + let dump = hexdump(data); + + assert!(dump.contains("48 65 6c 6c")); // "Hell" in hex + assert!(dump.contains("|Hello, World!|")); +} + +#[test] +fn test_hexstr() { + let data = [0xde, 0xad, 0xbe, 0xef]; + assert_eq!(hexstr(&data), "deadbeef"); +} + +#[test] +fn test_checksum() { + // Known good IP header + let header = [ + 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, + 0x00, // checksum field = 0 + 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, + ]; + + let checksum = internet_checksum(&header); + // Checksum should be non-zero for this data + assert_ne!(checksum, 0); +}