diff --git a/sdk/rt/speechmatics/rt/_async_client.py b/sdk/rt/speechmatics/rt/_async_client.py index 5e581e15..3a97df6d 100644 --- a/sdk/rt/speechmatics/rt/_async_client.py +++ b/sdk/rt/speechmatics/rt/_async_client.py @@ -97,6 +97,10 @@ def __init__( self.on(ServerMessageType.WARNING, self._on_warning) self.on(ServerMessageType.AUDIO_ADDED, self._on_audio_added) + # Audio format is set when start_session is called with an explicit format. + # Deliberately None until then to avoid silently using incorrect defaults. + self._audio_format: Optional[AudioFormat] = None + self._logger.debug("AsyncClient initialized (request_id=%s)", self._session.request_id) async def start_session( @@ -133,7 +137,10 @@ async def start_session( ... await client.start_session() ... await client.send_audio(frame) """ - await self._start_recognition_session( + + # _start_recognition_session resolves defaults (e.g. AudioFormat() if None), + # so we capture the resolved format to keep _audio_format in sync. + _, self._audio_format = await self._start_recognition_session( transcription_config=transcription_config, audio_format=audio_format, translation_config=translation_config, @@ -161,16 +168,27 @@ async def stop_session(self) -> None: await self._session_done_evt.wait() # Wait for end of transcript event to indicate we can stop listening await self.close() - async def force_end_of_utterance(self) -> None: + async def force_end_of_utterance(self, timestamp: Optional[float] = None) -> float: """ This method sends a ForceEndOfUtterance message to the server to signal the end of an utterance. Forcing end of utterance will cause the final transcript to be sent to the client early. + Takes an optional timestamp parameter to specify a marker for the engine + to use for timing of the end of the utterance. If not provided, the timestamp + will be calculated based on the cumulative audio sent to the server. + + Args: + timestamp: Optional timestamp for the request. + + Returns: + The timestamp that was used for the request. + Raises: ConnectionError: If the WebSocket connection fails. TranscriptionError: If the server reports an error during teardown. TimeoutError: If the connection or teardown times out. + ValueError: If the audio format does not have an encoding set. Examples: Basic streaming: @@ -179,7 +197,26 @@ async def force_end_of_utterance(self) -> None: ... await client.send_audio(frame) ... await client.force_end_of_utterance() """ - await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE}) + if timestamp is None: + timestamp = self.audio_seconds_sent + + await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE, "timestamp": timestamp}) + + return timestamp + + @property + def audio_seconds_sent(self) -> float: + """Number of audio seconds sent to the server. + + Raises: + ValueError: If called before start_session has set the audio format, + or if the audio format does not have an encoding set. + """ + # _audio_format is only set once start_session receives an explicit AudioFormat. + # Failing here prevents silently computing with wrong defaults (e.g. 44100Hz). + if self._audio_format is None: + raise ValueError("audio_seconds_sent is not available before start_session is called with an audio format") + return self._audio_bytes_sent / (self._audio_format.sample_rate * self._audio_format.bytes_per_sample) async def transcribe( self, diff --git a/sdk/rt/speechmatics/rt/_base_client.py b/sdk/rt/speechmatics/rt/_base_client.py index 0ac6d085..89167e20 100644 --- a/sdk/rt/speechmatics/rt/_base_client.py +++ b/sdk/rt/speechmatics/rt/_base_client.py @@ -42,6 +42,7 @@ def __init__(self, transport: Transport) -> None: self._recv_task: Optional[asyncio.Task[None]] = None self._closed_evt = asyncio.Event() self._eos_sent = False + self._audio_bytes_sent = 0 self._seq_no = 0 self._logger = get_logger("speechmatics.rt.base_client") @@ -122,11 +123,17 @@ async def send_audio(self, payload: bytes) -> None: try: await self._transport.send_message(payload) + self._audio_bytes_sent += len(payload) self._seq_no += 1 except Exception: self._closed_evt.set() raise + @property + def audio_bytes_sent(self) -> int: + """Number of audio bytes sent to the server.""" + return self._audio_bytes_sent + async def send_message(self, message: dict[str, Any]) -> None: """ Send a message through the WebSocket. diff --git a/sdk/rt/speechmatics/rt/_models.py b/sdk/rt/speechmatics/rt/_models.py index 84e57204..d1f6acbf 100644 --- a/sdk/rt/speechmatics/rt/_models.py +++ b/sdk/rt/speechmatics/rt/_models.py @@ -183,6 +183,29 @@ class AudioFormat: sample_rate: int = 44100 chunk_size: int = 4096 + _BYTES_PER_SAMPLE = { + AudioEncoding.PCM_F32LE: 4, + AudioEncoding.PCM_S16LE: 2, + AudioEncoding.MULAW: 1, + } + + @property + def bytes_per_sample(self) -> int: + """Number of bytes per audio sample based on encoding. + + Raises: + ValueError: If encoding is None (file type) or unrecognized. + """ + if self.encoding is None: + raise ValueError( + "Cannot determine bytes per sample for file-type audio format. " + "Set an explicit encoding on AudioFormat." + ) + try: + return self._BYTES_PER_SAMPLE[self.encoding] + except KeyError: + raise ValueError(f"Unknown encoding: {self.encoding}") + def to_dict(self) -> dict[str, Any]: """ Convert audio format to dictionary. diff --git a/sdk/voice/pyproject.toml b/sdk/voice/pyproject.toml index 9006bd1f..a8decb29 100644 --- a/sdk/voice/pyproject.toml +++ b/sdk/voice/pyproject.toml @@ -11,7 +11,7 @@ authors = [{ name = "Speechmatics", email = "support@speechmatics.com" }] license = "MIT" requires-python = ">=3.9" dependencies = [ - "speechmatics-rt>=0.5.3", + "speechmatics-rt==0.5.3", "pydantic>=2.10.6,<3", "numpy>=1.26.4,<3" ] diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index c0988dd3..763fb585 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -176,6 +176,10 @@ def __init__( preset_config = VoiceAgentConfigPreset.load(preset) config = VoiceAgentConfigPreset._merge_configs(preset_config, config) + # Validate the final config (deferred to allow overlay/preset merging first) + if config is not None: + config.validate_config() + # Process the config self._config, self._transcription_config, self._audio_format = self._prepare_config(config) @@ -310,20 +314,18 @@ def __init__( self._turn_handler: TurnTaskProcessor = TurnTaskProcessor(name="turn_handler", done_callback=self.finalize) self._eot_calculation_task: Optional[asyncio.Task] = None - # Uses fixed EndOfUtterance message from STT - self._uses_fixed_eou: bool = ( - self._eou_mode == EndOfUtteranceMode.FIXED - and not self._silero_detector - and not self._config.end_of_turn_config.use_forced_eou - ) - - # Uses ForceEndOfUtterance message - self._uses_forced_eou: bool = not self._uses_fixed_eou + # Forced end of utterance handling + # FEOU is not used in FIXED mode, unless VAD has been enabled. It can / should + # also be disabled during testing when not connected to an endpoint, as the + # waiting for FEOU response will block the test. + self._use_forced_eou: bool = self._eou_mode is not EndOfUtteranceMode.FIXED or self._uses_silero_vad self._forced_eou_active: bool = False self._last_forced_eou_latency: float = 0.0 - # Emit EOT prediction (uses _uses_forced_eou) - self._uses_eot_prediction: bool = self._eou_mode not in [ + # Emit EOT prediction + # EOT predictions are only relevant when not using the FIXED or EXTERNAL modes, + # as these use different triggers to finalize the turn. + self._emit_eot_predictions: bool = self._eou_mode not in [ EndOfUtteranceMode.FIXED, EndOfUtteranceMode.EXTERNAL, ] @@ -360,8 +362,8 @@ def __init__( AudioEncoding.PCM_S16LE: 2, }.get(self._audio_format.encoding, 1) - # Default audio buffer - if not self._config.audio_buffer_length and (self._uses_smart_turn or self._uses_silero_vad): + # Default audio buffer (used when Silero VAD is enabled and with Smart Turn) + if not self._config.audio_buffer_length and self._uses_silero_vad: self._config.audio_buffer_length = 15.0 # Audio buffer @@ -447,9 +449,7 @@ def _prepare_config( ) # Fixed end of Utterance - if bool( - config.end_of_utterance_mode == EndOfUtteranceMode.FIXED and not config.end_of_turn_config.use_forced_eou - ): + if config.end_of_utterance_mode == EndOfUtteranceMode.FIXED: transcription_config.conversation_config = ConversationConfig( end_of_utterance_silence_trigger=config.end_of_utterance_silence_trigger, ) @@ -659,8 +659,14 @@ async def send_audio(self, payload: bytes) -> None: return # Process with Silero VAD - if self._silero_detector: - asyncio.create_task(self._silero_detector.process_audio(payload)) + if self._uses_silero_vad and self._silero_detector is not None: + asyncio.create_task( + self._silero_detector.process_audio( + payload, + sample_rate=self._audio_sample_rate, + sample_width=self._audio_sample_width, + ) + ) # Add to audio buffer (use put_bytes to handle variable chunk sizes) if self._config.audio_buffer_length > 0: @@ -717,14 +723,11 @@ def update_diarization_config(self, config: SpeakerFocusConfig) -> None: # PUBLIC UTTERANCE / TURN MANAGEMENT # ============================================================================ - def finalize(self, end_of_turn: bool = False) -> None: + def finalize(self) -> None: """Finalize segments. This function will emit segments in the buffer without any further checks on the contents of the segments. - - Args: - end_of_turn: Whether to emit an end of turn message. """ # Clear smart turn cutoff @@ -738,7 +741,7 @@ async def emit() -> None: """Wait for EndOfUtterance if needed, then emit segments.""" # Forced end of utterance message (only when no speaker is detected) - if self._config.end_of_turn_config.use_forced_eou: + if self._use_forced_eou: await self._await_forced_eou() # Check if the turn has changed @@ -749,7 +752,7 @@ async def emit() -> None: self._stt_message_queue.put_nowait(lambda: self._emit_segments(finalize=True, is_eou=True)) # Call async task (only if not already waiting for forced EOU) - if not (self._config.end_of_turn_config.use_forced_eou and self._forced_eou_active): + if not self._forced_eou_active: asyncio.create_task(emit()) # ============================================================================ @@ -788,8 +791,8 @@ def _evt_on_final_transcript(message: dict[str, Any]) -> None: return self._stt_message_queue.put_nowait(lambda: self._handle_transcript(message, is_final=True)) - # End of Utterance (FIXED mode only) - if self._uses_fixed_eou: + # End of Utterance - only when not using ForceEndOfUtterance messages + if not self._use_forced_eou: @self.on(ServerMessageType.END_OF_UTTERANCE) # type: ignore[misc] def _evt_on_end_of_utterance(message: dict[str, Any]) -> None: @@ -1121,7 +1124,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool = self._last_fragment_end_time = max(self._last_fragment_end_time, fragment.end_time) # Evaluate for VAD (only done on partials) - await self._vad_evaluation(fragments, is_final=is_final) + await self._speaker_start_stop_evaluation(fragments, is_final=is_final) # Fragments to retain retained_fragments = [ @@ -1205,18 +1208,8 @@ async def _process_speech_fragments(self, change_filter: Optional[list[Annotatio if change_filter and not changes.any(*change_filter): return - # Skip re-evaluation if transcripts are older than smart turn cutoff - if self._smart_turn_pending_cutoff is not None and self._current_view: - latest_end_time = max( - (f.end_time for f in self._current_view.fragments if f.end_time is not None), default=0.0 - ) - - # If all fragments end before or at the cutoff, skip re-evaluation - if latest_end_time <= self._smart_turn_pending_cutoff: - return - # Turn prediction - if self._uses_eot_prediction and self._uses_forced_eou and not self._forced_eou_active: + if self._emit_eot_predictions and not self._forced_eou_active and self._use_forced_eou: async def fn() -> None: ttl = await self._calculate_finalize_delay() @@ -1518,7 +1511,7 @@ async def _calculate_finalize_delay( annotation = annotation or AnnotationResult() # VAD enabled - if self._silero_detector: + if self._uses_silero_vad: annotation.add(AnnotationFlags.VAD_ACTIVE) else: annotation.add(AnnotationFlags.VAD_INACTIVE) @@ -1526,6 +1519,12 @@ async def _calculate_finalize_delay( # Smart Turn enabled if self._smart_turn_detector: annotation.add(AnnotationFlags.SMART_TURN_ACTIVE) + # If Smart Turn hasn't returned a result yet but is enabled, add NO_SIGNAL annotation. + # This covers the case where the TTL fires before VAD triggers Smart Turn inference. + if not annotation.has(AnnotationFlags.SMART_TURN_TRUE) and not annotation.has( + AnnotationFlags.SMART_TURN_FALSE + ): + annotation.add(AnnotationFlags.SMART_TURN_NO_SIGNAL) else: annotation.add(AnnotationFlags.SMART_TURN_INACTIVE) @@ -1551,8 +1550,7 @@ async def _calculate_finalize_delay( delay = round(self._config.end_of_utterance_silence_trigger * multiplier, 3) # Trim off the most recent forced EOU delay if we're in forced EOU mode - if self._uses_forced_eou: - delay -= self._last_forced_eou_latency + delay -= self._last_forced_eou_latency # Clamp to max delay and adjust for TTFB clamped_delay = min(delay, self._config.end_of_utterance_max_delay) @@ -1586,7 +1584,10 @@ async def _eot_prediction( # Wait for Smart Turn result if self._smart_turn_detector and end_time is not None: result = await self._smart_turn_prediction(end_time, self._config.language, speaker=speaker) - if result.prediction: + if result.error: + # No valid prediction — SMART_TURN_NO_SIGNAL will be applied by _calculate_finalize_delay + pass + elif result.prediction: annotation.add(AnnotationFlags.SMART_TURN_TRUE) else: annotation.add(AnnotationFlags.SMART_TURN_FALSE) @@ -1676,9 +1677,6 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: # Add listener self.once(AgentServerMessageType.END_OF_UTTERANCE, lambda message: eou_received.set()) - # Trigger EOU message - self._emit_diagnostic_message("ForceEndOfUtterance sent - waiting for EndOfUtterance") - # Wait for EOU try: # Track the start time @@ -1686,7 +1684,10 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: self._forced_eou_active = True # Send the force EOU and wait for the response - await self.force_end_of_utterance() + timestamp = await self.force_end_of_utterance() + self._emit_diagnostic_message(f"ForceEndOfUtterance sent - waiting for EndOfUtterance ({timestamp=})") + + # Wait for the response await asyncio.wait_for(eou_received.wait(), timeout=timeout) # Record the latency @@ -1702,7 +1703,7 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: # VAD (VOICE ACTIVITY DETECTION) / SPEAKER DETECTION # ============================================================================ - async def _vad_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None: + async def _speaker_start_stop_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None: """Emit a VAD event. This will emit `SPEAKER_STARTED` and `SPEAKER_ENDED` events to the client and is @@ -1850,18 +1851,20 @@ def _handle_silero_vad_result(self, result: SileroVADResult) -> None: annotation.add(AnnotationFlags.VAD_STARTED) # If speech has ended, we need to predict the end of turn - if result.speech_ended and self._uses_eot_prediction: + if self._emit_eot_predictions and result.speech_ended: """VAD-based end of turn prediction.""" # Set cutoff to prevent late transcripts from cancelling finalization self._smart_turn_pending_cutoff = event_time + # Async callback async def fn() -> None: ttl = await self._eot_prediction( end_time=event_time, speaker=self._current_speaker, annotation=annotation ) self._turn_handler.update_timer(ttl) + # Call the eot calculation asynchronously self._run_background_eot_calculation(fn, "silero_vad") async def _handle_speaker_started(self, speaker: Optional[str], event_time: float) -> None: @@ -1878,8 +1881,7 @@ async def _handle_speaker_started(self, speaker: Optional[str], event_time: floa await self._emit_start_of_turn(event_time) # Update the turn handler - if self._uses_forced_eou: - self._turn_handler.reset() + self._turn_handler.reset() # Emit the event self._emit_message( @@ -1902,7 +1904,7 @@ async def _handle_speaker_stopped(self, speaker: Optional[str], event_time: floa self._last_speak_end_latency = self._total_time - event_time # Turn prediction - if self._uses_eot_prediction and not self._forced_eou_active: + if self._emit_eot_predictions and not self._forced_eou_active: async def fn() -> None: ttl = await self._eot_prediction(event_time, speaker) diff --git a/sdk/voice/speechmatics/voice/_models.py b/sdk/voice/speechmatics/voice/_models.py index b4a432c2..c58a7ca6 100644 --- a/sdk/voice/speechmatics/voice/_models.py +++ b/sdk/voice/speechmatics/voice/_models.py @@ -13,7 +13,6 @@ from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict from pydantic import Field -from pydantic import model_validator from typing_extensions import Self from speechmatics.rt import AudioEncoding @@ -261,6 +260,7 @@ class AnnotationFlags(str, Enum): SMART_TURN_INACTIVE = "smart_turn_inactive" SMART_TURN_TRUE = "smart_turn_true" SMART_TURN_FALSE = "smart_turn_false" + SMART_TURN_NO_SIGNAL = "smart_turn_no_signal" # ============================================================================== @@ -410,35 +410,57 @@ class EndOfTurnConfig(BaseModel): base_multiplier: Base multiplier for end of turn delay. min_end_of_turn_delay: Minimum end of turn delay. penalties: List of end of turn penalty items. - use_forced_eou: Whether to use forced end of utterance detection. + use_forced_eou: Whether to use forced end of utterance detection. (SHOULD ONLY EVER BE TRUE) """ base_multiplier: float = 1.0 min_end_of_turn_delay: float = 0.01 penalties: list[EndOfTurnPenaltyItem] = Field( default_factory=lambda: [ - # Increase delay + # + # Speaker rate increases expected TTL EndOfTurnPenaltyItem(penalty=3.0, annotation=[AnnotationFlags.VERY_SLOW_SPEAKER]), EndOfTurnPenaltyItem(penalty=2.0, annotation=[AnnotationFlags.SLOW_SPEAKER]), + # + # High / low rate of disfluencies EndOfTurnPenaltyItem(penalty=2.5, annotation=[AnnotationFlags.ENDS_WITH_DISFLUENCY]), EndOfTurnPenaltyItem(penalty=1.1, annotation=[AnnotationFlags.HAS_DISFLUENCY]), + # + # We do NOT have an end of sentence character EndOfTurnPenaltyItem( penalty=2.0, annotation=[AnnotationFlags.ENDS_WITH_EOS], is_not=True, ), - # Decrease delay + # + # We have finals and end of sentence EndOfTurnPenaltyItem( penalty=0.5, annotation=[AnnotationFlags.ENDS_WITH_FINAL, AnnotationFlags.ENDS_WITH_EOS] ), - # Smart Turn + VAD - EndOfTurnPenaltyItem(penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE]), + # + # Smart Turn - when false, wait longer to prevent premature end of turn EndOfTurnPenaltyItem( - penalty=0.2, annotation=[AnnotationFlags.VAD_STOPPED, AnnotationFlags.SMART_TURN_INACTIVE] + penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE, AnnotationFlags.SMART_TURN_ACTIVE] + ), + EndOfTurnPenaltyItem( + penalty=2.0, annotation=[AnnotationFlags.SMART_TURN_FALSE, AnnotationFlags.SMART_TURN_ACTIVE] + ), + EndOfTurnPenaltyItem( + penalty=1.5, annotation=[AnnotationFlags.SMART_TURN_NO_SIGNAL, AnnotationFlags.SMART_TURN_ACTIVE] + ), + # + # VAD - only applied when smart turn is not in use and on the speaker stopping + EndOfTurnPenaltyItem( + penalty=0.2, + annotation=[ + AnnotationFlags.VAD_STOPPED, + AnnotationFlags.VAD_ACTIVE, + AnnotationFlags.SMART_TURN_INACTIVE, + ], ), ] ) - use_forced_eou: bool = False + use_forced_eou: bool = True class VoiceActivityConfig(BaseModel): @@ -711,10 +733,16 @@ class VoiceAgentConfig(BaseModel): audio_encoding: AudioEncoding = AudioEncoding.PCM_S16LE chunk_size: int = 160 - # Validation - @model_validator(mode="after") # type: ignore[misc] - def validate_config(self) -> Self: - """Validate the configuration.""" + def validate_config(self) -> None: + """Validate the configuration. + + Cross-field validation is deferred to this method so that configs can be + constructed as overlays (e.g. for presets) without triggering validation + on intermediate states. Call this once the final config is ready. + + Raises: + ValueError: If any validation errors are found. + """ # Validation errors errors: list[str] = [] @@ -723,12 +751,6 @@ def validate_config(self) -> Self: if self.end_of_utterance_mode == EndOfUtteranceMode.EXTERNAL and self.smart_turn_config: errors.append("EXTERNAL mode cannot be used in conjunction with SmartTurnConfig") - # Cannot have FIXED and forced end of utterance enabled without VAD being enabled - if (self.end_of_utterance_mode == EndOfUtteranceMode.FIXED and self.end_of_turn_config.use_forced_eou) and not ( - self.vad_config and self.vad_config.enabled - ): - errors.append("FIXED mode cannot be used in conjunction with forced end of utterance without VAD enabled") - # Cannot use VAD with external end of utterance mode if self.end_of_utterance_mode == EndOfUtteranceMode.EXTERNAL and (self.vad_config and self.vad_config.enabled): errors.append("EXTERNAL mode cannot be used in conjunction with VAD being enabled") @@ -751,13 +773,14 @@ def validate_config(self) -> Self: if self.sample_rate not in [8000, 16000]: errors.append("sample_rate must be 8000 or 16000") + # Check that forced end of utterance is set to True + if not self.end_of_turn_config.use_forced_eou: + errors.append("EndOfTurnConfig.use_forced_eou cannot be False") + # Raise error if any validation errors if errors: raise ValueError(f"{len(errors)} config error(s): {'; '.join(errors)}") - # Return validated config - return self - # ============================================================================== # SESSION & INFO MODELS diff --git a/sdk/voice/speechmatics/voice/_presets.py b/sdk/voice/speechmatics/voice/_presets.py index 2bcb092f..09d47953 100644 --- a/sdk/voice/speechmatics/voice/_presets.py +++ b/sdk/voice/speechmatics/voice/_presets.py @@ -6,7 +6,6 @@ from typing import Optional -from ._models import EndOfTurnConfig from ._models import EndOfUtteranceMode from ._models import OperatingPoint from ._models import SmartTurnConfig @@ -82,7 +81,6 @@ def ADAPTIVE(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: # end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), vad_config=VoiceActivityConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -114,7 +112,6 @@ def SMART_TURN(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: enabled=True, ), vad_config=VoiceActivityConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -175,7 +172,6 @@ def EXTERNAL(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: # max_delay=2.0, end_of_utterance_mode=EndOfUtteranceMode.EXTERNAL, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -232,4 +228,10 @@ def _merge_configs(base: VoiceAgentConfig, overlay: Optional[VoiceAgentConfig]) **base.model_dump(exclude_unset=True, exclude_none=True), **overlay.model_dump(exclude_unset=True, exclude_none=True), } - return VoiceAgentConfig.from_dict(merged_dict) + config = VoiceAgentConfig.from_dict(merged_dict) + + # Validate the merged config + config.validate_config() + + # Return the merged config + return config diff --git a/sdk/voice/speechmatics/voice/_smart_turn.py b/sdk/voice/speechmatics/voice/_smart_turn.py index 9ce44a03..529f4653 100644 --- a/sdk/voice/speechmatics/voice/_smart_turn.py +++ b/sdk/voice/speechmatics/voice/_smart_turn.py @@ -196,13 +196,21 @@ async def predict( # Convert int16 to float32 in range [-1, 1] (same as reference implementation) float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0 + # Whisper's feature extractor requires 16kHz audio. Resample if needed. + target_rate = 16000 + if sample_rate != target_rate: + float32_array = self._resample(float32_array, sample_rate, target_rate) + + # After resampling, max_samples is relative to 16kHz + max_samples_16k = 8 * target_rate + # Process audio using Whisper's feature extractor inputs = self.feature_extractor( float32_array, - sampling_rate=sample_rate, + sampling_rate=target_rate, return_tensors="np", padding="max_length", - max_length=max_samples, + max_length=max_samples_16k, truncation=True, do_normalize=True, ) @@ -230,6 +238,44 @@ async def predict( processing_time=round(float((end_time - start_time).total_seconds()), 3), ) + @staticmethod + def _resample(audio: np.ndarray, orig_rate: int, target_rate: int) -> np.ndarray: + """Resample audio using FFT-based method (zero-pad in frequency domain). + + This produces higher quality resampling than linear interpolation by + preserving the original spectral content without aliasing artifacts. + + Args: + audio: Float32 numpy array of audio samples. + orig_rate: Original sample rate. + target_rate: Target sample rate. + + Returns: + Resampled float32 numpy array. + """ + if orig_rate == target_rate: + return audio + + n_orig = len(audio) + n_target = int(n_orig * target_rate / orig_rate) + + # FFT of original signal + fft = np.fft.rfft(audio) + + # Create zero-padded FFT array for target length + n_fft_target = n_target // 2 + 1 + new_fft = np.zeros(n_fft_target, dtype=complex) + + # Copy original frequency bins (preserves spectral content) + copy_len = min(len(fft), n_fft_target) + new_fft[:copy_len] = fft[:copy_len] + + # Inverse FFT at target length, scale to preserve amplitude + resampled = np.fft.irfft(new_fft, n=n_target) + resampled *= n_target / n_orig + + return resampled.astype(np.float32) + @staticmethod def truncate_audio_to_last_n_seconds( audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = 16000 diff --git a/sdk/voice/speechmatics/voice/_vad.py b/sdk/voice/speechmatics/voice/_vad.py index e5a7b1e8..a65502d5 100644 --- a/sdk/voice/speechmatics/voice/_vad.py +++ b/sdk/voice/speechmatics/voice/_vad.py @@ -38,12 +38,16 @@ # Hint for when dependencies are not available SILERO_INSTALL_HINT = "Silero VAD unavailable. Install `speechmatics-voice[smart]` to enable VAD." -# Silero VAD constants -SILERO_SAMPLE_RATE = 16000 -SILERO_CHUNK_SIZE = 512 # Silero expects 512 samples at 16kHz (32ms chunks) -SILERO_CONTEXT_SIZE = 64 # Silero uses 64-sample context +# Silero VAD supported sample rates (see https://github.com/snakers4/silero-vad) +SILERO_SUPPORTED_SAMPLE_RATES = [8000, 16000] + +# Chunk and context sizes differ by sample rate. +# Both result in ~32ms chunks: 512/16000 = 256/8000 = 0.032s +SILERO_CHUNK_SIZES = {16000: 512, 8000: 256} +SILERO_CONTEXT_SIZES = {16000: 64, 8000: 32} + MODEL_RESET_STATES_TIME = 5.0 # Reset state every 5 seconds -SILERO_CHUNK_DURATION_MS = (SILERO_CHUNK_SIZE / SILERO_SAMPLE_RATE) * 1000 # 32ms per chunk +SILERO_CHUNK_DURATION_MS = 32.0 # Both sample rates produce 32ms chunks class SileroVADResult(BaseModel): @@ -70,7 +74,7 @@ class SileroVAD: """Silero Voice Activity Detector. Uses Silero's opensource VAD model for detecting speech vs silence. - Processes audio in 512-sample chunks at 16kHz. + Supports 8kHz (256-sample chunks) and 16kHz (512-sample chunks). Further information at https://github.com/snakers4/silero-vad """ @@ -172,56 +176,72 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession: # Return the new session return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"], sess_options=so) - def _init_states(self) -> None: - """Initialize or reset internal VAD states.""" + def _init_states(self, sample_rate: int = 16000) -> None: + """Initialize or reset internal VAD states. + + Args: + sample_rate: Audio sample rate, used to determine context size. + """ + context_size = SILERO_CONTEXT_SIZES.get(sample_rate, 64) self._state = np.zeros((2, 1, 128), dtype=np.float32) - self._context = np.zeros((1, SILERO_CONTEXT_SIZE), dtype=np.float32) + self._context = np.zeros((1, context_size), dtype=np.float32) + self._last_sr: int = sample_rate self._last_reset_time = time.time() - def _maybe_reset_states(self) -> None: + def _maybe_reset_states(self, sample_rate: int) -> None: """Reset ONNX model states periodically to prevent drift. + Also resets if the sample rate changes between calls. + Note: Does NOT reset prediction window or speech state tracking. """ - if (time.time() - self._last_reset_time) >= MODEL_RESET_STATES_TIME: - self._state = np.zeros((2, 1, 128), dtype=np.float32) - self._context = np.zeros((1, SILERO_CONTEXT_SIZE), dtype=np.float32) - self._last_reset_time = time.time() + # Reset if sample rate changed (context size depends on it) + sr_changed = hasattr(self, "_last_sr") and self._last_sr != sample_rate + time_expired = (time.time() - self._last_reset_time) >= MODEL_RESET_STATES_TIME + + if sr_changed or time_expired: + self._init_states(sample_rate) + + def process_chunk(self, chunk_f32: np.ndarray, sample_rate: int = 16000) -> float: + """Process a single audio chunk and return speech probability. - def process_chunk(self, chunk_f32: np.ndarray) -> float: - """Process a single 512-sample chunk and return speech probability. + Chunk size depends on sample rate: 512 samples at 16kHz, 256 at 8kHz. Args: - chunk_f32: Float32 numpy array of exactly 512 samples. + chunk_f32: Float32 numpy array of audio samples. + sample_rate: Sample rate of the audio (8000 or 16000). Returns: Speech probability (0.0-1.0). Raises: - ValueError: If chunk is not exactly 512 samples. + ValueError: If chunk size doesn't match expected size for sample rate. """ - # Ensure shape (1, 512) + # Expected sizes depend on sample rate (512 @ 16kHz, 256 @ 8kHz) + expected_chunk_size = SILERO_CHUNK_SIZES.get(sample_rate, 512) + context_size = SILERO_CONTEXT_SIZES.get(sample_rate, 64) + x = np.reshape(chunk_f32, (1, -1)) - if x.shape[1] != SILERO_CHUNK_SIZE: - raise ValueError(f"Expected {SILERO_CHUNK_SIZE} samples, got {x.shape[1]}") + if x.shape[1] != expected_chunk_size: + raise ValueError(f"Expected {expected_chunk_size} samples for {sample_rate}Hz, got {x.shape[1]}") - # Concatenate with context (previous 64 samples) + # Concatenate with context (previous N samples, where N depends on sample rate) if self._context is not None: x = np.concatenate((self._context, x), axis=1) - # Run ONNX inference + # Run ONNX inference — pass actual sample rate so the model uses correct internal params ort_inputs = { "input": x.astype(np.float32), "state": self._state, - "sr": np.array(SILERO_SAMPLE_RATE, dtype=np.int64), + "sr": np.array(sample_rate, dtype=np.int64), } out, self._state = self.session.run(None, ort_inputs) - # Update context (keep last 64 samples) - self._context = x[:, -SILERO_CONTEXT_SIZE:] + # Update context (keep last N samples for next chunk) + self._context = x[:, -context_size:] # Maybe reset states periodically - self._maybe_reset_states() + self._maybe_reset_states(sample_rate) # Return probability (out shape is (1, 1)) return float(out[0][0]) @@ -229,12 +249,13 @@ def process_chunk(self, chunk_f32: np.ndarray) -> float: async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, sample_width: int = 2) -> None: """Process incoming audio bytes and invoke callback on state changes. - This method buffers incomplete chunks and processes all complete 512-sample chunks. + This method buffers incomplete chunks and processes all complete chunks. + Chunk size depends on sample rate: 512 samples at 16kHz, 256 at 8kHz. The callback is invoked only once at the end if the VAD state changed during processing. Args: audio_bytes: Raw audio bytes (int16 PCM). - sample_rate: Sample rate of the audio (must be 16000). + sample_rate: Sample rate of the audio (8000 or 16000). sample_width: Sample width in bytes (2 for int16). """ @@ -242,15 +263,17 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp logger.error("SileroVAD is not initialized") return - if sample_rate != SILERO_SAMPLE_RATE: - logger.error(f"Sample rate must be {SILERO_SAMPLE_RATE}Hz, got {sample_rate}Hz") + # Silero VAD only supports 8kHz and 16kHz natively + if sample_rate not in SILERO_SUPPORTED_SAMPLE_RATES: + logger.error(f"Sample rate must be one of {SILERO_SUPPORTED_SAMPLE_RATES}Hz, got {sample_rate}Hz") return # Add new bytes to buffer self._audio_buffer += audio_bytes - # Calculate bytes per chunk (512 samples * 2 bytes for int16) - bytes_per_chunk = SILERO_CHUNK_SIZE * sample_width + # Chunk size depends on sample rate (512 @ 16kHz, 256 @ 8kHz) + chunk_samples = SILERO_CHUNK_SIZES[sample_rate] + bytes_per_chunk = chunk_samples * sample_width # Process all complete chunks in buffer while len(self._audio_buffer) >= bytes_per_chunk: @@ -266,8 +289,8 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0 try: - # Process the chunk and add probability to rolling window - probability = self.process_chunk(float32_array) + # Process the chunk with the correct sample rate + probability = self.process_chunk(float32_array, sample_rate=sample_rate) self._prediction_window.append(probability) except Exception as e: @@ -307,10 +330,26 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp # Update state after emitting self._last_is_speech = is_speech - def reset(self) -> None: - """Reset the VAD state and clear audio buffer.""" + @property + def is_speech_likely(self) -> bool: + """Quick check if the most recent raw prediction suggests speech. + + Unlike _last_is_speech which uses a smoothed rolling average (slower to + react), this checks the latest chunk prediction directly — giving faster + speech-onset detection at the cost of more false positives. + """ + if not self._prediction_window: + return self._last_is_speech + return float(self._prediction_window[-1]) >= self._threshold + + def reset(self, sample_rate: int = 16000) -> None: + """Reset the VAD state and clear audio buffer. + + Args: + sample_rate: Sample rate to reinitialise context size for. + """ if self._is_initialized: - self._init_states() + self._init_states(sample_rate) self._audio_buffer = b"" self._prediction_window.clear() self._last_is_speech = False diff --git a/tests/voice/_utils.py b/tests/voice/_utils.py index 8308e905..ad49128a 100644 --- a/tests/voice/_utils.py +++ b/tests/voice/_utils.py @@ -18,7 +18,7 @@ async def get_client( api_key: Optional[str] = None, url: Optional[str] = None, - app: Optional[str] = None, + app: str = "sdk-test", config: Optional[VoiceAgentConfig] = None, connect: bool = True, ) -> VoiceAgentClient: diff --git a/tests/voice/assets/audio_10_16kHz.wav b/tests/voice/assets/audio_10_16kHz.wav new file mode 100644 index 00000000..a6fe0267 Binary files /dev/null and b/tests/voice/assets/audio_10_16kHz.wav differ diff --git a/tests/voice/test_05_utterance.py b/tests/voice/test_05_utterance.py index 9c3c6604..d184ac9a 100644 --- a/tests/voice/test_05_utterance.py +++ b/tests/voice/test_05_utterance.py @@ -10,7 +10,6 @@ from _utils import log_client_messages from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeechSegmentConfig from speechmatics.voice import VoiceAgentConfig @@ -232,11 +231,13 @@ async def test_external_vad(): config=VoiceAgentConfig( end_of_utterance_silence_trigger=adaptive_timeout, end_of_utterance_mode=EndOfUtteranceMode.EXTERNAL, - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) assert client is not None + # Set FEOU to disabled for offline tests + client._use_forced_eou = False + # Start the queue client._start_stt_queue() @@ -335,7 +336,6 @@ async def test_end_of_utterance_adaptive_vad(): end_of_utterance_silence_trigger=adaptive_timeout, end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) assert client is not None @@ -344,6 +344,9 @@ async def test_end_of_utterance_adaptive_vad(): if SHOW_LOG: log_client_messages(client) + # Set FEOU to disabled for offline tests + client._use_forced_eou = False + # Start the queue client._start_stt_queue() diff --git a/tests/voice/test_07_languages.py b/tests/voice/test_07_languages.py index c83428d5..3dc15f0a 100644 --- a/tests/voice/test_07_languages.py +++ b/tests/voice/test_07_languages.py @@ -14,7 +14,6 @@ from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeechSegmentConfig from speechmatics.voice import VoiceAgentConfig @@ -25,7 +24,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") -URL = "wss://eu2.rt.speechmatics.com/v2" +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] @@ -113,22 +112,24 @@ async def test_transcribe_languages(sample: AudioSample): if not API_KEY: pytest.skip("Valid API key required for test") + # Config + config = VoiceAgentConfig( + max_delay=1.2, + end_of_utterance_mode=EndOfUtteranceMode.FIXED, + end_of_utterance_silence_trigger=1.2, + language=sample.language, + additional_vocab=[AdditionalVocabEntry(content=vocab) for vocab in sample.vocab], + speech_segment_config=SpeechSegmentConfig( + emit_sentences=False, + ), + ) + # Client client = await get_client( api_key=API_KEY, url=URL, connect=False, - config=VoiceAgentConfig( - max_delay=1.2, - end_of_utterance_mode=EndOfUtteranceMode.FIXED, - end_of_utterance_silence_trigger=1.2, - language=sample.language, - additional_vocab=[AdditionalVocabEntry(content=vocab) for vocab in sample.vocab], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), - speech_segment_config=SpeechSegmentConfig( - emit_sentences=False, - ), - ), + config=config, ) assert client is not None @@ -188,6 +189,10 @@ def log_segment(message): # Extract the last message assert last_message.get("message") == AgentServerMessageType.ADD_SEGMENT + # Close session + await client.disconnect() + assert not client._is_connected + # Check the segment assert len(segments) >= 1 seg0 = segments[0] @@ -216,7 +221,3 @@ def log_segment(message): print(f"Transcribed: [{str_transcribed}]") print(f"CER: {str_cer}") raise AssertionError("Transcription does not match original") - - # Close session - await client.disconnect() - assert not client._is_connected diff --git a/tests/voice/test_08_multiple_speakers.py b/tests/voice/test_08_multiple_speakers.py index fa662aa5..73a7d299 100644 --- a/tests/voice/test_08_multiple_speakers.py +++ b/tests/voice/test_08_multiple_speakers.py @@ -24,6 +24,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] @@ -116,11 +117,17 @@ async def test_multiple_speakers(sample: SpeakerTest): # Client client = await get_client( + url=URL, api_key=API_KEY, connect=False, config=config, ) + # Debug + if SHOW_LOG: + print(config.to_json(exclude_none=True, exclude_defaults=True, exclude_unset=True, indent=2)) + print(json.dumps(client._transcription_config.to_dict(), indent=2)) + # Create an event to track when the callback is called messages: list[str] = [] bytes_sent: int = 0 @@ -148,19 +155,35 @@ def log_final_segment(message): segments: list[SpeakerSegment] = message["segments"] final_segments.extend(segments) + # Log end of turn + def log_end_of_turn(message): + final_segments.extend([{"speaker_id": "--", "text": "_TURN_"}]) + # Add listeners client.once(AgentServerMessageType.RECOGNITION_STARTED, log_message) client.once(AgentServerMessageType.INFO, log_message) client.on(AgentServerMessageType.WARNING, log_message) client.on(AgentServerMessageType.ERROR, log_message) + client.on(AgentServerMessageType.DIAGNOSTICS, log_message) + + # Transcript + client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, log_message) + client.on(AgentServerMessageType.ADD_TRANSCRIPT, log_message) client.on(AgentServerMessageType.ADD_PARTIAL_SEGMENT, log_message) client.on(AgentServerMessageType.ADD_SEGMENT, log_message) + + # Turn events + client.on(AgentServerMessageType.VAD_STATUS, log_message) client.on(AgentServerMessageType.SPEAKER_STARTED, log_message) client.on(AgentServerMessageType.SPEAKER_ENDED, log_message) + client.on(AgentServerMessageType.START_OF_TURN, log_message) client.on(AgentServerMessageType.END_OF_TURN, log_message) + client.on(AgentServerMessageType.END_OF_TURN_PREDICTION, log_message) + client.on(AgentServerMessageType.END_OF_UTTERANCE, log_message) - # Log ADD_SEGMENT + # Log ADD_SEGMENT + END_OF_TURN client.on(AgentServerMessageType.ADD_SEGMENT, log_final_segment) + client.on(AgentServerMessageType.END_OF_TURN, log_end_of_turn) # HEADER if SHOW_LOG: @@ -187,22 +210,44 @@ def log_final_segment(message): progress_callback=log_bytes_sent, ) + # Close session + await client.disconnect() + # FOOTER if SHOW_LOG: print("---") print() print() + # Print all final_segments + if SHOW_LOG: + print("Final segments:") + for idx, segment in enumerate(final_segments): + print(f"{idx}: [{segment.get('speaker_id')}] {segment.get('text')}") + print() + + # Accumulate errors + errors: list[str] = [] + + # Check number of final segments + if len(final_segments) < len(sample.segment_regex): + errors.append(f"Expected at least {len(sample.segment_regex)} segments, got {len(final_segments)}") + # Check final segments against regex + if SHOW_LOG: + print("Checking final segments against regex:") for idx, _test in enumerate(sample.segment_regex): + text = final_segments[idx].get("text") if idx < len(final_segments) else None + match = text and re.search(_test, text, flags=re.IGNORECASE | re.MULTILINE) if SHOW_LOG: - print(f"`{_test}` -> `{final_segments[idx].get('text')}`") - assert re.search(_test, final_segments[idx].get("text"), flags=re.IGNORECASE | re.MULTILINE) + print(f'{idx}: {"✅" if match else "❌"} - `{_test}` -> `{text}`') + if not match: + errors.append(f"Segment {idx}: expected /{_test}/ but got '{text}'") # Check only speakers present speakers = [segment.get("speaker_id") for segment in final_segments] - assert set(speakers) == set(sample.speakers_present) + if set(speakers) != set(sample.speakers_present): + errors.append(f"Speakers: expected {set(sample.speakers_present)} but got {set(speakers)}") - # Close session - await client.disconnect() - assert not client._is_connected + # Report all errors + assert not errors, "\n".join(errors) diff --git a/tests/voice/test_09_speaker_id.py b/tests/voice/test_09_speaker_id.py index 6e8dc0bc..a8bb7ccc 100644 --- a/tests/voice/test_09_speaker_id.py +++ b/tests/voice/test_09_speaker_id.py @@ -11,7 +11,6 @@ from speechmatics.rt import ClientMessageType from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeakerIdentifier from speechmatics.voice import SpeechSegmentConfig @@ -23,7 +22,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") -URL: Optional[str] = "wss://eu2.rt.speechmatics.com/v2" +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] # List of know speakers during tests @@ -59,7 +58,6 @@ async def test_extract_speaker_ids(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -192,7 +190,6 @@ async def test_known_speakers(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -227,9 +224,6 @@ def log_final_segment(message): speakers = [segment.get("speaker_id") for segment in final_segments] assert set(speakers) == set({"Assistant", "John Doe"}) - # Should be 5 segments - assert len(final_segments) == 5 - # Close session await client.disconnect() assert not client._is_connected @@ -270,7 +264,6 @@ async def test_ignoring_assistant(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) diff --git a/tests/voice/test_11_audio_buffer.py b/tests/voice/test_11_audio_buffer.py index a10834e9..6472859c 100644 --- a/tests/voice/test_11_audio_buffer.py +++ b/tests/voice/test_11_audio_buffer.py @@ -14,7 +14,6 @@ from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SmartTurnConfig from speechmatics.voice import VoiceAgentConfig @@ -263,7 +262,6 @@ async def save_slice( AdditionalVocabEntry(content="Speechmatics", sounds_like=["speech matics"]), ], smart_turn_config=SmartTurnConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -369,7 +367,6 @@ async def save_slice( AdditionalVocabEntry(content="Speechmatics", sounds_like=["speech matics"]), ], smart_turn_config=SmartTurnConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) diff --git a/tests/voice/test_17_eou_feou.py b/tests/voice/test_17_eou_feou.py index f78c6abe..fb554c95 100644 --- a/tests/voice/test_17_eou_feou.py +++ b/tests/voice/test_17_eou_feou.py @@ -48,41 +48,41 @@ class TranscriptionTests(BaseModel): SAMPLES: TranscriptionTests = TranscriptionTests.from_dict( { "samples": [ - # { - # "id": "07b", - # "path": "./assets/audio_07b_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, - # {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, - # {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, - # {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, - # {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, - # {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, - # {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, - # {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, - # {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, - # {"text": "When?", "start_time": 25.49, "end_time": 25.96}, - # {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, - # {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, - # {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, - # ], - # }, - # { - # "id": "08", - # "path": "./assets/audio_08_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, - # {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, - # {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, - # {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, - # {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, - # {"text": "After.", "start_time": 9.56, "end_time": 10.05}, - # ], - # }, + { + "id": "07b", + "path": "./assets/audio_07b_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, + {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, + {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, + {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, + {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, + {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, + {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, + {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, + {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, + {"text": "When?", "start_time": 25.49, "end_time": 25.96}, + {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, + {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, + {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, + ], + }, + { + "id": "08", + "path": "./assets/audio_08_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, + {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, + {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, + {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, + {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, + {"text": "After.", "start_time": 9.56, "end_time": 10.05}, + ], + }, { "id": "09", "path": "./assets/audio_09_16kHz.wav", @@ -97,12 +97,12 @@ class TranscriptionTests(BaseModel): ) # VAD delay -VAD_DELAY_S: list[float] = [0.18, 0.22] +VAD_DELAY_S: list[float] = [0.18] # , 0.22] # Endpoints ENDPOINTS: list[str] = [ - # "wss://eu-west-2-research.speechmatics.cloud/v2", - "wss://eu.rt.speechmatics.com/v2", + "wss://preview.rt.speechmatics.com/v2", + # "wss://eu.rt.speechmatics.com/v2", # "wss://us.rt.speechmatics.com/v2", ] @@ -177,6 +177,11 @@ async def run_test(endpoint: str, sample: TranscriptionTest, config: VoiceAgentC # Start time start_time = datetime.datetime.now() + # Zero time + def zero_time(message): + global start_time + start_time = datetime.datetime.now() + # Finalized segment def add_segments(message): segments = message["segments"] @@ -213,6 +218,13 @@ def log_message(message): log = json.dumps({"ts": round(ts, 3), "payload": message}) print(log) + # Custom listeners + client.on(AgentServerMessageType.RECOGNITION_STARTED, zero_time) + client.on(AgentServerMessageType.END_OF_TURN, eot_detected) + client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) + client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) + client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) + # Add listeners if SHOW_LOG: message_types = [m for m in AgentServerMessageType if m != AgentServerMessageType.AUDIO_ADDED] @@ -220,12 +232,6 @@ def log_message(message): for message_type in message_types: client.on(message_type, log_message) - # Custom listeners - client.on(AgentServerMessageType.END_OF_TURN, eot_detected) - client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) - client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) - client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) - # HEADER if SHOW_LOG: print() @@ -326,7 +332,9 @@ def log_message(message): # Calculate the CER cer = TextUtils.cer(normalized_expected, normalized_received) - print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") + # Debug metrics + if SHOW_LOG: + print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") # Check CER if cer > CER_THRESHOLD: diff --git a/tests/voice/test_18_feou_timestamp.py b/tests/voice/test_18_feou_timestamp.py new file mode 100644 index 00000000..39d85bfe --- /dev/null +++ b/tests/voice/test_18_feou_timestamp.py @@ -0,0 +1,73 @@ +import os + +import pytest +from _utils import get_client +from _utils import send_silence + +from speechmatics.rt import AudioEncoding +from speechmatics.voice import VoiceAgentConfig + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping in CI") +pytestmark = pytest.mark.skipif(API_KEY is None, reason="Skipping when no API key is provided") + +# How much silence to send (seconds) +SILENCE_DURATION = 3.0 + +# Tolerance for the timestamp check +TOLERANCE = 0.00 + +# Audio format configurations to test: (encoding, chunk_size, bytes_per_sample) +AUDIO_FORMATS = [ + pytest.param(AudioEncoding.PCM_S16LE, 160, 2, id="s16-chunk160"), + pytest.param(AudioEncoding.PCM_S16LE, 320, 2, id="s16-chunk320"), + pytest.param(AudioEncoding.PCM_F32LE, 160, 4, id="f32-chunk160"), + pytest.param(AudioEncoding.PCM_F32LE, 320, 4, id="f32-chunk320"), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("encoding,chunk_size,sample_size", AUDIO_FORMATS) +async def test_feou_timestamp(encoding: AudioEncoding, chunk_size: int, sample_size: int): + """Test that audio_seconds_sent correctly computes elapsed audio time. + + Sends 3 seconds of silence (zero bytes) with different audio encodings + and chunk sizes, then verifies that audio_seconds_sent returns the + correct duration. + """ + + # Create and connect client + config = VoiceAgentConfig(audio_encoding=encoding, chunk_size=chunk_size) + client = await get_client( + api_key=API_KEY, + connect=False, + config=config, + ) + + try: + await client.connect() + except Exception: + pytest.skip("Failed to connect to server") + + assert client._is_connected + + # Send 3 seconds of silence + await send_silence( + client, + duration=SILENCE_DURATION, + chunk_size=chunk_size, + sample_size=sample_size, + ) + + # Check the computed audio seconds + actual_seconds = client.audio_seconds_sent + assert ( + abs(actual_seconds - SILENCE_DURATION) <= TOLERANCE + ), f"Expected ~{SILENCE_DURATION}s but got {actual_seconds:.4f}s" + + # Clean up + await client.disconnect() + assert not client._is_connected diff --git a/tests/voice/test_19_no_feou_fix.py b/tests/voice/test_19_no_feou_fix.py new file mode 100644 index 00000000..ad903865 --- /dev/null +++ b/tests/voice/test_19_no_feou_fix.py @@ -0,0 +1,139 @@ +import json +import os +import shutil +import time + +import pytest +from _utils import get_client +from _utils import send_audio_file + +from speechmatics.voice import AgentServerMessageType +from speechmatics.voice import EndOfTurnConfig +from speechmatics.voice import EndOfUtteranceMode +from speechmatics.voice import SmartTurnConfig +from speechmatics.voice import VoiceActivityConfig +from speechmatics.voice import VoiceAgentConfig + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping transcription tests in CI") + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") +SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] + + +@pytest.mark.asyncio +async def test_no_feou_fix(): + """Test for when FEOU is disabled.""" + + # API key + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Config + config = VoiceAgentConfig( + language="en", + end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, + end_of_utterance_silence_trigger=0.5, + smart_turn_config=SmartTurnConfig(enabled=True, smart_turn_threshold=0.80), + vad_config=VoiceActivityConfig(enabled=True), + end_of_turn_config=EndOfTurnConfig(base_multiplier=1.0), + ) + + # Debug config + print( + config.to_json( + indent=2, + exclude_none=True, + exclude_defaults=True, + exclude_unset=True, + ) + ) + + # Client + client = await get_client( + api_key=API_KEY, + connect=False, + config=config, + ) + + # Disable FEOU + client._use_forced_eou = False + + # Add listeners + messages = [message for message in AgentServerMessageType if message != AgentServerMessageType.AUDIO_ADDED] + + # Colors for messages + colors = { + "StartOfTurn": "\033[94m", # Blue + "EndOfTurn": "\033[92m", # Green + "AddSegment": "\033[93m", # Yellow + "AddPartialSegment": "\033[38;5;208m", # Orange + "SpeakerStarted": "\033[96m", # Cyan + "SpeakerEnded": "\033[95m", # Magenta + "VadStatus": "\033[91m", # Red + } + + # Callback for each message + term_width = shutil.get_terminal_size().columns + log_start_time = time.monotonic() + + def log_message(message): + """Log a message with color and formatting.""" + + # Elapsed time in seconds (right-aligned, capacity for 100s) + elapsed = time.monotonic() - log_start_time + timestamp = f"{elapsed:>7.3f}" + + # Extract message type and remaining payload (drop noisy keys) + msg_type = message.get("message", "") + rest = {k: v for k, v in message.items() if k not in ("message", "format")} + + # Color based on message type (default: dark gray) + color = colors.get(msg_type, "\033[90m") + reset = "\033[0m" + + # Format: timestamp - fixed-width type label + JSON payload + label = f"{msg_type:<20}" + payload = json.dumps(rest, default=str) + visible = f"{timestamp} - {label} - {payload}" + + # Truncate to terminal width to prevent wrapping + if len(visible) > term_width: + visible = visible[: term_width - 1] + "…" + + # Print with color + print(f"{color}{visible}{reset}") + + # Add listeners + for message_type in messages: + client.on(message_type, log_message) + + # Load the audio file `./assets/audio_01_16kHz.wav` + # audio_file = "../../tmp/feou/recording-appointment.wav" + audio_file = "./assets/audio_10_16kHz.wav" + + # HEADER + if SHOW_LOG: + print() + print() + print("---") + + # Connect + await client.connect() + + # Check we are connected + assert client._is_connected + + # Individual payloads + await send_audio_file(client, audio_file) + + # Close session + await client.disconnect() + assert not client._is_connected + + # FOOTER + if SHOW_LOG: + print("---") + print() + print()