diff --git a/alioth/src/virtio/dev/net/vmnet.rs b/alioth/src/virtio/dev/net/vmnet.rs index 389426ef..cd510a4d 100644 --- a/alioth/src/virtio/dev/net/vmnet.rs +++ b/alioth/src/virtio/dev/net/vmnet.rs @@ -404,28 +404,7 @@ impl VirtioMio for Net { fn read_from_vmnet(interface: *mut VmnetInterface) -> impl FnMut(&mut DescChain) -> Result { move |chain: &mut DescChain| { - let mut trim_len = size_of::(); - let mut iov = Vec::with_capacity(chain.writable.len()); - - for buf in chain.writable.iter_mut() { - if trim_len > 0 { - if let Some((_, tail)) = buf.split_at_mut_checked(trim_len) { - iov.push(libc::iovec { - iov_base: tail.as_ptr() as *mut c_void, - iov_len: tail.len(), - }); - trim_len = 0; - } else { - trim_len -= buf.len(); - } - } else { - iov.push(libc::iovec { - iov_base: buf.as_ptr() as *mut c_void, - iov_len: buf.len(), - }); - } - } - + let mut iov = trim_desc_chain(chain.writable.iter().map(|b| &**b)); let size = iov.iter().map(|s| s.iov_len).sum(); let mut packets = VmPktDesc { vm_pkt_size: size, @@ -455,32 +434,11 @@ fn read_from_vmnet(interface: *mut VmnetInterface) -> impl FnMut(&mut DescChain) fn write_to_vmnet(interface: *mut VmnetInterface) -> impl FnMut(&mut DescChain) -> Result { move |chain: &mut DescChain| { - let mut trim_len = size_of::(); - let mut iov = Vec::with_capacity(chain.readable.len()); - - for buf in chain.readable.iter() { - if trim_len > 0 { - if let Some((_, tail)) = buf.split_at_checked(trim_len) { - iov.push(libc::iovec { - iov_base: tail.as_ptr() as *mut c_void, - iov_len: tail.len(), - }); - trim_len = 0; - } else { - trim_len -= buf.len(); - } - } else { - iov.push(libc::iovec { - iov_base: buf.as_ptr() as *mut c_void, - iov_len: buf.len(), - }); - } - } - + let mut iov = trim_desc_chain(chain.readable.iter().map(|b| &**b)); let size = iov.iter().map(|s| s.iov_len).sum(); let mut packets = VmPktDesc { vm_pkt_size: size, - vm_pkt_iov: iov.as_ptr() as *mut libc::iovec, + vm_pkt_iov: iov.as_mut_ptr(), vm_pkt_iovcnt: iov.len() as u32, vm_flags: 0, }; @@ -508,3 +466,29 @@ impl DevParam for NetVmnetParam { Net::new(self, name) } } + +fn trim_desc_chain<'m>(bufs: impl Iterator) -> Vec { + let mut iov = Vec::new(); + let mut trim_len = size_of::(); + + for buf in bufs { + let b = if trim_len > 0 { + if let Some((_, tail)) = buf.split_at_checked(trim_len) + && !tail.is_empty() + { + trim_len = 0; + tail + } else { + trim_len -= buf.len(); + continue; + } + } else { + buf + }; + iov.push(libc::iovec { + iov_base: b.as_ptr() as *mut c_void, + iov_len: b.len(), + }); + } + iov +}