From dd6f4c96e1f6041666a8f856c2e3f237f93da33d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Pol=C3=A1=C4=8Dek?= Date: Sun, 22 Mar 2026 18:55:24 +0100 Subject: [PATCH 1/2] Abort JoinHandle when tasks are aborted Use the new `AbortingJoinHandle` to cancel a JoinHandle (using `JoinHandle::abort()`) when `AbortingJoinHandle` is dropped. --- src/connections/stream_api.rs | 55 +++++++++++++++---------- src/connections/wrappers.rs | 77 +++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/utils_internal.rs | 2 +- 4 files changed, 113 insertions(+), 22 deletions(-) diff --git a/src/connections/stream_api.rs b/src/connections/stream_api.rs index 6be9dbe..e60dc8e 100644 --- a/src/connections/stream_api.rs +++ b/src/connections/stream_api.rs @@ -1,11 +1,9 @@ -use futures_util::future::join3; use log::trace; use prost::Message; use std::{fmt::Display, marker::PhantomData}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::mpsc::UnboundedSender, - task::JoinHandle, }; use tokio_util::sync::CancellationToken; @@ -20,7 +18,7 @@ use super::{ wrappers::{ encoded_data::{EncodedMeshPacketData, EncodedToRadioPacket, IncomingStreamData}, mesh_channel::MeshChannel, - NodeId, + AbortingJoinHandle, NodeId, }, PacketDestination, PacketRouter, }; @@ -71,10 +69,13 @@ pub struct StreamApi; pub struct ConnectedStreamApi { write_input_tx: UnboundedSender, - read_handle: JoinHandle>, - write_handle: JoinHandle>, - processing_handle: JoinHandle>, - heartbeat_handle: JoinHandle>, + read_handle: AbortingJoinHandle, + write_handle: AbortingJoinHandle, + processing_handle: AbortingJoinHandle, + heartbeat_handle: AbortingJoinHandle, + /// An optional handle to a background task that bridges data between a high-level + /// stream and a low-level transport (e.g. BLE). + bridge_handle: Option, cancellation_token: CancellationToken, @@ -87,7 +88,7 @@ pub struct StreamHandle { /// The underlying stream. pub stream: T, /// An optional join handle that processes data on the other side of the stream. - pub join_handle: Option>>, + pub join_handle: Option, } impl StreamHandle { @@ -422,7 +423,7 @@ impl StreamApi { /// pub async fn connect( self, - stream_handle: StreamHandle, + mut stream_handle: StreamHandle, ) -> (PacketReceiver, ConnectedStreamApi) where S: AsyncReadExt + AsyncWriteExt + Send + 'static, @@ -440,28 +441,28 @@ impl StreamApi { // Spawn worker threads with kill switch + let bridge_handle = stream_handle.join_handle.take(); let (read_stream, write_stream) = tokio::io::split(stream_handle.stream); let cancellation_token = CancellationToken::new(); let read_handle = - handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx); + handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx) + .into(); let write_handle = - handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx); + handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx) + .into(); let processing_handle = handlers::spawn_processing_handler( cancellation_token.clone(), read_output_rx, decoded_packet_tx, - ); + ) + .into(); let heartbeat_handle = - handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_input_tx.clone()); - - // Persist channels and kill switch to struct - - let write_input_tx = write_input_tx; - let cancellation_token = cancellation_token; + handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_input_tx.clone()) + .into(); // Return channel for receiving decoded packets @@ -473,6 +474,7 @@ impl StreamApi { write_handle, processing_handle, heartbeat_handle, + bridge_handle, cancellation_token, typestate: PhantomData, }, @@ -549,13 +551,14 @@ impl ConnectedStreamApi { write_handle: self.write_handle, processing_handle: self.processing_handle, heartbeat_handle: self.heartbeat_handle, + bridge_handle: self.bridge_handle, cancellation_token: self.cancellation_token, typestate: PhantomData, }) } } -impl ConnectedStreamApi { +impl ConnectedStreamApi { /// A method to disconnect from a radio. This method will close all channels and /// join all worker threads. If connected via serial or TCP, this will also trigger /// the radio to terminate its current connection. @@ -600,13 +603,23 @@ impl ConnectedStreamApi { // Close worker threads - let (read_result, write_result, processing_result) = - join3(self.read_handle, self.write_handle, self.processing_handle).await; + let (read_result, write_result, processing_result, heartbeat_result) = tokio::join!( + self.read_handle, + self.write_handle, + self.processing_handle, + self.heartbeat_handle + ); + + if let Some(handle) = self.bridge_handle { + handle.abort(); + let _ = handle.await; + } // Note: we only return the first error. read_result??; write_result??; processing_result??; + heartbeat_result??; trace!("Handlers fully disconnected"); diff --git a/src/connections/wrappers.rs b/src/connections/wrappers.rs index 49a64f8..407bcf5 100644 --- a/src/connections/wrappers.rs +++ b/src/connections/wrappers.rs @@ -1,3 +1,9 @@ +use std::ops::Deref; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::task::{JoinError, JoinHandle}; + use crate::errors_internal::Error; /// A helper struct representing the ID of a node in the mesh. @@ -240,3 +246,74 @@ pub mod mesh_channel { } } } + +/// A wrapper around a [`JoinHandle`] that will automatically abort the underlying task +/// when it is dropped. +#[derive(Debug)] +pub struct AbortingJoinHandle(JoinHandle>); + +impl From>> for AbortingJoinHandle { + fn from(handle: JoinHandle>) -> Self { + Self(handle) + } +} + +impl Deref for AbortingJoinHandle { + type Target = JoinHandle>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Drop for AbortingJoinHandle { + fn drop(&mut self) { + self.0.abort(); + } +} + +impl std::future::Future for AbortingJoinHandle { + type Output = Result, JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_aborting_join_handle() { + let (mut tx, rx) = tokio::sync::oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + let _rx = rx; + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + Ok(()) + }); + + let aborting_handle: AbortingJoinHandle = handle.into(); + drop(aborting_handle); + + // If the task was aborted, the receiver `rx` will be dropped, causing `tx.is_closed()` to + // become true. + tokio::time::timeout(std::time::Duration::from_secs(5), tx.closed()) + .await + .expect("Task was not aborted within timeout"); + } + + #[tokio::test] + async fn test_aborting_join_handle_completion() { + let handle = tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok(()) + }); + + let aborting_handle: AbortingJoinHandle = handle.into(); + + let result = aborting_handle.await; + assert!(result.is_ok()); + assert!(result.unwrap().is_ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3a7e141..1ab45b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ pub mod api { pub use crate::connections::stream_api::ConnectedStreamApi; pub use crate::connections::stream_api::StreamApi; pub use crate::connections::stream_api::StreamHandle; + pub use crate::connections::wrappers::AbortingJoinHandle; } /// This module contains the global `Error` type of the library. This enum implements diff --git a/src/utils_internal.rs b/src/utils_internal.rs index 45dc577..5454fea 100644 --- a/src/utils_internal.rs +++ b/src/utils_internal.rs @@ -351,7 +351,7 @@ where Ok(StreamHandle { stream: client, - join_handle: Some(handle), + join_handle: Some(handle.into()), }) } From b04517ba9f502ae4b36c1d8b4a6c70be9273e6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Pol=C3=A1=C4=8Dek?= Date: Sat, 2 May 2026 23:14:58 +0200 Subject: [PATCH 2/2] Use AbortOnDropHandle from tokio_util No reason to implement our own AbortingJoinHandle --- Cargo.toml | 2 +- src/connections/stream_api.rs | 44 +++++++++++--------- src/connections/wrappers.rs | 76 +---------------------------------- src/lib.rs | 1 - src/utils_internal.rs | 2 +- 5 files changed, 27 insertions(+), 98 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a45a6bf..0877129 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ futures-util = "0.3.31" rand = "0.9.0" tokio = { version = "1.43.0", features = ["full"], optional = true } tokio-serial = { version = "5.4.5", optional = true } -tokio-util = { version = "0.7.13", optional = true } +tokio-util = { version = "0.7.13", optional = true, features = ["rt"] } prost = "0.14" log = "0.4.25" diff --git a/src/connections/stream_api.rs b/src/connections/stream_api.rs index e60dc8e..bc79cca 100644 --- a/src/connections/stream_api.rs +++ b/src/connections/stream_api.rs @@ -5,7 +5,7 @@ use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::mpsc::UnboundedSender, }; -use tokio_util::sync::CancellationToken; +use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use crate::{errors_internal::Error, protobufs, types::EncodedToRadioPacketWithHeader, utils}; use crate::{ @@ -18,7 +18,7 @@ use super::{ wrappers::{ encoded_data::{EncodedMeshPacketData, EncodedToRadioPacket, IncomingStreamData}, mesh_channel::MeshChannel, - AbortingJoinHandle, NodeId, + NodeId, }, PacketDestination, PacketRouter, }; @@ -69,13 +69,13 @@ pub struct StreamApi; pub struct ConnectedStreamApi { write_input_tx: UnboundedSender, - read_handle: AbortingJoinHandle, - write_handle: AbortingJoinHandle, - processing_handle: AbortingJoinHandle, - heartbeat_handle: AbortingJoinHandle, + read_handle: AbortOnDropHandle>, + write_handle: AbortOnDropHandle>, + processing_handle: AbortOnDropHandle>, + heartbeat_handle: AbortOnDropHandle>, /// An optional handle to a background task that bridges data between a high-level /// stream and a low-level transport (e.g. BLE). - bridge_handle: Option, + bridge_handle: Option>>, cancellation_token: CancellationToken, @@ -88,7 +88,7 @@ pub struct StreamHandle { /// The underlying stream. pub stream: T, /// An optional join handle that processes data on the other side of the stream. - pub join_handle: Option, + pub join_handle: Option>>, } impl StreamHandle { @@ -445,24 +445,28 @@ impl StreamApi { let (read_stream, write_stream) = tokio::io::split(stream_handle.stream); let cancellation_token = CancellationToken::new(); - let read_handle = - handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx) - .into(); + let read_handle = AbortOnDropHandle::new(handlers::spawn_read_handler( + cancellation_token.clone(), + read_stream, + read_output_tx, + )); - let write_handle = - handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx) - .into(); + let write_handle = AbortOnDropHandle::new(handlers::spawn_write_handler( + cancellation_token.clone(), + write_stream, + write_input_rx, + )); - let processing_handle = handlers::spawn_processing_handler( + let processing_handle = AbortOnDropHandle::new(handlers::spawn_processing_handler( cancellation_token.clone(), read_output_rx, decoded_packet_tx, - ) - .into(); + )); - let heartbeat_handle = - handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_input_tx.clone()) - .into(); + let heartbeat_handle = AbortOnDropHandle::new(handlers::spawn_heartbeat_handler( + cancellation_token.clone(), + write_input_tx.clone(), + )); // Return channel for receiving decoded packets diff --git a/src/connections/wrappers.rs b/src/connections/wrappers.rs index 407bcf5..b136b4e 100644 --- a/src/connections/wrappers.rs +++ b/src/connections/wrappers.rs @@ -1,9 +1,3 @@ -use std::ops::Deref; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use tokio::task::{JoinError, JoinHandle}; - use crate::errors_internal::Error; /// A helper struct representing the ID of a node in the mesh. @@ -247,73 +241,5 @@ pub mod mesh_channel { } } -/// A wrapper around a [`JoinHandle`] that will automatically abort the underlying task -/// when it is dropped. -#[derive(Debug)] -pub struct AbortingJoinHandle(JoinHandle>); - -impl From>> for AbortingJoinHandle { - fn from(handle: JoinHandle>) -> Self { - Self(handle) - } -} - -impl Deref for AbortingJoinHandle { - type Target = JoinHandle>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Drop for AbortingJoinHandle { - fn drop(&mut self) { - self.0.abort(); - } -} - -impl std::future::Future for AbortingJoinHandle { - type Output = Result, JoinError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) - } -} - #[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_aborting_join_handle() { - let (mut tx, rx) = tokio::sync::oneshot::channel::<()>(); - let handle = tokio::spawn(async move { - let _rx = rx; - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - Ok(()) - }); - - let aborting_handle: AbortingJoinHandle = handle.into(); - drop(aborting_handle); - - // If the task was aborted, the receiver `rx` will be dropped, causing `tx.is_closed()` to - // become true. - tokio::time::timeout(std::time::Duration::from_secs(5), tx.closed()) - .await - .expect("Task was not aborted within timeout"); - } - - #[tokio::test] - async fn test_aborting_join_handle_completion() { - let handle = tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - Ok(()) - }); - - let aborting_handle: AbortingJoinHandle = handle.into(); - - let result = aborting_handle.await; - assert!(result.is_ok()); - assert!(result.unwrap().is_ok()); - } -} +mod tests {} diff --git a/src/lib.rs b/src/lib.rs index 1ab45b9..3a7e141 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,6 @@ pub mod api { pub use crate::connections::stream_api::ConnectedStreamApi; pub use crate::connections::stream_api::StreamApi; pub use crate::connections::stream_api::StreamHandle; - pub use crate::connections::wrappers::AbortingJoinHandle; } /// This module contains the global `Error` type of the library. This enum implements diff --git a/src/utils_internal.rs b/src/utils_internal.rs index 5454fea..3575612 100644 --- a/src/utils_internal.rs +++ b/src/utils_internal.rs @@ -351,7 +351,7 @@ where Ok(StreamHandle { stream: client, - join_handle: Some(handle.into()), + join_handle: Some(tokio_util::task::AbortOnDropHandle::new(handle)), }) }