diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index f4e96607..44a5ca46 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -472,12 +472,15 @@ def _prepare_config( # LIFECYCLE METHODS # ============================================================================ - async def connect(self) -> None: + async def connect(self, ws_headers: Optional[dict] = None) -> None: """Connect to the Speechmatics API. Establishes WebSocket connection and starts the transcription session. This must be called before sending audio. + Args: + ws_headers: Optional headers to pass to the WebSocket connection. + Raises: Exception: If connection fails. @@ -521,6 +524,7 @@ async def connect(self) -> None: await self.start_session( transcription_config=self._transcription_config, audio_format=self._audio_format, + ws_headers=ws_headers, ) self._is_connected = True self._start_metrics_task() diff --git a/tests/voice/test_06_stt_config.py b/tests/voice/test_06_stt_config.py index cc7b3cac..77e44ab4 100644 --- a/tests/voice/test_06_stt_config.py +++ b/tests/voice/test_06_stt_config.py @@ -1,8 +1,78 @@ +import os + import pytest +from _utils import get_client + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") + + +@pytest.mark.asyncio +async def test_with_headers(): + """Tests that a client can be created. + + - Checks for a valid session + - Checks that 'English' is the language pack info + """ + + # API key + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Create client + client = await get_client( + api_key=API_KEY, + connect=False, + ) + + # Headers + ws_headers = {"Z-TEST-HEADER-1": "ValueOne", "Z-TEST-HEADER-2": "ValueTwo"} + + # Check we are connected OK + await client.connect(ws_headers=ws_headers) + + # Check we are connected + assert client._is_connected + + # Disconnect + await client.disconnect() + + # Check we are disconnected + assert not client._is_connected @pytest.mark.asyncio -async def test_no_partials(): - """Tests for STT config (no partials).""" +async def test_with_corrupted_headers(): + """Tests that a client can be created. + + - Checks for a valid session + - Checks that 'English' is the language pack info + """ + + # API key + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Create client + client = await get_client( + api_key=API_KEY, + connect=False, + ) + + # Headers + ws_headers = ["ItemOne", "ItemTwo"] + + # Check we are connected OK + try: + await client.connect(ws_headers=ws_headers) + except AttributeError: + pass + + # Check we are connected + assert not client._is_connected + + # Disconnect (in case connected) + await client.disconnect() - pass + # Check we are disconnected + assert not client._is_connected