@@ -761,25 +761,35 @@ def presence(self, pubnub, presence):
761761 """
762762 self .presence_queue .put_nowait (presence )
763763
764- async def _wait_for (self , coro ):
764+ async def _wait_for (self , coro , timeout = 30 ):
765765 """Wait for a coroutine to complete.
766766
767767 Args:
768768 coro: The coroutine to wait for
769+ timeout: Maximum time to wait in seconds (default: 30)
769770
770771 Returns:
771772 The result of the coroutine
772773
773774 Raises:
775+ asyncio.TimeoutError: If the operation times out
774776 Exception: If an error occurs while waiting
775777 """
776778 scc_task = asyncio .ensure_future (coro )
777779 err_task = asyncio .ensure_future (self .error_queue .get ())
778780
779- await asyncio .wait ([
781+ done , pending = await asyncio .wait ([
780782 scc_task ,
781783 err_task
782- ], return_when = asyncio .FIRST_COMPLETED )
784+ ], return_when = asyncio .FIRST_COMPLETED , timeout = timeout )
785+
786+ # Handle timeout
787+ if not done :
788+ if not scc_task .cancelled ():
789+ scc_task .cancel ()
790+ if not err_task .cancelled ():
791+ err_task .cancel ()
792+ raise asyncio .TimeoutError (f"Operation timed out after { timeout } seconds" )
783793
784794 if err_task .done () and not scc_task .done ():
785795 if not scc_task .cancelled ():
@@ -790,44 +800,73 @@ async def _wait_for(self, coro):
790800 err_task .cancel ()
791801 return scc_task .result ()
792802
793- async def wait_for_connect (self ):
794- """Wait for a connection to be established."""
803+ async def wait_for_connect (self , timeout = 30 ):
804+ """Wait for a connection to be established.
805+
806+ Args:
807+ timeout: Maximum time to wait in seconds (default: 30)
808+
809+ Raises:
810+ asyncio.TimeoutError: If connection is not established within timeout
811+ """
795812 if not self .connected_event .is_set ():
796- await self ._wait_for (self .connected_event .wait ())
813+ await self ._wait_for (self .connected_event .wait (), timeout = timeout )
814+
815+ async def wait_for_disconnect (self , timeout = 30 ):
816+ """Wait for a disconnection to occur.
817+
818+ Args:
819+ timeout: Maximum time to wait in seconds (default: 30)
797820
798- async def wait_for_disconnect (self ):
799- """Wait for a disconnection to occur."""
821+ Raises:
822+ asyncio.TimeoutError: If disconnection does not occur within timeout
823+ """
800824 if not self .disconnected_event .is_set ():
801- await self ._wait_for (self .disconnected_event .wait ())
825+ await self ._wait_for (self .disconnected_event .wait (), timeout = timeout )
802826
803- async def wait_for_message_on (self , * channel_names ):
827+ async def wait_for_message_on (self , * channel_names , timeout = 30 ):
804828 """Wait for a message on specific channels.
805829
806830 Args:
807831 *channel_names: Channel names to wait for
832+ timeout: Maximum time to wait in seconds (default: 30)
808833
809834 Returns:
810835 The message envelope when received
811836
812837 Raises:
838+ asyncio.TimeoutError: If no message is received within timeout
813839 Exception: If an error occurs while waiting
814840 """
815841 channel_names = list (channel_names )
816842 while True :
817843 try :
818- env = await self ._wait_for (self .message_queue .get ())
844+ env = await self ._wait_for (self .message_queue .get (), timeout = timeout )
819845 if env .channel in channel_names :
820846 return env
821847 else :
822848 continue
823849 finally :
824850 self .message_queue .task_done ()
825851
826- async def wait_for_presence_on (self , * channel_names ):
852+ async def wait_for_presence_on (self , * channel_names , timeout = 30 ):
853+ """Wait for a presence event on specific channels.
854+
855+ Args:
856+ *channel_names: Channel names to wait for
857+ timeout: Maximum time to wait in seconds (default: 30)
858+
859+ Returns:
860+ The presence envelope when received
861+
862+ Raises:
863+ asyncio.TimeoutError: If no presence event is received within timeout
864+ Exception: If an error occurs while waiting
865+ """
827866 channel_names = list (channel_names )
828867 while True :
829868 try :
830- env = await self ._wait_for (self .presence_queue .get ())
869+ env = await self ._wait_for (self .presence_queue .get (), timeout = timeout )
831870 if env .channel in channel_names :
832871 return env
833872 else :
0 commit comments