diff --git a/protocol/src/futures.rs b/protocol/src/futures.rs index b8d6526..dd6af62 100644 --- a/protocol/src/futures.rs +++ b/protocol/src/futures.rs @@ -406,7 +406,15 @@ where bytes_read, } => { while *bytes_read < NUM_LENGTH_BYTES { - *bytes_read += self.reader.read(&mut length_bytes[*bytes_read..]).await?; + let len = self.reader.read(&mut length_bytes[*bytes_read..]).await?; + if len == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "read zero bytes", + ) + .into()); + } + *bytes_read += len; } let packet_bytes_len = self.inbound_cipher.decrypt_packet_len(*length_bytes); @@ -417,7 +425,15 @@ where bytes_read, } => { while *bytes_read < packet_bytes.len() { - *bytes_read += self.reader.read(&mut packet_bytes[*bytes_read..]).await?; + let len = self.reader.read(&mut packet_bytes[*bytes_read..]).await?; + if len == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "read zero bytes", + ) + .into()); + } + *bytes_read += len; } let plaintext_len = InboundCipher::decryption_buffer_len(packet_bytes.len()); diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 8b07daa..206aac3 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -154,6 +154,83 @@ fn hello_world_happy_path() { assert_eq!(message, decrypted_message[1..].to_vec()); // Skip header byte } +#[tokio::test] +#[cfg(feature = "tokio")] +async fn pingpong_with_closed_connection_async() { + use bip324::{futures::Protocol, io::Payload}; + use bitcoin::consensus; + use p2p::message::{NetworkMessage, V2NetworkMessage}; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + + // Start a server that responds to exactly one Ping(x) message with a + // Pong(x) message and then stops. This allows testing to read from a closed stream. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, writer) = stream.into_split(); + let mut protocol = Protocol::new( + p2p::Magic::REGTEST, + bip324::Role::Responder, + None, // no garbage + None, // no decoys + reader, + writer, + ) + .await + .unwrap(); + + let payload = protocol.read().await.unwrap(); + let received_message = + consensus::deserialize::(payload.contents()).unwrap(); + if let NetworkMessage::Ping(x) = received_message.payload() { + let pong = V2NetworkMessage::new(NetworkMessage::Pong(*x)); + let message = consensus::serialize(&pong); + protocol.write(&Payload::genuine(message)).await.unwrap(); + println!("Pong sent, stopping server.") + } else { + panic!("Expected Ping, but received: {received_message:?}"); + } + }); + + let stream = TcpStream::connect(addr).await.unwrap(); + + let (reader, writer) = stream.into_split(); + + // Initialize high-level async protocol with handshake + println!("Starting async BIP-324 handshake"); + let mut protocol = Protocol::new( + p2p::Magic::REGTEST, + bip324::Role::Initiator, + None, // no garbage + None, // no decoys + reader, + writer, + ) + .await + .unwrap(); + + println!("Sending Ping using async Protocol::write()"); + let ping = V2NetworkMessage::new(NetworkMessage::Ping(45324)); + let message = consensus::serialize(&ping); + protocol.write(&Payload::genuine(message)).await.unwrap(); + + println!("Reading response using async Protocol::read()"); + let payload = protocol.read().await.unwrap(); + let response_message = consensus::deserialize::(payload.contents()).unwrap(); + + assert_eq!(NetworkMessage::Pong(45324), *response_message.payload()); + + println!("Successfully ping-pong message using async Protocol API!"); + server.await.unwrap(); + + println!( + "Trying to read another message from the server, while the connection is already closed." + ); + assert!(protocol.read().await.is_err()); +} + #[test] #[cfg(feature = "std")] fn regtest_handshake() {