diff --git a/Cargo.toml b/Cargo.toml index a45a6bf..0877129 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ futures-util = "0.3.31" rand = "0.9.0" tokio = { version = "1.43.0", features = ["full"], optional = true } tokio-serial = { version = "5.4.5", optional = true } -tokio-util = { version = "0.7.13", optional = true } +tokio-util = { version = "0.7.13", optional = true, features = ["rt"] } prost = "0.14" log = "0.4.25" diff --git a/src/connections/stream_api.rs b/src/connections/stream_api.rs index 6be9dbe..bc79cca 100644 --- a/src/connections/stream_api.rs +++ b/src/connections/stream_api.rs @@ -1,13 +1,11 @@ -use futures_util::future::join3; use log::trace; use prost::Message; use std::{fmt::Display, marker::PhantomData}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::mpsc::UnboundedSender, - task::JoinHandle, }; -use tokio_util::sync::CancellationToken; +use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use crate::{errors_internal::Error, protobufs, types::EncodedToRadioPacketWithHeader, utils}; use crate::{ @@ -71,10 +69,13 @@ pub struct StreamApi; pub struct ConnectedStreamApi { write_input_tx: UnboundedSender, - read_handle: JoinHandle>, - write_handle: JoinHandle>, - processing_handle: JoinHandle>, - heartbeat_handle: JoinHandle>, + read_handle: AbortOnDropHandle>, + write_handle: AbortOnDropHandle>, + processing_handle: AbortOnDropHandle>, + heartbeat_handle: AbortOnDropHandle>, + /// An optional handle to a background task that bridges data between a high-level + /// stream and a low-level transport (e.g. BLE). + bridge_handle: Option>>, cancellation_token: CancellationToken, @@ -87,7 +88,7 @@ pub struct StreamHandle { /// The underlying stream. pub stream: T, /// An optional join handle that processes data on the other side of the stream. - pub join_handle: Option>>, + pub join_handle: Option>>, } impl StreamHandle { @@ -422,7 +423,7 @@ impl StreamApi { /// pub async fn connect( self, - stream_handle: StreamHandle, + mut stream_handle: StreamHandle, ) -> (PacketReceiver, ConnectedStreamApi) where S: AsyncReadExt + AsyncWriteExt + Send + 'static, @@ -440,28 +441,32 @@ impl StreamApi { // Spawn worker threads with kill switch + let bridge_handle = stream_handle.join_handle.take(); let (read_stream, write_stream) = tokio::io::split(stream_handle.stream); let cancellation_token = CancellationToken::new(); - let read_handle = - handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx); + let read_handle = AbortOnDropHandle::new(handlers::spawn_read_handler( + cancellation_token.clone(), + read_stream, + read_output_tx, + )); - let write_handle = - handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx); + let write_handle = AbortOnDropHandle::new(handlers::spawn_write_handler( + cancellation_token.clone(), + write_stream, + write_input_rx, + )); - let processing_handle = handlers::spawn_processing_handler( + let processing_handle = AbortOnDropHandle::new(handlers::spawn_processing_handler( cancellation_token.clone(), read_output_rx, decoded_packet_tx, - ); - - let heartbeat_handle = - handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_input_tx.clone()); - - // Persist channels and kill switch to struct + )); - let write_input_tx = write_input_tx; - let cancellation_token = cancellation_token; + let heartbeat_handle = AbortOnDropHandle::new(handlers::spawn_heartbeat_handler( + cancellation_token.clone(), + write_input_tx.clone(), + )); // Return channel for receiving decoded packets @@ -473,6 +478,7 @@ impl StreamApi { write_handle, processing_handle, heartbeat_handle, + bridge_handle, cancellation_token, typestate: PhantomData, }, @@ -549,13 +555,14 @@ impl ConnectedStreamApi { write_handle: self.write_handle, processing_handle: self.processing_handle, heartbeat_handle: self.heartbeat_handle, + bridge_handle: self.bridge_handle, cancellation_token: self.cancellation_token, typestate: PhantomData, }) } } -impl ConnectedStreamApi { +impl ConnectedStreamApi { /// A method to disconnect from a radio. This method will close all channels and /// join all worker threads. If connected via serial or TCP, this will also trigger /// the radio to terminate its current connection. @@ -600,13 +607,23 @@ impl ConnectedStreamApi { // Close worker threads - let (read_result, write_result, processing_result) = - join3(self.read_handle, self.write_handle, self.processing_handle).await; + let (read_result, write_result, processing_result, heartbeat_result) = tokio::join!( + self.read_handle, + self.write_handle, + self.processing_handle, + self.heartbeat_handle + ); + + if let Some(handle) = self.bridge_handle { + handle.abort(); + let _ = handle.await; + } // Note: we only return the first error. read_result??; write_result??; processing_result??; + heartbeat_result??; trace!("Handlers fully disconnected"); diff --git a/src/connections/wrappers.rs b/src/connections/wrappers.rs index 49a64f8..b136b4e 100644 --- a/src/connections/wrappers.rs +++ b/src/connections/wrappers.rs @@ -240,3 +240,6 @@ pub mod mesh_channel { } } } + +#[cfg(test)] +mod tests {} diff --git a/src/utils_internal.rs b/src/utils_internal.rs index 45dc577..3575612 100644 --- a/src/utils_internal.rs +++ b/src/utils_internal.rs @@ -351,7 +351,7 @@ where Ok(StreamHandle { stream: client, - join_handle: Some(handle), + join_handle: Some(tokio_util::task::AbortOnDropHandle::new(handle)), }) }