diff --git a/.gitignore b/.gitignore index cc4e55f20..ef2ff8280 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,21 @@ sdks/typescript/node_modules/ sdks/kotlin/build/ sdks/kotlin/.gradle/ sdks/kotlin/bin/ + +# Frontend build/cache artifacts +frontend/.svelte-kit/ +frontend/build/ +frontend/.vite/ +frontend/.vercel/ + +# Voice model artifacts downloaded at runtime (Piper) +en_US-*.onnx +en_US-*.onnx.json + +# Runtime logs and provider test outputs +logs/ + +# Local tooling/worktree manager state +.kilo/ +.idea/ +.local/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d7ad7624..9e4620283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: ^\.agents/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 @@ -21,7 +22,7 @@ repos: hooks: - id: ty name: ty type checker - entry: uv run ty check bindu/ tests/ --exclude bindu/grpc/generated/ + entry: uv run ty check bindu/ --exclude bindu/grpc/generated/ language: system types: [python] pass_filenames: false @@ -29,7 +30,7 @@ repos: - id: pytest name: pytest with coverage - entry: bash -c 'uv run pytest -n auto --cov=bindu --cov-report= && coverage report --skip-covered --fail-under=60' + entry: bash -c 'uv run pytest -n auto --cov=bindu --cov-report= && uv run coverage report --skip-covered --fail-under=60' language: system pass_filenames: false always_run: true diff --git a/bindu/common/protocol/types.py b/bindu/common/protocol/types.py index bfa17f24e..eb95d5ec0 100644 --- a/bindu/common/protocol/types.py +++ b/bindu/common/protocol/types.py @@ -156,7 +156,7 @@ class FileWithUri(FileWithBytes): @pydantic.with_config(ConfigDict(alias_generator=to_camel)) -class FilePart(TextPart): +class FilePart(TypedDict): """Represents a file segment within a message or artifact. The file content can be provided either directly as bytes or as a URI. @@ -168,11 +168,15 @@ class FilePart(TextPart): file: Required[FileWithBytes | FileWithUri] """The file of the part.""" + metadata: NotRequired[dict[str, Any]] + """Metadata about the file part.""" + embeddings: NotRequired[list[float]] """The embeddings of File. """ -class DataPart(TextPart): +@pydantic.with_config(ConfigDict(alias_generator=to_camel)) +class DataPart(TypedDict): """Represents a structured data segment (e.g., JSON) within a message or artifact.""" kind: Required[Literal["data"]] @@ -181,6 +185,9 @@ class DataPart(TextPart): data: Required[dict[str, Any]] """The data of the part.""" + metadata: NotRequired[dict[str, Any]] + """Metadata about the data part.""" + embeddings: NotRequired[list[float]] """The embeddings of Data. """ diff --git a/bindu/extensions/__init__.py b/bindu/extensions/__init__.py index 43ee5b5f7..64b8fa86c 100644 --- a/bindu/extensions/__init__.py +++ b/bindu/extensions/__init__.py @@ -34,6 +34,13 @@ negotiate prices, request payments, and execute transactions based on cryptographic mandates, creating a true agent economy. +**Voice (Real-time Voice Conversations)** +Adds Vapi-like real-time voice capability to Bindu agents via WebSocket. +The pipeline uses Pipecat with configurable provider-based STT and TTS to provide +bidirectional voice conversations. Providers can be selected via the ``voice`` +config passed to ``bindufy()`` or through ``VOICE__ENABLED``-based environment +settings. + Each extension follows the A2A protocol specification for extensions: https://a2a-protocol.org/v0.3.0/topics/extensions/ diff --git a/bindu/extensions/did/did_agent_extension.py b/bindu/extensions/did/did_agent_extension.py index b2dff811b..cc11d4222 100644 --- a/bindu/extensions/did/did_agent_extension.py +++ b/bindu/extensions/did/did_agent_extension.py @@ -1,12 +1,3 @@ -# |---------------------------------------------------------| -# | | -# | Give Feedback / Get Help | -# | https://github.com/getbindu/Bindu/issues/new/choose | -# | | -# |---------------------------------------------------------| -# -# Thank you users! We ā¤ļø you! - 🌻 - """DID (Decentralized Identifier) Extension for Bindu Agents. Why is DID an Extension? @@ -29,6 +20,8 @@ import os import platform +import subprocess +import getpass from datetime import datetime, timezone from functools import cached_property @@ -222,17 +215,31 @@ def generate_and_save_key_pair(self) -> dict[str, str]: # Windows does not enforce POSIX permissions — write directly self.private_key_path.write_bytes(private_pem) self.public_key_path.write_bytes(public_pem) + try: + self._harden_windows_key_file_acl(self.private_key_path) + self._harden_windows_key_file_acl(self.public_key_path) + except Exception: + for key_path in (self.private_key_path, self.public_key_path): + try: + key_path.unlink(missing_ok=True) + except OSError: + logger.warning( + f"Failed to remove key file during ACL hardening rollback: {key_path}" + ) + raise else: # POSIX: use os.open to set permissions atomically on creation fd = os.open( str(self.private_key_path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600 ) + os.fchmod(fd, 0o600) with os.fdopen(fd, "wb") as f: f.write(private_pem) fd = os.open( str(self.public_key_path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644 ) + os.fchmod(fd, 0o644) with os.fdopen(fd, "wb") as f: f.write(public_pem) @@ -241,6 +248,44 @@ def generate_and_save_key_pair(self) -> dict[str, str]: "public_key_path": str(self.public_key_path), } + @staticmethod + def _harden_windows_key_file_acl(path: Path) -> None: + """Restrict Windows ACLs to current user for key file confidentiality.""" + username = (os.getenv("USERNAME") or "").strip() + if not username: + try: + username = os.getlogin().strip() + except Exception: + username = "" + if not username: + try: + username = (getpass.getuser() or "").strip() + except Exception: + username = "" + if not username: + raise RuntimeError( + "Unable to determine Windows username for key ACL hardening" + ) + + # Remove inherited ACLs and grant full control to current user only. + cmd = [ + "icacls", + str(path), + "/inheritance:r", + "/grant:r", + f"{username}:F", + "/remove:g", + "Users", + "Authenticated Users", + "Everyone", + ] + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + if result.returncode != 0: + raise RuntimeError( + "Failed to harden ACL for key file " + f"{path}: {result.stderr.strip() or result.stdout.strip()}" + ) + def _load_key_from_file(self, key_path: Path, key_type: str) -> bytes: """Load key PEM data from file. diff --git a/bindu/extensions/voice/__init__.py b/bindu/extensions/voice/__init__.py new file mode 100644 index 000000000..bbf0b2f85 --- /dev/null +++ b/bindu/extensions/voice/__init__.py @@ -0,0 +1,19 @@ +"""Voice Agent extension for real-time voice conversations. + +Provides Vapi-like voice capabilities (STT → Agent → TTS) integrated +into Bindu's A2A protocol and extension system. + +Usage:: + + from bindu.extensions.voice import VoiceAgentExtension + + voice = VoiceAgentExtension( + stt_provider="deepgram", + tts_provider="elevenlabs", + tts_voice_id="21m00Tcm4TlvDq8ikWAM", + ) +""" + +from .voice_agent_extension import VoiceAgentExtension + +__all__ = ["VoiceAgentExtension"] diff --git a/bindu/extensions/voice/agent_bridge.py b/bindu/extensions/voice/agent_bridge.py new file mode 100644 index 000000000..73157876d --- /dev/null +++ b/bindu/extensions/voice/agent_bridge.py @@ -0,0 +1,510 @@ +"""Agent bridge between pipecat voice pipeline and Bindu A2A agents. + +This custom pipecat ``FrameProcessor`` converts STT transcription frames +into chat messages, invokes the Bindu manifest's ``run()`` method, and +emits text frames that are consumed by the TTS service. +""" + +from __future__ import annotations + +import asyncio +import inspect +from typing import TYPE_CHECKING, Any, Callable, AsyncIterator + +from pipecat.frames.frames import ( + Frame, + StartFrame, + TranscriptionFrame, + TextFrame, + InterruptionFrame, + EndFrame, + ErrorFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from bindu.server.workers.helpers.result_processor import ResultProcessor +from bindu.utils.logging import get_logger + +if TYPE_CHECKING: + from bindu.settings import VoiceSettings + +logger = get_logger("bindu.voice.agent_bridge") + +MAX_HISTORY_TURNS = 20 +DEFAULT_FIRST_TOKEN_TIMEOUT_SECONDS = 10.0 +DEFAULT_TOTAL_RESPONSE_TIMEOUT_SECONDS = 30.0 +DEFAULT_CANCELLATION_GRACE_SECONDS = 0.5 +DEFAULT_THINKING_TEXT = "One moment." +DEFAULT_TIMEOUT_FALLBACK_TEXT = "Sorry — I’m having trouble responding right now." + + +class AgentBridgeProcessor(FrameProcessor): + """Bridges pipecat STT ↔ Bindu manifest ↔ pipecat TTS. + + Flow: + 1. Receives ``TranscriptionFrame`` from STT (user utterance). + 2. Appends to conversation history as ``{"role": "user", "content": text}``. + 3. Calls ``manifest.run(history)`` through the provided run function in a task. + 4. Collects the result and appends as ``{"role": "assistant", "content": text}``. + 5. Emits a ``TextFrame`` for the downstream TTS service. + 6. Optionally sends real-time transcript events back to the WebSocket. + """ + + def __init__( + self, + manifest_run: Callable[..., Any], + context_id: str, + *, + voice_settings: "VoiceSettings | None" = None, + allow_interruptions: bool = True, + first_token_timeout_seconds: float = DEFAULT_FIRST_TOKEN_TIMEOUT_SECONDS, + total_response_timeout_seconds: float = DEFAULT_TOTAL_RESPONSE_TIMEOUT_SECONDS, + on_state_change: Callable[[str], Any] | None = None, + on_user_transcript: Callable[[str], Any] | None = None, + on_agent_response: Callable[[str], Any] | None = None, + on_agent_transcript: Callable[[str, bool], Any] | None = None, + ): + """Initialize bridge callbacks and context for one voice session.""" + super().__init__() + self._manifest_run = manifest_run + self._context_id = context_id + self._allow_interruptions = bool(allow_interruptions) + if voice_settings is not None: + self._first_token_timeout_seconds = float(voice_settings.agent_timeout_secs) + self._total_response_timeout_seconds = float( + voice_settings.utterance_timeout_secs + ) + self._cancellation_grace_seconds = float( + voice_settings.cancellation_grace_secs + ) + self._max_history_messages = int(voice_settings.conversation_history_limit) + else: + self._first_token_timeout_seconds = float(first_token_timeout_seconds) + self._total_response_timeout_seconds = float(total_response_timeout_seconds) + self._cancellation_grace_seconds = float(DEFAULT_CANCELLATION_GRACE_SECONDS) + self._max_history_messages = MAX_HISTORY_TURNS * 2 + + self._on_state_change = on_state_change + self._on_user_transcript = on_user_transcript + self._on_agent_response = on_agent_response + self._on_agent_transcript = on_agent_transcript + self._conversation_history: list[dict[str, str]] = [] + self._background_tasks: set[asyncio.Task[Any]] = set() + + self._current_agent_task: asyncio.Task | None = None + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming pipecat frames.""" + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + # TTS and other downstream processors require lifecycle frames. + await self.push_frame(frame, direction) + return + + if isinstance(frame, TranscriptionFrame): + text = frame.text.strip() + if text: + await self._handle_user_utterance(text) + return + elif isinstance(frame, InterruptionFrame): + if not self._allow_interruptions: + return + logger.info("Interruption frame received; cancelling current agent task...") + await self._cancel_current_agent_task() + # Propagate interruption downstream so TTS/services can stop immediately. + await self.push_frame(frame, direction) + self._set_state("listening") + return + elif isinstance(frame, EndFrame): + # Session is ending, cleanup background tasks + await self.cleanup_background_tasks() + await self.push_frame(frame, direction) + return + elif isinstance(frame, ErrorFrame): + logger.error(f"Error frame received in pipeline: {frame.error}") + await self.push_frame(frame, direction) + return + + await self.push_frame(frame, direction) + + async def process_transcription( + self, text: str, *, emit_frames: bool = False + ) -> str | None: + """Process a user transcription and return the agent response. + + This is a convenience helper used by unit tests and non-pipeline callers. + It updates the conversation history and executes the agent handler + synchronously (no background task). When ``emit_frames=True``, it also + streams partial ``TextFrame`` deltas to downstream processors (TTS). + """ + cleaned = text.strip() + if not cleaned: + return None + + if self._allow_interruptions: + await self._interrupt() + else: + await self._cancel_current_agent_task() + + if self._on_user_transcript: + self._safe_callback(self._on_user_transcript, cleaned) + + self._conversation_history.append({"role": "user", "content": cleaned}) + self._trim_history() + + response_text = await self._invoke_agent_streaming(emit_frames=emit_frames) + if response_text: + self._conversation_history.append( + {"role": "assistant", "content": response_text} + ) + self._trim_history() + if self._on_agent_response: + self._safe_callback(self._on_agent_response, response_text) + return response_text + + # Keep history consistent when invocation fails or yields no response. + if self._conversation_history: + last = self._conversation_history[-1] + if last.get("role") == "user" and last.get("content") == cleaned: + self._conversation_history.pop() + self._trim_history() + return None + + async def _handle_user_utterance(self, text: str) -> None: + """Process a completed user transcription and get agent response.""" + # Cancel any running agent task + if self._allow_interruptions: + await self._interrupt() + else: + await self._cancel_current_agent_task() + + # Notify caller about the user transcript + if self._on_user_transcript: + self._safe_callback(self._on_user_transcript, text) + + # Add user message to history + self._conversation_history.append({"role": "user", "content": text}) + self._trim_history() + logger.debug( + f"Voice user ({self._context_id}): {text[:80]}{'...' if len(text) > 80 else ''}" + ) + + # Start a new task to invoke the agent + self._current_agent_task = asyncio.create_task(self._invoke_and_emit(text)) + + async def _interrupt(self) -> None: + """Cancel the in-flight agent task and propagate interruption downstream.""" + if self._current_agent_task and not self._current_agent_task.done(): + self._current_agent_task.cancel() + # Tell downstream processors (esp. TTS) to stop immediately. + try: + await self.push_frame(InterruptionFrame()) + except Exception: + logger.exception("Failed to push InterruptionFrame downstream") + grace = max(0.0, float(self._cancellation_grace_seconds)) + try: + if grace > 0: + await asyncio.wait_for(self._current_agent_task, timeout=grace) + else: + await self._current_agent_task + except (asyncio.CancelledError, TimeoutError): + pass + + async def _invoke_and_emit(self, user_text: str): + """Invoke agent and emit text frames.""" + try: + response_text = await self._invoke_agent_streaming(emit_frames=True) + if response_text: + self._conversation_history.append( + {"role": "assistant", "content": response_text} + ) + self._trim_history() + logger.debug( + f"Voice agent ({self._context_id}): {response_text[:80]}{'...' if len(response_text) > 80 else ''}" + ) + + if self._on_agent_response: + self._safe_callback(self._on_agent_response, response_text) + elif self._conversation_history: + # Keep history consistent when invocation fails or yields no response. + last = self._conversation_history[-1] + if last.get("role") == "user" and last.get("content") == user_text: + self._conversation_history.pop() + self._trim_history() + except asyncio.CancelledError: + # Handle cancellation (interruption) + logger.debug("Agent task was cancelled.") + self._set_state("listening") + # Remove the last user text if agent didn't finish responding + if self._conversation_history: + last = self._conversation_history[-1] + if last.get("role") == "user" and last.get("content") == user_text: + self._conversation_history.pop() + raise + except Exception: + if self._conversation_history: + last = self._conversation_history[-1] + if last.get("role") == "user" and last.get("content") == user_text: + self._conversation_history.pop() + self._trim_history() + logger.exception( + f"Error processing voice transcription in {self._context_id}" + ) + + async def _invoke_agent_streaming(self, *, emit_frames: bool) -> str | None: + """Invoke the agent handler and optionally stream deltas as TextFrames.""" + try: + raw = self._manifest_run(list(self._conversation_history)) + if inspect.isawaitable(raw) and not hasattr(raw, "__anext__"): + raw = await raw + + streamed_text = "" + last_emitted = "" + started_speaking = False + + async def _consume_chunk( + chunk_text: str, + streamed: str, + last: str, + speaking: bool, + ) -> tuple[str, str, bool]: + if not chunk_text: + return streamed, last, speaking + + # Some streaming handlers yield cumulative text. Emit only the delta. + delta = self._trim_overlap_text(last, chunk_text) + if not delta: + return streamed, chunk_text, speaking + + if emit_frames: + if not speaking: + speaking = True + self._set_state("agent-speaking") + if self._on_agent_transcript: + self._safe_callback(self._on_agent_transcript, delta, False) + await self.push_frame(TextFrame(text=delta)) + + streamed = self._append_text(streamed, delta) + last = chunk_text + return streamed, last, speaking + + chunks = self._iter_text_chunks(raw) + try: + async with asyncio.timeout(self._total_response_timeout_seconds): + if emit_frames: + try: + first_task = asyncio.create_task(anext(chunks)) + timeout_seconds = max( + 0.0, self._first_token_timeout_seconds + ) + if timeout_seconds > 0: + done, _pending = await asyncio.wait( + {first_task}, timeout=timeout_seconds + ) + if not done: + # TTS filler so the agent doesn't feel "dead air". + await self.push_frame( + TextFrame(text=DEFAULT_THINKING_TEXT) + ) + first = await first_task + except StopAsyncIteration: + return None + + ( + streamed_text, + last_emitted, + started_speaking, + ) = await _consume_chunk( + first, streamed_text, last_emitted, started_speaking + ) + + async for chunk_text in chunks: + ( + streamed_text, + last_emitted, + started_speaking, + ) = await _consume_chunk( + chunk_text, streamed_text, last_emitted, started_speaking + ) + except TimeoutError: + if emit_frames: + self._set_state("error") + fallback = DEFAULT_TIMEOUT_FALLBACK_TEXT + if self._on_agent_transcript: + self._safe_callback(self._on_agent_transcript, fallback, True) + await self.push_frame(TextFrame(text=fallback)) + self._set_state("listening") + return fallback + return None + + if emit_frames: + if streamed_text and self._on_agent_transcript: + self._safe_callback(self._on_agent_transcript, streamed_text, True) + self._set_state("listening") + return streamed_text or None + + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Agent invocation failed") + if emit_frames: + self._set_state("error") + return None + + async def _cancel_current_agent_task(self) -> None: + if self._current_agent_task and not self._current_agent_task.done(): + self._current_agent_task.cancel() + grace = max(0.0, float(self._cancellation_grace_seconds)) + try: + if grace > 0: + await asyncio.wait_for(self._current_agent_task, timeout=grace) + else: + await self._current_agent_task + except (asyncio.CancelledError, TimeoutError): + pass + + async def _iter_text_chunks(self, raw_results: Any) -> AsyncIterator[str]: + """Yield normalized text chunks from sync/async results.""" + if raw_results is None: + return + + # Async generator / async iterator + if hasattr(raw_results, "__anext__"): + async for item in raw_results: + text = self._extract_text(item) + if text: + yield text + return + + # Sync generator / iterator + if hasattr(raw_results, "__next__"): + for item in raw_results: # type: ignore[assignment] + text = self._extract_text(item) + if text: + yield text + return + + # Direct return + text = self._extract_text(raw_results) + if text: + yield text + + def _extract_text(self, value: Any) -> str | None: + """Extract text from handler output chunks.""" + if value is None: + return None + + normalized = ResultProcessor.normalize_result(value) + + if normalized is None: + return None + if isinstance(normalized, str): + return normalized + if isinstance(normalized, dict): + # Check "content" explicitly, preserving empty strings + if "content" in normalized and normalized["content"] is not None: + content = normalized["content"] + if isinstance(content, str): + return content + # Check "text" explicitly, preserving empty strings + if "text" in normalized and normalized["text"] is not None: + text = normalized["text"] + if isinstance(text, str): + return text + # Check "message" explicitly and convert to str if needed + if "message" in normalized and normalized["message"] is not None: + message = normalized["message"] + if isinstance(message, str): + return message + return str(message) + # Dict with state but no content: ignore for voice TTS. + if "state" in normalized: + return None + return str(normalized) + return str(normalized) + + def _append_text(self, existing: str, delta: str) -> str: + if not existing: + return delta.strip() + if not delta: + return existing + if existing.endswith((" ", "\n")) or delta.startswith((" ", "\n")): + return f"{existing}{delta}".strip() + return f"{existing} {delta}".strip() + + def _trim_overlap_text(self, previous: str, current: str) -> str: + """Remove overlap when providers stream cumulative or partial snapshots.""" + prev = previous.strip() + curr = current.strip() + if not prev: + return curr + if prev == curr: + return "" + + # Cumulative snapshot mode: emit only newly appended tail. + if curr.startswith(prev): + return curr[len(prev) :].strip() + + # Duplicate short chunk already present at end. + if prev.endswith(curr): + return "" + + # Character-level overlap to handle punctuation and non-tokenized chunks. + max_overlap = min(len(prev), len(curr)) + for overlap in range(max_overlap, 0, -1): + if prev[-overlap:] == curr[:overlap]: + return curr[overlap:].strip() + + return curr + + @property + def history(self) -> list[dict[str, str]]: + """Return a read-only copy of the conversation history.""" + return list(self._conversation_history) + + def clear_history(self) -> None: + """Clear the conversation history.""" + self._conversation_history.clear() + + async def cleanup_background_tasks(self) -> None: + """Cancel and await any background callback tasks.""" + await self._cancel_current_agent_task() + + if not self._background_tasks: + return + + tasks = list(self._background_tasks) + self._background_tasks.clear() + + for task in tasks: + task.cancel() + + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception) and not isinstance( + result, asyncio.CancelledError + ): + logger.error(f"Error while cleaning up voice callback task: {result}") + + def _trim_history(self) -> None: + """Keep only the most recent conversation turns.""" + overflow = len(self._conversation_history) - self._max_history_messages + if overflow > 0: + turns_to_drop = max(1, (overflow + 1) // 2) + del self._conversation_history[: turns_to_drop * 2] + + def _safe_callback(self, fn: Callable[..., Any], *args: Any) -> None: + """Call a callback, tracking async tasks so they are not GC'd early.""" + try: + result = fn(*args) + if asyncio.iscoroutine(result): + task = asyncio.create_task(result) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + except Exception: + logger.exception("Error in voice callback") + + def _set_state(self, state: str) -> None: + if self._on_state_change: + self._safe_callback(self._on_state_change, state) diff --git a/bindu/extensions/voice/audio_config.py b/bindu/extensions/voice/audio_config.py new file mode 100644 index 000000000..f234360b7 --- /dev/null +++ b/bindu/extensions/voice/audio_config.py @@ -0,0 +1,60 @@ +"""Audio format constants and validation for the voice extension. + +Standard speech processing format used across the pipeline: +PCM 16-bit signed little-endian, 16kHz, mono. +""" + +# Standard speech processing format +DEFAULT_SAMPLE_RATE: int = 16000 +DEFAULT_CHANNELS: int = 1 +DEFAULT_ENCODING: str = "linear16" # PCM 16-bit signed little-endian + + +def get_bytes_per_sample(encoding: str) -> int: + """Return the bytes per sample for a supported audio encoding.""" + if encoding == "linear16": + return 2 + if encoding in {"mulaw", "alaw"}: + return 1 + raise ValueError(f"Unsupported encoding: {encoding}") + + +# These constants apply to linear16 only. +BYTES_PER_SAMPLE: int = get_bytes_per_sample(DEFAULT_ENCODING) + +# Frame sizing for real-time streaming +FRAME_DURATION_MS: int = 20 # 20ms frames +FRAME_SIZE: int = ( + DEFAULT_SAMPLE_RATE * FRAME_DURATION_MS // 1000 * BYTES_PER_SAMPLE +) # 640 bytes + + +def get_frame_size(sample_rate: int, duration_ms: int, encoding: str) -> int: + """Return a frame size in bytes for the given encoding.""" + return sample_rate * duration_ms // 1000 * get_bytes_per_sample(encoding) + + +# Supported audio encodings +SUPPORTED_ENCODINGS: frozenset[str] = frozenset({"linear16", "mulaw", "alaw"}) + +# Limits +MIN_SAMPLE_RATE: int = 8000 +MAX_SAMPLE_RATE: int = 48000 + + +def validate_sample_rate(rate: int) -> int: + """Validate and return sample rate.""" + if not MIN_SAMPLE_RATE <= rate <= MAX_SAMPLE_RATE: + raise ValueError( + f"Sample rate must be between {MIN_SAMPLE_RATE} and {MAX_SAMPLE_RATE}, got {rate}" + ) + return rate + + +def validate_encoding(encoding: str) -> str: + """Validate and return audio encoding.""" + if encoding not in SUPPORTED_ENCODINGS: + raise ValueError( + f"Unsupported encoding '{encoding}'. Must be one of: {sorted(SUPPORTED_ENCODINGS)}" + ) + return encoding diff --git a/bindu/extensions/voice/pipeline_builder.py b/bindu/extensions/voice/pipeline_builder.py new file mode 100644 index 000000000..271a6bf5b --- /dev/null +++ b/bindu/extensions/voice/pipeline_builder.py @@ -0,0 +1,84 @@ +"""Voice pipeline builder. + +Assembles a pipecat-compatible voice pipeline: + WebSocket input → STT → Agent Bridge → TTS → WebSocket output + +The pipeline is built lazily when starting a voice session. +""" + +from __future__ import annotations + +from typing import Any, Callable + +from bindu.settings import app_settings +from bindu.utils.logging import get_logger + +from .agent_bridge import AgentBridgeProcessor +from .service_factory import create_stt_service, create_tts_service +from .voice_agent_extension import VoiceAgentExtension + +logger = get_logger("bindu.voice.pipeline_builder") + + +def build_voice_pipeline( + voice_ext: VoiceAgentExtension, + manifest_run: Callable[..., Any], + context_id: str, + *, + on_state_change: Callable[[str], Any] | None = None, + on_user_transcript: Callable[[str], Any] | None = None, + on_agent_response: Callable[[str], Any] | None = None, + on_agent_transcript: Callable[[str, bool], Any] | None = None, +) -> dict[str, Any]: + """Build the voice pipeline components. + + Returns a dict of pipeline components that can be wired up to a + WebSocket transport. This keeps the pipeline builder independent + of the specific pipecat transport implementation. + + Args: + voice_ext: Voice agent extension with STT/TTS config. + manifest_run: The agent manifest's ``run`` callable. + context_id: A2A context ID for this session. + on_user_transcript: Optional callback for user transcript events. + on_agent_response: Optional callback for agent response events. + + Returns: + Dictionary with ``stt``, ``tts``, ``bridge``, and ``vad`` components. + """ + stt = create_stt_service(voice_ext) + tts = create_tts_service(voice_ext) + + bridge = AgentBridgeProcessor( + manifest_run=manifest_run, + context_id=context_id, + voice_settings=voice_ext.voice_settings, + allow_interruptions=voice_ext.allow_interruptions, + on_state_change=on_state_change, + on_user_transcript=on_user_transcript, + on_agent_response=on_agent_response, + on_agent_transcript=on_agent_transcript, + ) + + vad_component = None + if app_settings.voice.vad_enabled: + from pipecat.audio.vad.silero import SileroVADAnalyzer + from pipecat.processors.audio.vad_processor import VADProcessor + + vad_analyzer = SileroVADAnalyzer( + sample_rate=app_settings.voice.sample_rate, + ) + vad_component = VADProcessor(vad_analyzer=vad_analyzer) + + logger.info( + f"Voice pipeline built: STT={voice_ext.stt_provider}/{voice_ext.stt_model}, " + f"TTS={voice_ext.tts_provider}/{voice_ext.tts_voice_id}, " + f"context={context_id}, VAD={app_settings.voice.vad_enabled}" + ) + + return { + "stt": stt, + "tts": tts, + "bridge": bridge, + "vad": vad_component, + } diff --git a/bindu/extensions/voice/redis_session_manager.py b/bindu/extensions/voice/redis_session_manager.py new file mode 100644 index 000000000..393ffc967 --- /dev/null +++ b/bindu/extensions/voice/redis_session_manager.py @@ -0,0 +1,442 @@ +"""Redis-backed voice session manager. + +Provides multi-worker compatible session storage using Redis as the backend, +enabling session lookup across Uvicorn workers. +""" + +from __future__ import annotations + +import asyncio +import json +from urllib.parse import urlsplit +from typing import Literal, Any + +import redis.asyncio as redis +from uuid import uuid4 + +from bindu.extensions.voice.session_manager import VoiceSession +from bindu.utils.logging import get_logger + +logger = get_logger("bindu.voice.redis_session_manager") + +# Constants +REDIS_KEY_PREFIX = "voice:sessions" +REDIS_ACTIVE_SET_KEY = f"{REDIS_KEY_PREFIX}:active" +DEFAULT_SESSION_TTL = 300 # seconds + + +_CREATE_SESSION_LUA = """ +-- Atomically create a session if the active set is below the limit. +-- +-- KEYS[1]: Active session set key. +-- KEYS[2]: The key for the new session to create. +-- ARGV[1]: The maximum number of sessions allowed. +-- ARGV[2]: The serialized session data to store. +-- ARGV[3]: The TTL for the new session key. +-- +-- Returns: 1 if the session was created, 0 otherwise. +local active_members = redis.call('smembers', KEYS[1]) +for _, member in ipairs(active_members) do + if redis.call('exists', member) == 0 then + redis.call('srem', KEYS[1], member) + end +end + +local active_count = redis.call('scard', KEYS[1]) +if active_count >= tonumber(ARGV[1]) then + return 0 +end +redis.call('set', KEYS[2], ARGV[2], 'EX', ARGV[3]) +redis.call('sadd', KEYS[1], KEYS[2]) +return 1 +""" + + +_UPDATE_SESSION_LUA = """ +-- Atomically update the serialized session while keeping the TTL refreshed. +-- +-- KEYS[1]: The session key to update. +-- ARGV[1]: The new state value. +-- ARGV[2]: The TTL for the session key. +-- +-- Returns: 1 if the session existed and was updated, 0 otherwise. +local raw = redis.call('get', KEYS[1]) +if not raw then + return 0 +end +local session = cjson.decode(raw) +session['state'] = ARGV[1] +redis.call('set', KEYS[1], cjson.encode(session), 'EX', ARGV[2]) +return 1 +""" + + +_DELETE_SESSION_LUA = """ +-- Remove a session key and its active-set membership atomically. +-- +-- KEYS[1]: Active session set key. +-- KEYS[2]: The session key to delete. +-- +-- Returns: 1 if the session key existed, 0 otherwise. +local removed = redis.call('del', KEYS[2]) +redis.call('srem', KEYS[1], KEYS[2]) +return removed +""" + + +class RedisVoiceSessionManager: + """Manages active voice sessions with Redis backend. + + Uses Redis string keys with JSON-serialized session data via SET/GET, + enabling session sharing across multiple Uvicorn workers. Implements the + same interface as VoiceSessionManager for compatibility. + """ + + def __init__( + self, + redis_url: str, + max_sessions: int = 10, + session_timeout: int = DEFAULT_SESSION_TTL, + redis_session_ttl: int = DEFAULT_SESSION_TTL, + ): + """Initialize the Redis session manager. + + Args: + redis_url: Redis connection URL + max_sessions: Maximum concurrent sessions allowed + session_timeout: Session timeout in seconds (for cleanup) + redis_session_ttl: TTL for Redis keys in seconds + """ + self.redis_url = redis_url + self._max_sessions = max_sessions + self._session_timeout = session_timeout + self._redis_session_ttl = redis_session_ttl + self._redis_client: redis.Redis | None = None + self._cleanup_task: asyncio.Task[None] | None = None + self._create_session_script_sha: str | None = None + self._update_session_script_sha: str | None = None + self._delete_session_script_sha: str | None = None + + async def __aenter__(self) -> RedisVoiceSessionManager: + """Enter async context manager and initialize Redis connection.""" + self._redis_client = redis.from_url( + self.redis_url, + encoding="utf-8", + decode_responses=True, + ) + try: + await self._redis_client.ping() + logger.info( + f"Redis session manager connected to {self._safe_redis_target()}" + ) + self._create_session_script_sha = await self._redis_client.script_load( + _CREATE_SESSION_LUA + ) + self._update_session_script_sha = await self._redis_client.script_load( + _UPDATE_SESSION_LUA + ) + self._delete_session_script_sha = await self._redis_client.script_load( + _DELETE_SESSION_LUA + ) + except redis.RedisError as e: + logger.error(f"Failed to connect to Redis: {e}") + if self._redis_client is not None: + await self._redis_client.aclose() + self._redis_client = None + raise ConnectionError( + f"Unable to connect to Redis at {self._safe_redis_target()}: {e}" + ) from e + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + """Exit async context manager and close Redis connection.""" + cleanup_task = self._cleanup_task + if cleanup_task is not None: + cleanup_task.cancel() + try: + await cleanup_task + except asyncio.CancelledError: + pass + finally: + self._cleanup_task = None + + redis_client = self._redis_client + if redis_client is not None: + await redis_client.aclose() + logger.info("Redis session manager connection closed") + self._redis_client = None + + def _session_key(self, session_id: str) -> str: + """Generate Redis key for a session.""" + return f"{REDIS_KEY_PREFIX}:{session_id}" + + def _serialize_session(self, session: VoiceSession) -> str: + """Serialize session to JSON string.""" + return json.dumps(session.to_dict()) + + def _safe_redis_target(self) -> str: + """Return a redacted Redis target for logs and errors.""" + parsed = urlsplit(self.redis_url) + scheme = parsed.scheme or "redis" + host = parsed.hostname or "unknown-host" + port = f":{parsed.port}" if parsed.port else "" + return f"{scheme}://***@{host}{port}" + + def _deserialize_session(self, _key: str, data: str) -> VoiceSession: + """Deserialize session from JSON string.""" + data_dict = json.loads(data) + return VoiceSession.from_dict(data_dict) + + # ------------------------------------------------------------------ + # Session lifecycle + # ------------------------------------------------------------------ + + async def create_session( + self, + context_id: str, + *, + session_token: str | None = None, + session_token_expires_at: float | None = None, + ) -> VoiceSession: + """Create a new voice session. + + Args: + context_id: A2A context ID to associate with this session. + + Returns: + The newly created ``VoiceSession``. + + Raises: + RuntimeError: If the maximum number of concurrent sessions is reached. + RuntimeError: If Redis client is not initialized. + """ + if not self._redis_client or not self._create_session_script_sha: + raise RuntimeError( + "Redis client not initialized. Use async context manager." + ) + + session_id = uuid4().hex + session = VoiceSession( + id=session_id, + context_id=context_id, + session_token=session_token, + session_token_expires_at=session_token_expires_at, + ) + key = self._session_key(session_id) + serialized_session = self._serialize_session(session) + + # Atomically check session count and create the new session using a Lua script. + # This prevents a race condition across multiple workers. + success = await self._redis_client.evalsha( + self._create_session_script_sha, + 2, + REDIS_ACTIVE_SET_KEY, + key, + self._max_sessions, + serialized_session, + self._redis_session_ttl, + ) + + if not success: + raise RuntimeError( + f"Maximum concurrent voice sessions ({self._max_sessions}) reached" + ) + + logger.info(f"Voice session created: {session_id} (context={context_id})") + return session + + async def get_session(self, session_id: str) -> VoiceSession | None: + """Get a session by ID, or ``None`` if not found. + + Args: + session_id: The session ID to look up. + + Returns: + The session if found, None otherwise. + """ + if not self._redis_client: + raise RuntimeError( + "Redis client not initialized. Use async context manager." + ) + + try: + key = self._session_key(session_id) + data = await self._redis_client.get(key) + + if data is None: + return None + + return self._deserialize_session(session_id, data) + except redis.RedisError as e: + logger.error(f"Error getting voice session {session_id}: {e}") + return None + + async def end_session(self, session_id: str) -> None: + """Gracefully end a voice session. + + Marks the session as ``ended`` and removes it from Redis. + + Args: + session_id: The session ID to end. + """ + if not self._redis_client: + raise RuntimeError( + "Redis client not initialized. Use async context manager." + ) + + if not self._delete_session_script_sha: + raise RuntimeError( + "Redis delete script not initialized. Use async context manager." + ) + + key = self._session_key(session_id) + session_data = await self._redis_client.get(key) + + if session_data: + session = self._deserialize_session(session_id, session_data) + duration = session.duration_seconds + logger.info(f"Voice session ended: {session_id} (duration={duration:.1f}s)") + + await self._redis_client.evalsha( + self._delete_session_script_sha, + 2, + REDIS_ACTIVE_SET_KEY, + key, + ) + logger.debug(f"Voice session removed from Redis: {session_id}") + + async def update_state( + self, + session_id: str, + state: Literal["connecting", "active", "ending", "ended"], + ) -> None: + """Update the state of a session. + + Args: + session_id: The session ID to update. + state: The new state. + """ + if not self._redis_client: + raise RuntimeError( + "Redis client not initialized. Use async context manager." + ) + + key = self._session_key(session_id) + if not self._update_session_script_sha: + raise RuntimeError( + "Redis update script not initialized. Use async context manager." + ) + + success = await self._redis_client.evalsha( + self._update_session_script_sha, + 1, + key, + state, + self._redis_session_ttl, + ) + if not success: + logger.debug(f"Session missing during update_state: {session_id}") + + async def get_active_count(self) -> int: + """Return the number of sessions that are not ended.""" + if not self._redis_client: + return 0 + + try: + members = await self._redis_client.smembers(REDIS_ACTIVE_SET_KEY) + if not members: + return 0 + + count = 0 + stale_members: list[str] = [] + for key in members: + data = await self._redis_client.get(key) + if data: + session = self._deserialize_session(key.split(":")[-1], data) + if session.state != "ended": + count += 1 + else: + stale_members.append(key) + + if stale_members: + await self._redis_client.srem(REDIS_ACTIVE_SET_KEY, *stale_members) + + return count + except redis.RedisError as e: + logger.error(f"Error getting active session count: {e}") + return 0 + + # ------------------------------------------------------------------ + # Background cleanup + # ------------------------------------------------------------------ + + async def start_cleanup_loop(self) -> None: + """Start the periodic session cleanup background task.""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("Redis voice session cleanup loop started") + + async def stop_cleanup_loop(self) -> None: + """Stop the cleanup background task.""" + if self._cleanup_task is not None: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + logger.info("Redis voice session cleanup loop stopped") + + async def _cleanup_loop(self) -> None: + """Periodically expire sessions that exceed the timeout.""" + while True: + try: + await asyncio.sleep(30) # check every 30s + await self._expire_timed_out_sessions() + except asyncio.CancelledError: + break + except Exception: + logger.exception("Error in Redis voice session cleanup loop") + + async def _expire_timed_out_sessions(self) -> None: + """End sessions that have exceeded the configured timeout.""" + if not self._redis_client: + return + + try: + expired: list[str] = [] + members = await self._redis_client.smembers(REDIS_ACTIVE_SET_KEY) + + for key in members: + data = await self._redis_client.get(key) + if not data: + expired.append(key) + continue + + session = self._deserialize_session(key.split(":")[-1], data) + if ( + session.state != "ended" + and session.duration_seconds > self._session_timeout + ): + session_id = session.id + expired.append(key) + logger.warning( + f"Voice session timed out: {session_id} " + f"(duration={session.duration_seconds:.1f}s, " + f"limit={self._session_timeout}s)" + ) + + for key in expired: + if self._delete_session_script_sha: + await self._redis_client.evalsha( + self._delete_session_script_sha, + 2, + REDIS_ACTIVE_SET_KEY, + key, + ) + else: + await self._redis_client.delete(key) + await self._redis_client.srem(REDIS_ACTIVE_SET_KEY, key) + + except redis.RedisError as e: + logger.error(f"Error expiring timed out sessions: {e}") diff --git a/bindu/extensions/voice/service_factory.py b/bindu/extensions/voice/service_factory.py new file mode 100644 index 000000000..8d2c461b2 --- /dev/null +++ b/bindu/extensions/voice/service_factory.py @@ -0,0 +1,223 @@ +"""Factory for creating pipecat STT and TTS service instances. + +Creates configured Deepgram STT and Piper/ElevenLabs/Azure TTS services +from the VoiceAgentExtension configuration. +""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Any + +from bindu.settings import app_settings +from bindu.utils.logging import get_logger + +if TYPE_CHECKING: + from .voice_agent_extension import VoiceAgentExtension + +logger = get_logger("bindu.voice.service_factory") + + +def create_stt_service(config: VoiceAgentExtension) -> Any: + """Create a Speech-to-Text service instance. + + Args: + config: Voice extension configuration. + + Returns: + Configured pipecat STT service. + + Raises: + ImportError: If pipecat STT dependencies are not installed. + ValueError: If the STT API key is not configured. + """ + api_key = app_settings.voice.stt_api_key + if not api_key: + logger.warning( + "STT service configuration incomplete: missing API key", + setting="VOICE__STT_API_KEY", + ) + raise ValueError("STT service configuration incomplete") + + if config.stt_provider == "deepgram": + try: + deepgram_module = importlib.import_module("pipecat.services.deepgram.stt") + DeepgramSTTService = getattr(deepgram_module, "DeepgramSTTService") + except (ImportError, AttributeError) as e: + raise ImportError( + "Deepgram STT requires pipecat[deepgram]. " + "Install with: pip install 'bindu[voice]'" + ) from e + + stt = DeepgramSTTService( + api_key=api_key, + model=config.stt_model, + language=config.stt_language, + ) + logger.info( + f"Created Deepgram STT: model={config.stt_model}, lang={config.stt_language}" + ) + return stt + + logger.warning("Unsupported STT provider requested", provider=config.stt_provider) + raise ValueError("Unsupported STT provider") + + +def create_tts_service(config: VoiceAgentExtension) -> Any: + """Create a Text-to-Speech service instance. + + Args: + config: Voice extension configuration. + + Returns: + Configured pipecat TTS service. + + Raises: + ImportError: If pipecat TTS dependencies are not installed. + ValueError: If the TTS API key is not configured. + """ + provider = config.tts_provider + fallback_provider_raw = app_settings.voice.tts_fallback_provider + fallback_provider = ( + fallback_provider_raw if isinstance(fallback_provider_raw, str) else "none" + ) + if fallback_provider not in {"none", "elevenlabs", "azure"}: + fallback_provider = "none" + + try: + return _create_tts_service_for_provider(provider, config) + except Exception as primary_error: + if fallback_provider not in {"", "none", provider}: + logger.warning( + "Primary TTS provider failed; attempting fallback", + provider=provider, + fallback_provider=fallback_provider, + error=str(primary_error), + ) + try: + return _create_tts_service_for_provider(fallback_provider, config) + except Exception as fallback_error: + raise RuntimeError( + "TTS setup failed for primary and fallback providers" + ) from fallback_error + raise + + +def _create_tts_service_for_provider(provider: str, config: VoiceAgentExtension) -> Any: + if provider == "piper": + voice_id = config.tts_voice_id + try: + piper_module = importlib.import_module("pipecat.services.piper.tts") + PiperTTSService = getattr(piper_module, "PiperTTSService") + PiperTTSSettings = getattr(piper_module, "PiperTTSSettings", None) + except (ImportError, AttributeError) as e: + raise ImportError( + "Piper TTS requires pipecat[piper]. " + "Install with: pip install 'bindu[voice]'" + ) from e + + if PiperTTSSettings is not None: + tts = PiperTTSService( + settings=PiperTTSSettings( + voice=voice_id, + ), + sample_rate=config.sample_rate, + ) + else: + tts = PiperTTSService( + voice_id=voice_id, + sample_rate=config.sample_rate, + ) + + logger.info(f"Created Piper TTS: voice={voice_id}") + return tts + + if provider == "elevenlabs": + api_key = app_settings.voice.tts_api_key + if not api_key: + logger.warning( + "TTS service configuration incomplete: missing API key", + setting="VOICE__TTS_API_KEY", + ) + raise ValueError("TTS service configuration incomplete") + + try: + elevenlabs_module = importlib.import_module( + "pipecat.services.elevenlabs.tts" + ) + ElevenLabsTTSService = getattr(elevenlabs_module, "ElevenLabsTTSService") + ElevenLabsTTSSettings = getattr( + elevenlabs_module, "ElevenLabsTTSSettings", None + ) + except (ImportError, AttributeError) as e: + raise ImportError( + "ElevenLabs TTS requires pipecat[elevenlabs]. " + "Install with: pip install 'bindu[voice]'" + ) from e + + if ElevenLabsTTSSettings is not None: + tts = ElevenLabsTTSService( + api_key=api_key, + settings=ElevenLabsTTSSettings( + voice=config.tts_voice_id, + model=config.tts_model, + ), + ) + else: + tts = ElevenLabsTTSService( + api_key=api_key, + voice_id=config.tts_voice_id, + model=config.tts_model, + ) + logger.info( + f"Created ElevenLabs TTS: voice={config.tts_voice_id}, model={config.tts_model}" + ) + return tts + + if provider == "azure": + azure_api_key = app_settings.voice.azure_tts_api_key + azure_region = app_settings.voice.azure_tts_region + azure_voice = app_settings.voice.azure_tts_voice or config.tts_voice_id + if not azure_api_key: + logger.warning( + "Azure TTS configuration incomplete: missing API key", + setting="VOICE__AZURE_TTS_API_KEY", + ) + raise ValueError("Azure TTS configuration incomplete") + if not azure_region: + logger.warning( + "Azure TTS configuration incomplete: missing region", + setting="VOICE__AZURE_TTS_REGION", + ) + raise ValueError("Azure TTS configuration incomplete") + + try: + azure_module = importlib.import_module("pipecat.services.azure.tts") + AzureTTSService = getattr(azure_module, "AzureTTSService") + AzureTTSSettings = getattr(azure_module, "AzureTTSSettings", None) + except (ImportError, AttributeError) as e: + raise ImportError( + "Azure TTS requires pipecat[azure]. " + "Install with: pip install 'bindu[voice]'" + ) from e + + if AzureTTSSettings is not None: + tts = AzureTTSService( + api_key=azure_api_key, + region=azure_region, + settings=AzureTTSSettings(voice=azure_voice), + sample_rate=config.sample_rate, + ) + else: + tts = AzureTTSService( + api_key=azure_api_key, + region=azure_region, + voice=azure_voice, + sample_rate=config.sample_rate, + ) + + logger.info(f"Created Azure TTS: voice={azure_voice}, region={azure_region}") + return tts + + logger.warning("Unsupported TTS provider requested", provider=provider) + raise ValueError("Unsupported TTS provider") diff --git a/bindu/extensions/voice/session_factory.py b/bindu/extensions/voice/session_factory.py new file mode 100644 index 000000000..86cc26d6b --- /dev/null +++ b/bindu/extensions/voice/session_factory.py @@ -0,0 +1,160 @@ +"""Factory for creating voice session manager backends. + +This module provides a factory function to create session managers based on +configuration settings. It supports easy switching between session +storage backends without changing application code. +""" + +from __future__ import annotations as _annotations + +from typing import TYPE_CHECKING, Literal, Protocol + +from bindu.extensions.voice.session_manager import VoiceSessionManager +from bindu.utils.logging import get_logger + +if TYPE_CHECKING: + from bindu.extensions.voice.session_manager import VoiceSession + from bindu.settings import VoiceSettings + +logger = get_logger("bindu.voice.session_factory") + +# Import RedisSessionManager conditionally +try: + from .redis_session_manager import RedisVoiceSessionManager + + REDIS_AVAILABLE = True +except ImportError: + RedisVoiceSessionManager = None # type: ignore[assignment] # redis not installed + REDIS_AVAILABLE = False + + +class SessionManagerBackend(Protocol): + """Common interface supported by voice session manager backends.""" + + async def create_session( + self, + context_id: str, + *, + session_token: str | None = None, + session_token_expires_at: float | None = None, + ) -> VoiceSession: + """Create a voice session for a context ID.""" + + async def get_session(self, session_id: str) -> VoiceSession | None: + """Get an existing session by ID.""" + + async def end_session(self, session_id: str) -> None: + """End and cleanup a voice session.""" + + async def update_state( + self, + session_id: str, + state: Literal["connecting", "active", "ending", "ended"], + ) -> None: + """Update the lifecycle state of a session.""" + + async def get_active_count(self) -> int: + """Return the number of sessions that are not ended.""" + + async def start_cleanup_loop(self) -> None: + """Start periodic cleanup for stale sessions.""" + + async def stop_cleanup_loop(self) -> None: + """Stop the background cleanup task.""" + + +async def create_session_manager( + settings: VoiceSettings | None = None, +) -> SessionManagerBackend: + """Create session manager backend based on configuration. + + Args: + settings: Voice settings. If not provided, uses app_settings.voice. + + Returns: + SessionManagerBackend: An instance of the appropriate session manager backend. + + Raises: + ValueError: If Redis backend is requested but Redis is not available. + """ + from bindu.settings import app_settings + + voice_settings = settings or app_settings.voice + backend = voice_settings.session_backend + + logger.info(f"Creating voice session manager with backend: {backend}") + + if backend == "memory": + logger.info("Using in-memory session manager (single-process)") + return VoiceSessionManager( + max_sessions=voice_settings.max_concurrent_sessions, + session_timeout=voice_settings.session_timeout, + ) + + if backend == "redis": + if not REDIS_AVAILABLE or RedisVoiceSessionManager is None: + raise ValueError( + "Redis session manager requires redis package. " + "Install with: pip install redis[hiredis]" + ) + + logger.info("Using Redis session manager (distributed, multi-process)") + + redis_url = voice_settings.redis_url + if not redis_url: + raise ValueError( + "Redis session manager requires a Redis URL. " + "Please provide it via VOICE__REDIS_URL environment variable or config." + ) + + manager = RedisVoiceSessionManager( + redis_url=redis_url, + max_sessions=voice_settings.max_concurrent_sessions, + session_timeout=voice_settings.session_timeout, + redis_session_ttl=voice_settings.redis_session_ttl, + ) + + # Enter async context to initialize Redis connection + try: + await manager.__aenter__() + except Exception: + # Ensure cleanup on initialization failure + await manager.__aexit__(None, None, None) + raise + return manager + + raise ValueError( + f"Unknown session backend: {backend}. Supported backends: memory, redis" + ) + + +async def close_session_manager(manager: SessionManagerBackend) -> None: + """Close session manager connection gracefully. + + Args: + manager: The session manager to close. + """ + cleanup_error: Exception | None = None + try: + await manager.stop_cleanup_loop() + logger.info(f"{type(manager).__name__} cleanup loop stopped") + except Exception as e: + cleanup_error = e + logger.error(f"Error stopping cleanup loop for {type(manager).__name__}: {e}") + finally: + if ( + REDIS_AVAILABLE + and RedisVoiceSessionManager is not None + and isinstance(manager, RedisVoiceSessionManager) + ): + try: + await manager.__aexit__(None, None, None) + logger.info(f"{type(manager).__name__} connection closed") + except Exception as e: + if cleanup_error is not None: + e.__cause__ = cleanup_error + cleanup_error = e + logger.error(f"Error closing {type(manager).__name__}: {e}") + + if cleanup_error is not None: + raise cleanup_error diff --git a/bindu/extensions/voice/session_manager.py b/bindu/extensions/voice/session_manager.py new file mode 100644 index 000000000..694da9847 --- /dev/null +++ b/bindu/extensions/voice/session_manager.py @@ -0,0 +1,291 @@ +"""Voice session manager. + +Tracks active voice sessions, enforces concurrency limits, +and runs a background cleanup task to expire timed-out sessions. +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Literal +from uuid import uuid4 + +from bindu.utils.logging import get_logger + +logger = get_logger("bindu.voice.session_manager") + + +@dataclass +class VoiceSession: + """Represents an active voice conversation session.""" + + id: str + context_id: str + task_id: str | None = None + session_token: str | None = None + session_token_expires_at: float | None = None + start_time: float = field(default_factory=time.time) + state: Literal["connecting", "active", "ending", "ended"] = "connecting" + + def __post_init__(self) -> None: + """Handle default start_time after initialization.""" + if self.start_time is None or self.start_time == 0: + self.start_time = time.time() + + @property + def duration_seconds(self) -> float: + """Elapsed time since session started.""" + return max(0.0, time.time() - self.start_time) + + def to_dict(self) -> dict: + """Serialize session to dictionary for Redis storage.""" + return { + "id": self.id, + "context_id": self.context_id, + "task_id": self.task_id, + "session_token": self.session_token, + "session_token_expires_at": self.session_token_expires_at, + "start_time": self.start_time, + "state": self.state, + } + + @classmethod + def from_dict(cls, data: dict) -> VoiceSession: + """Deserialize session from dictionary.""" + session_id = data.get("id") + if not isinstance(session_id, str) or not session_id.strip(): + raise ValueError( + "VoiceSession.id is required and must be a non-empty string" + ) + + context_id = data.get("context_id") + if not isinstance(context_id, str) or not context_id.strip(): + raise ValueError( + "VoiceSession.context_id is required and must be a non-empty string" + ) + + task_id = data.get("task_id") + if task_id is not None and ( + not isinstance(task_id, str) or not task_id.strip() + ): + raise ValueError( + "VoiceSession.task_id must be a non-empty string when provided" + ) + + session_token = data.get("session_token") + if session_token is not None and ( + not isinstance(session_token, str) or not session_token.strip() + ): + raise ValueError( + "VoiceSession.session_token must be a non-empty string when provided" + ) + + session_token_expires_at = data.get("session_token_expires_at") + if session_token_expires_at is not None and ( + not isinstance(session_token_expires_at, (int, float)) + or isinstance(session_token_expires_at, bool) + ): + raise ValueError( + "VoiceSession.session_token_expires_at must be a numeric timestamp when provided" + ) + + start_time = data.get("start_time", time.time()) + if not isinstance(start_time, (int, float)) or isinstance(start_time, bool): + raise ValueError("VoiceSession.start_time must be a numeric timestamp") + + state = data.get("state", "connecting") + allowed_states = {"connecting", "active", "ending", "ended"} + if state not in allowed_states: + raise ValueError( + f"VoiceSession.state must be one of {sorted(allowed_states)}; got {state!r}" + ) + + return cls( + id=session_id, + context_id=context_id, + task_id=task_id, + session_token=session_token, + session_token_expires_at=float(session_token_expires_at) + if session_token_expires_at is not None + else None, + start_time=float(start_time), + state=state, + ) + + +class VoiceSessionManager: + """Manages active voice sessions with lifecycle and cleanup. + + Enforces ``max_sessions`` concurrency and ``session_timeout`` + expiration through a periodic background task. + """ + + def __init__(self, max_sessions: int = 10, session_timeout: int = 300): + """Initialize in-memory voice session manager limits and cleanup state.""" + if ( + not isinstance(max_sessions, int) + or isinstance(max_sessions, bool) + or max_sessions <= 0 + ): + raise ValueError("max_sessions must be a positive integer") + if ( + not isinstance(session_timeout, int) + or isinstance(session_timeout, bool) + or session_timeout <= 0 + ): + raise ValueError("session_timeout must be a positive integer") + + self._sessions: dict[str, VoiceSession] = {} + self._max_sessions = max_sessions + self._session_timeout = session_timeout + self._cleanup_task: asyncio.Task[None] | None = None + self._lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Session lifecycle + # ------------------------------------------------------------------ + + async def create_session( + self, + context_id: str, + *, + session_token: str | None = None, + session_token_expires_at: float | None = None, + ) -> VoiceSession: + """Create a new voice session. + + Args: + context_id: A2A context ID as a non-empty string. + + Returns: + The newly created ``VoiceSession``. + + Raises: + RuntimeError: If the maximum number of concurrent sessions is reached. + ValueError: If context_id is not a non-empty string. + """ + if not isinstance(context_id, str) or not context_id.strip(): + raise ValueError("context_id must be a non-empty string") + + async with self._lock: + # Prune any already-ended sessions first + self._sessions = { + k: v for k, v in self._sessions.items() if v.state != "ended" + } + + if len(self._sessions) >= self._max_sessions: + raise RuntimeError( + f"Maximum concurrent voice sessions ({self._max_sessions}) reached" + ) + + session_id = uuid4().hex + session = VoiceSession( + id=session_id, + context_id=context_id, + session_token=session_token, + session_token_expires_at=session_token_expires_at, + ) + self._sessions[session_id] = session + + logger.info( + f"Voice session created: {session_id} (context={context_id}, " + f"active={len(self._sessions)})" + ) + return session + + async def get_session(self, session_id: str) -> VoiceSession | None: + """Get a session by ID, or ``None`` if not found.""" + async with self._lock: + return self._sessions.get(session_id) + + async def end_session(self, session_id: str) -> None: + """Gracefully end a voice session. + + Marks the session as ``ended`` and removes it from the active map. + """ + async with self._lock: + session = self._sessions.pop(session_id, None) + if session: + session.state = "ended" + logger.info( + f"Voice session ended: {session_id} " + f"(duration={session.duration_seconds:.1f}s)" + ) + + async def update_state( + self, + session_id: str, + state: Literal["connecting", "active", "ending", "ended"], + ) -> None: + """Update the state of a session.""" + async with self._lock: + if state == "ended": + session = self._sessions.pop(session_id, None) + else: + session = self._sessions.get(session_id) + if session: + session.state = state + + async def get_active_count(self) -> int: + """Return the number of sessions that are not ended.""" + async with self._lock: + return sum(1 for s in self._sessions.values() if s.state != "ended") + + # ------------------------------------------------------------------ + # Background cleanup + # ------------------------------------------------------------------ + + async def start_cleanup_loop(self) -> None: + """Start the periodic session cleanup background task.""" + async with self._lock: + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("Voice session cleanup loop started") + + async def stop_cleanup_loop(self) -> None: + """Stop the cleanup background task.""" + task: asyncio.Task[None] | None = None + async with self._lock: + if self._cleanup_task is not None: + task = self._cleanup_task + self._cleanup_task = None + + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + logger.info("Voice session cleanup loop stopped") + + async def _cleanup_loop(self) -> None: + """Periodically expire sessions that exceed the timeout.""" + while True: + try: + await asyncio.sleep(30) # check every 30 s + await self._expire_timed_out_sessions() + except asyncio.CancelledError: + break + except Exception: + logger.exception("Error in voice session cleanup loop") + + async def _expire_timed_out_sessions(self) -> None: + """End sessions that have exceeded the configured timeout.""" + async with self._lock: + expired: list[str] = [] + for sid, session in self._sessions.items(): + if ( + session.state != "ended" + and session.duration_seconds > self._session_timeout + ): + session.state = "ended" + expired.append(sid) + logger.warning( + f"Voice session timed out: {sid} " + f"(duration={session.duration_seconds:.1f}s, " + f"limit={self._session_timeout}s)" + ) + for sid in expired: + self._sessions.pop(sid, None) diff --git a/bindu/extensions/voice/voice_agent_extension.py b/bindu/extensions/voice/voice_agent_extension.py new file mode 100644 index 000000000..dcaa1a459 --- /dev/null +++ b/bindu/extensions/voice/voice_agent_extension.py @@ -0,0 +1,102 @@ +"""Voice Agent Extension for real-time voice conversations. + +This module provides the VoiceAgentExtension class that configures +STT/TTS providers, audio parameters, and session behavior for +voice-enabled Bindu agents. +""" + +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Optional + +from bindu.common.protocol.types import AgentExtension +from bindu.settings import app_settings +from bindu.utils.logging import get_logger + +if TYPE_CHECKING: + from bindu.settings import VoiceSettings + +logger = get_logger("bindu.voice_agent_extension") + + +class VoiceAgentExtension: + """Voice extension for real-time voice agent conversations. + + Configures the voice pipeline (STT, TTS, VAD) and exposes + a discoverable ``AgentExtension`` in the agent card so clients + know the agent supports voice. + """ + + def __init__( + self, + stt_provider: str = "deepgram", + stt_model: str = "nova-3", + stt_language: str = "en", + tts_provider: str = "elevenlabs", + tts_voice_id: str = "21m00Tcm4TlvDq8ikWAM", + tts_model: str = "eleven_turbo_v2_5", + sample_rate: int = 16000, + allow_interruptions: bool = True, + vad_enabled: bool = True, + description: Optional[str] = None, + voice_settings: "VoiceSettings | None" = None, + ): + """Initialize voice extension configuration for agent metadata and runtime.""" + self.stt_provider = stt_provider + self.stt_model = stt_model + self.stt_language = stt_language + self.tts_provider = tts_provider + self.tts_voice_id = tts_voice_id + self.tts_model = tts_model + self.sample_rate = sample_rate + self.allow_interruptions = allow_interruptions + self.vad_enabled = vad_enabled + self._description = description + self.voice_settings = voice_settings or app_settings.voice + + # Validate audio config eagerly + from .audio_config import validate_sample_rate + + validate_sample_rate(sample_rate) + + logger.info( + f"VoiceAgentExtension created: STT={stt_provider}/{stt_model}, " + f"TTS={tts_provider}/{tts_voice_id}, rate={sample_rate}Hz" + ) + + @cached_property + def agent_extension(self) -> AgentExtension: + """Return AgentExtension metadata for the agent card.""" + return AgentExtension( + uri=app_settings.voice.extension_uri, + description=self._description or app_settings.voice.extension_description, + required=False, # Clients can still use text + params={ + "stt_provider": self.stt_provider, + "tts_provider": self.tts_provider, + "sample_rate": self.sample_rate, + "allow_interruptions": self.allow_interruptions, + }, + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to dictionary for logging/debugging.""" + return { + "stt_provider": self.stt_provider, + "stt_model": self.stt_model, + "stt_language": self.stt_language, + "tts_provider": self.tts_provider, + "tts_voice_id": self.tts_voice_id, + "tts_model": self.tts_model, + "sample_rate": self.sample_rate, + "allow_interruptions": self.allow_interruptions, + "vad_enabled": self.vad_enabled, + } + + def __repr__(self) -> str: + """Return concise debug representation of this voice extension.""" + return ( + f"VoiceAgentExtension(stt={self.stt_provider}/{self.stt_model}, " + f"tts={self.tts_provider}/{self.tts_voice_id})" + ) diff --git a/bindu/grpc/client.py b/bindu/grpc/client.py index 188e3d867..067c0f3ab 100644 --- a/bindu/grpc/client.py +++ b/bindu/grpc/client.py @@ -25,14 +25,29 @@ from typing import Any -import grpc - -from bindu.grpc.generated import agent_handler_pb2, agent_handler_pb2_grpc +try: + import grpc + + from bindu.grpc.generated import agent_handler_pb2, agent_handler_pb2_grpc +except ImportError as exc: # pragma: no cover - optional dependency path + grpc = None # type: ignore[assignment] + agent_handler_pb2 = None # type: ignore[assignment] + agent_handler_pb2_grpc = None # type: ignore[assignment] + _GRPC_IMPORT_ERROR = exc +else: + _GRPC_IMPORT_ERROR = None from bindu.utils.logging import get_logger logger = get_logger("bindu.grpc.client") +def _require_grpc_dependencies() -> None: + if grpc is None or agent_handler_pb2 is None or agent_handler_pb2_grpc is None: + raise ImportError( + "gRPC support requires optional dependencies. Install with: pip install 'bindu[grpc]'" + ) from _GRPC_IMPORT_ERROR + + class GrpcAgentClient: """Callable gRPC client that acts as manifest.run for remote agents. @@ -72,6 +87,7 @@ def __init__( instead of HandleMessages (unary). The streaming RPC returns a generator that ResultProcessor.collect_results() will drain. """ + _require_grpc_dependencies() self._address = callback_address self._timeout = timeout self._use_streaming = use_streaming diff --git a/bindu/grpc/generated/agent_handler_pb2.py b/bindu/grpc/generated/agent_handler_pb2.py index 5db6ab4e2..4676d0263 100644 --- a/bindu/grpc/generated/agent_handler_pb2.py +++ b/bindu/grpc/generated/agent_handler_pb2.py @@ -2,71 +2,70 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: agent_handler.proto -# Protobuf Python Version: 6.31.1 +# Protobuf Python Version: 5.29.0 """Generated protocol buffer code.""" - from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, 6, 31, 1, "", "agent_handler.proto" + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'agent_handler.proto' ) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13\x61gent_handler.proto\x12\nbindu.grpc"w\n\x14RegisterAgentRequest\x12\x13\n\x0b\x63onfig_json\x18\x01 \x01(\t\x12+\n\x06skills\x18\x02 \x03(\x0b\x32\x1b.bindu.grpc.SkillDefinition\x12\x1d\n\x15grpc_callback_address\x18\x03 \x01(\t"i\n\x15RegisterAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08\x61gent_id\x18\x02 \x01(\t\x12\x0b\n\x03\x64id\x18\x03 \x01(\t\x12\x11\n\tagent_url\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t"7\n\x10HeartbeatRequest\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"C\n\x11HeartbeatResponse\x12\x14\n\x0c\x61\x63knowledged\x18\x01 \x01(\x08\x12\x18\n\x10server_timestamp\x18\x02 \x01(\x03"*\n\x16UnregisterAgentRequest\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t"9\n\x17UnregisterAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\r\n\x05\x65rror\x18\x02 \x01(\t",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t"_\n\rHandleRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.bindu.grpc.ChatMessage\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x12\n\ncontext_id\x18\x03 \x01(\t"\xbf\x01\n\x0eHandleResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t\x12\r\n\x05state\x18\x02 \x01(\t\x12\x0e\n\x06prompt\x18\x03 \x01(\t\x12\x10\n\x08is_final\x18\x04 \x01(\x08\x12:\n\x08metadata\x18\x05 \x03(\x0b\x32(.bindu.grpc.HandleResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\xb3\x01\n\x0fSkillDefinition\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0c\n\x04tags\x18\x03 \x03(\t\x12\x13\n\x0binput_modes\x18\x04 \x03(\t\x12\x14\n\x0coutput_modes\x18\x05 \x03(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x0e\n\x06\x61uthor\x18\x07 \x01(\t\x12\x13\n\x0braw_content\x18\x08 \x01(\t\x12\x0e\n\x06\x66ormat\x18\t \x01(\t"\x18\n\x16GetCapabilitiesRequest"\x96\x01\n\x17GetCapabilitiesResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1a\n\x12supports_streaming\x18\x04 \x01(\x08\x12+\n\x06skills\x18\x05 \x03(\x0b\x32\x1b.bindu.grpc.SkillDefinition"\x14\n\x12HealthCheckRequest"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\x8a\x02\n\x0c\x42induService\x12T\n\rRegisterAgent\x12 .bindu.grpc.RegisterAgentRequest\x1a!.bindu.grpc.RegisterAgentResponse\x12H\n\tHeartbeat\x12\x1c.bindu.grpc.HeartbeatRequest\x1a\x1d.bindu.grpc.HeartbeatResponse\x12Z\n\x0fUnregisterAgent\x12".bindu.grpc.UnregisterAgentRequest\x1a#.bindu.grpc.UnregisterAgentResponse2\xd4\x02\n\x0c\x41gentHandler\x12G\n\x0eHandleMessages\x12\x19.bindu.grpc.HandleRequest\x1a\x1a.bindu.grpc.HandleResponse\x12O\n\x14HandleMessagesStream\x12\x19.bindu.grpc.HandleRequest\x1a\x1a.bindu.grpc.HandleResponse0\x01\x12Z\n\x0fGetCapabilities\x12".bindu.grpc.GetCapabilitiesRequest\x1a#.bindu.grpc.GetCapabilitiesResponse\x12N\n\x0bHealthCheck\x12\x1e.bindu.grpc.HealthCheckRequest\x1a\x1f.bindu.grpc.HealthCheckResponseB6\n\x11\x63om.getbindu.grpcP\x01Z\x1fgithub.com/getbindu/bindu/protob\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x61gent_handler.proto\x12\nbindu.grpc\"w\n\x14RegisterAgentRequest\x12\x13\n\x0b\x63onfig_json\x18\x01 \x01(\t\x12+\n\x06skills\x18\x02 \x03(\x0b\x32\x1b.bindu.grpc.SkillDefinition\x12\x1d\n\x15grpc_callback_address\x18\x03 \x01(\t\"i\n\x15RegisterAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08\x61gent_id\x18\x02 \x01(\t\x12\x0b\n\x03\x64id\x18\x03 \x01(\t\x12\x11\n\tagent_url\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\"7\n\x10HeartbeatRequest\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"C\n\x11HeartbeatResponse\x12\x14\n\x0c\x61\x63knowledged\x18\x01 \x01(\x08\x12\x18\n\x10server_timestamp\x18\x02 \x01(\x03\"*\n\x16UnregisterAgentRequest\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t\"9\n\x17UnregisterAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\r\n\x05\x65rror\x18\x02 \x01(\t\",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\"_\n\rHandleRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.bindu.grpc.ChatMessage\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x12\n\ncontext_id\x18\x03 \x01(\t\"\xbf\x01\n\x0eHandleResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t\x12\r\n\x05state\x18\x02 \x01(\t\x12\x0e\n\x06prompt\x18\x03 \x01(\t\x12\x10\n\x08is_final\x18\x04 \x01(\x08\x12:\n\x08metadata\x18\x05 \x03(\x0b\x32(.bindu.grpc.HandleResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x0fSkillDefinition\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0c\n\x04tags\x18\x03 \x03(\t\x12\x13\n\x0binput_modes\x18\x04 \x03(\t\x12\x14\n\x0coutput_modes\x18\x05 \x03(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x0e\n\x06\x61uthor\x18\x07 \x01(\t\x12\x13\n\x0braw_content\x18\x08 \x01(\t\x12\x0e\n\x06\x66ormat\x18\t \x01(\t\"\x18\n\x16GetCapabilitiesRequest\"\x96\x01\n\x17GetCapabilitiesResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1a\n\x12supports_streaming\x18\x04 \x01(\x08\x12+\n\x06skills\x18\x05 \x03(\x0b\x32\x1b.bindu.grpc.SkillDefinition\"\x14\n\x12HealthCheckRequest\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\x8a\x02\n\x0c\x42induService\x12T\n\rRegisterAgent\x12 .bindu.grpc.RegisterAgentRequest\x1a!.bindu.grpc.RegisterAgentResponse\x12H\n\tHeartbeat\x12\x1c.bindu.grpc.HeartbeatRequest\x1a\x1d.bindu.grpc.HeartbeatResponse\x12Z\n\x0fUnregisterAgent\x12\".bindu.grpc.UnregisterAgentRequest\x1a#.bindu.grpc.UnregisterAgentResponse2\xd4\x02\n\x0c\x41gentHandler\x12G\n\x0eHandleMessages\x12\x19.bindu.grpc.HandleRequest\x1a\x1a.bindu.grpc.HandleResponse\x12O\n\x14HandleMessagesStream\x12\x19.bindu.grpc.HandleRequest\x1a\x1a.bindu.grpc.HandleResponse0\x01\x12Z\n\x0fGetCapabilities\x12\".bindu.grpc.GetCapabilitiesRequest\x1a#.bindu.grpc.GetCapabilitiesResponse\x12N\n\x0bHealthCheck\x12\x1e.bindu.grpc.HealthCheckRequest\x1a\x1f.bindu.grpc.HealthCheckResponseB6\n\x11\x63om.getbindu.grpcP\x01Z\x1fgithub.com/getbindu/bindu/protob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "agent_handler_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_handler_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: - _globals["DESCRIPTOR"]._loaded_options = None - _globals[ - "DESCRIPTOR" - ]._serialized_options = ( - b"\n\021com.getbindu.grpcP\001Z\037github.com/getbindu/bindu/proto" - ) - _globals["_HANDLERESPONSE_METADATAENTRY"]._loaded_options = None - _globals["_HANDLERESPONSE_METADATAENTRY"]._serialized_options = b"8\001" - _globals["_REGISTERAGENTREQUEST"]._serialized_start = 35 - _globals["_REGISTERAGENTREQUEST"]._serialized_end = 154 - _globals["_REGISTERAGENTRESPONSE"]._serialized_start = 156 - _globals["_REGISTERAGENTRESPONSE"]._serialized_end = 261 - _globals["_HEARTBEATREQUEST"]._serialized_start = 263 - _globals["_HEARTBEATREQUEST"]._serialized_end = 318 - _globals["_HEARTBEATRESPONSE"]._serialized_start = 320 - _globals["_HEARTBEATRESPONSE"]._serialized_end = 387 - _globals["_UNREGISTERAGENTREQUEST"]._serialized_start = 389 - _globals["_UNREGISTERAGENTREQUEST"]._serialized_end = 431 - _globals["_UNREGISTERAGENTRESPONSE"]._serialized_start = 433 - _globals["_UNREGISTERAGENTRESPONSE"]._serialized_end = 490 - _globals["_CHATMESSAGE"]._serialized_start = 492 - _globals["_CHATMESSAGE"]._serialized_end = 536 - _globals["_HANDLEREQUEST"]._serialized_start = 538 - _globals["_HANDLEREQUEST"]._serialized_end = 633 - _globals["_HANDLERESPONSE"]._serialized_start = 636 - _globals["_HANDLERESPONSE"]._serialized_end = 827 - _globals["_HANDLERESPONSE_METADATAENTRY"]._serialized_start = 780 - _globals["_HANDLERESPONSE_METADATAENTRY"]._serialized_end = 827 - _globals["_SKILLDEFINITION"]._serialized_start = 830 - _globals["_SKILLDEFINITION"]._serialized_end = 1009 - _globals["_GETCAPABILITIESREQUEST"]._serialized_start = 1011 - _globals["_GETCAPABILITIESREQUEST"]._serialized_end = 1035 - _globals["_GETCAPABILITIESRESPONSE"]._serialized_start = 1038 - _globals["_GETCAPABILITIESRESPONSE"]._serialized_end = 1188 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 1190 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 1210 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 1212 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 1267 - _globals["_BINDUSERVICE"]._serialized_start = 1270 - _globals["_BINDUSERVICE"]._serialized_end = 1536 - _globals["_AGENTHANDLER"]._serialized_start = 1539 - _globals["_AGENTHANDLER"]._serialized_end = 1879 + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\021com.getbindu.grpcP\001Z\037github.com/getbindu/bindu/proto' + _globals['_HANDLERESPONSE_METADATAENTRY']._loaded_options = None + _globals['_HANDLERESPONSE_METADATAENTRY']._serialized_options = b'8\001' + _globals['_REGISTERAGENTREQUEST']._serialized_start=35 + _globals['_REGISTERAGENTREQUEST']._serialized_end=154 + _globals['_REGISTERAGENTRESPONSE']._serialized_start=156 + _globals['_REGISTERAGENTRESPONSE']._serialized_end=261 + _globals['_HEARTBEATREQUEST']._serialized_start=263 + _globals['_HEARTBEATREQUEST']._serialized_end=318 + _globals['_HEARTBEATRESPONSE']._serialized_start=320 + _globals['_HEARTBEATRESPONSE']._serialized_end=387 + _globals['_UNREGISTERAGENTREQUEST']._serialized_start=389 + _globals['_UNREGISTERAGENTREQUEST']._serialized_end=431 + _globals['_UNREGISTERAGENTRESPONSE']._serialized_start=433 + _globals['_UNREGISTERAGENTRESPONSE']._serialized_end=490 + _globals['_CHATMESSAGE']._serialized_start=492 + _globals['_CHATMESSAGE']._serialized_end=536 + _globals['_HANDLEREQUEST']._serialized_start=538 + _globals['_HANDLEREQUEST']._serialized_end=633 + _globals['_HANDLERESPONSE']._serialized_start=636 + _globals['_HANDLERESPONSE']._serialized_end=827 + _globals['_HANDLERESPONSE_METADATAENTRY']._serialized_start=780 + _globals['_HANDLERESPONSE_METADATAENTRY']._serialized_end=827 + _globals['_SKILLDEFINITION']._serialized_start=830 + _globals['_SKILLDEFINITION']._serialized_end=1009 + _globals['_GETCAPABILITIESREQUEST']._serialized_start=1011 + _globals['_GETCAPABILITIESREQUEST']._serialized_end=1035 + _globals['_GETCAPABILITIESRESPONSE']._serialized_start=1038 + _globals['_GETCAPABILITIESRESPONSE']._serialized_end=1188 + _globals['_HEALTHCHECKREQUEST']._serialized_start=1190 + _globals['_HEALTHCHECKREQUEST']._serialized_end=1210 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=1212 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=1267 + _globals['_BINDUSERVICE']._serialized_start=1270 + _globals['_BINDUSERVICE']._serialized_end=1536 + _globals['_AGENTHANDLER']._serialized_start=1539 + _globals['_AGENTHANDLER']._serialized_end=1879 # @@protoc_insertion_point(module_scope) diff --git a/bindu/grpc/generated/agent_handler_pb2.pyi b/bindu/grpc/generated/agent_handler_pb2.pyi index 6a8b04ef9..3fbe3072f 100644 --- a/bindu/grpc/generated/agent_handler_pb2.pyi +++ b/bindu/grpc/generated/agent_handler_pb2.pyi @@ -1,8 +1,7 @@ from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor @@ -14,12 +13,7 @@ class RegisterAgentRequest(_message.Message): config_json: str skills: _containers.RepeatedCompositeFieldContainer[SkillDefinition] grpc_callback_address: str - def __init__( - self, - config_json: _Optional[str] = ..., - skills: _Optional[_Iterable[_Union[SkillDefinition, _Mapping]]] = ..., - grpc_callback_address: _Optional[str] = ..., - ) -> None: ... + def __init__(self, config_json: _Optional[str] = ..., skills: _Optional[_Iterable[_Union[SkillDefinition, _Mapping]]] = ..., grpc_callback_address: _Optional[str] = ...) -> None: ... class RegisterAgentResponse(_message.Message): __slots__ = ("success", "agent_id", "did", "agent_url", "error") @@ -33,14 +27,7 @@ class RegisterAgentResponse(_message.Message): did: str agent_url: str error: str - def __init__( - self, - success: bool = ..., - agent_id: _Optional[str] = ..., - did: _Optional[str] = ..., - agent_url: _Optional[str] = ..., - error: _Optional[str] = ..., - ) -> None: ... + def __init__(self, success: bool = ..., agent_id: _Optional[str] = ..., did: _Optional[str] = ..., agent_url: _Optional[str] = ..., error: _Optional[str] = ...) -> None: ... class HeartbeatRequest(_message.Message): __slots__ = ("agent_id", "timestamp") @@ -48,9 +35,7 @@ class HeartbeatRequest(_message.Message): TIMESTAMP_FIELD_NUMBER: _ClassVar[int] agent_id: str timestamp: int - def __init__( - self, agent_id: _Optional[str] = ..., timestamp: _Optional[int] = ... - ) -> None: ... + def __init__(self, agent_id: _Optional[str] = ..., timestamp: _Optional[int] = ...) -> None: ... class HeartbeatResponse(_message.Message): __slots__ = ("acknowledged", "server_timestamp") @@ -58,9 +43,7 @@ class HeartbeatResponse(_message.Message): SERVER_TIMESTAMP_FIELD_NUMBER: _ClassVar[int] acknowledged: bool server_timestamp: int - def __init__( - self, acknowledged: bool = ..., server_timestamp: _Optional[int] = ... - ) -> None: ... + def __init__(self, acknowledged: bool = ..., server_timestamp: _Optional[int] = ...) -> None: ... class UnregisterAgentRequest(_message.Message): __slots__ = ("agent_id",) @@ -82,9 +65,7 @@ class ChatMessage(_message.Message): CONTENT_FIELD_NUMBER: _ClassVar[int] role: str content: str - def __init__( - self, role: _Optional[str] = ..., content: _Optional[str] = ... - ) -> None: ... + def __init__(self, role: _Optional[str] = ..., content: _Optional[str] = ...) -> None: ... class HandleRequest(_message.Message): __slots__ = ("messages", "task_id", "context_id") @@ -94,12 +75,7 @@ class HandleRequest(_message.Message): messages: _containers.RepeatedCompositeFieldContainer[ChatMessage] task_id: str context_id: str - def __init__( - self, - messages: _Optional[_Iterable[_Union[ChatMessage, _Mapping]]] = ..., - task_id: _Optional[str] = ..., - context_id: _Optional[str] = ..., - ) -> None: ... + def __init__(self, messages: _Optional[_Iterable[_Union[ChatMessage, _Mapping]]] = ..., task_id: _Optional[str] = ..., context_id: _Optional[str] = ...) -> None: ... class HandleResponse(_message.Message): __slots__ = ("content", "state", "prompt", "is_final", "metadata") @@ -109,10 +85,7 @@ class HandleResponse(_message.Message): VALUE_FIELD_NUMBER: _ClassVar[int] key: str value: str - def __init__( - self, key: _Optional[str] = ..., value: _Optional[str] = ... - ) -> None: ... - + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... CONTENT_FIELD_NUMBER: _ClassVar[int] STATE_FIELD_NUMBER: _ClassVar[int] PROMPT_FIELD_NUMBER: _ClassVar[int] @@ -123,27 +96,10 @@ class HandleResponse(_message.Message): prompt: str is_final: bool metadata: _containers.ScalarMap[str, str] - def __init__( - self, - content: _Optional[str] = ..., - state: _Optional[str] = ..., - prompt: _Optional[str] = ..., - is_final: bool = ..., - metadata: _Optional[_Mapping[str, str]] = ..., - ) -> None: ... + def __init__(self, content: _Optional[str] = ..., state: _Optional[str] = ..., prompt: _Optional[str] = ..., is_final: bool = ..., metadata: _Optional[_Mapping[str, str]] = ...) -> None: ... class SkillDefinition(_message.Message): - __slots__ = ( - "name", - "description", - "tags", - "input_modes", - "output_modes", - "version", - "author", - "raw_content", - "format", - ) + __slots__ = ("name", "description", "tags", "input_modes", "output_modes", "version", "author", "raw_content", "format") NAME_FIELD_NUMBER: _ClassVar[int] DESCRIPTION_FIELD_NUMBER: _ClassVar[int] TAGS_FIELD_NUMBER: _ClassVar[int] @@ -162,18 +118,7 @@ class SkillDefinition(_message.Message): author: str raw_content: str format: str - def __init__( - self, - name: _Optional[str] = ..., - description: _Optional[str] = ..., - tags: _Optional[_Iterable[str]] = ..., - input_modes: _Optional[_Iterable[str]] = ..., - output_modes: _Optional[_Iterable[str]] = ..., - version: _Optional[str] = ..., - author: _Optional[str] = ..., - raw_content: _Optional[str] = ..., - format: _Optional[str] = ..., - ) -> None: ... + def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., tags: _Optional[_Iterable[str]] = ..., input_modes: _Optional[_Iterable[str]] = ..., output_modes: _Optional[_Iterable[str]] = ..., version: _Optional[str] = ..., author: _Optional[str] = ..., raw_content: _Optional[str] = ..., format: _Optional[str] = ...) -> None: ... class GetCapabilitiesRequest(_message.Message): __slots__ = () @@ -191,14 +136,7 @@ class GetCapabilitiesResponse(_message.Message): version: str supports_streaming: bool skills: _containers.RepeatedCompositeFieldContainer[SkillDefinition] - def __init__( - self, - name: _Optional[str] = ..., - description: _Optional[str] = ..., - version: _Optional[str] = ..., - supports_streaming: bool = ..., - skills: _Optional[_Iterable[_Union[SkillDefinition, _Mapping]]] = ..., - ) -> None: ... + def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., version: _Optional[str] = ..., supports_streaming: bool = ..., skills: _Optional[_Iterable[_Union[SkillDefinition, _Mapping]]] = ...) -> None: ... class HealthCheckRequest(_message.Message): __slots__ = () diff --git a/bindu/grpc/generated/agent_handler_pb2_grpc.py b/bindu/grpc/generated/agent_handler_pb2_grpc.py index 82840bb8b..f30f6f62d 100644 --- a/bindu/grpc/generated/agent_handler_pb2_grpc.py +++ b/bindu/grpc/generated/agent_handler_pb2_grpc.py @@ -1,30 +1,27 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" - import grpc +import warnings from bindu.grpc.generated import agent_handler_pb2 as agent__handler__pb2 -GRPC_GENERATED_VERSION = "1.78.0" +GRPC_GENERATED_VERSION = '1.71.2' GRPC_VERSION = grpc.__version__ _version_not_supported = False try: from grpc._utilities import first_version_is_lower - - _version_not_supported = first_version_is_lower( - GRPC_VERSION, GRPC_GENERATED_VERSION - ) + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) except ImportError: _version_not_supported = True if _version_not_supported: raise RuntimeError( - f"The grpc package installed is at version {GRPC_VERSION}," - + " but the generated code in agent_handler_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in agent_handler_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' ) @@ -42,23 +39,20 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.RegisterAgent = channel.unary_unary( - "/bindu.grpc.BinduService/RegisterAgent", - request_serializer=agent__handler__pb2.RegisterAgentRequest.SerializeToString, - response_deserializer=agent__handler__pb2.RegisterAgentResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.BinduService/RegisterAgent', + request_serializer=agent__handler__pb2.RegisterAgentRequest.SerializeToString, + response_deserializer=agent__handler__pb2.RegisterAgentResponse.FromString, + _registered_method=True) self.Heartbeat = channel.unary_unary( - "/bindu.grpc.BinduService/Heartbeat", - request_serializer=agent__handler__pb2.HeartbeatRequest.SerializeToString, - response_deserializer=agent__handler__pb2.HeartbeatResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.BinduService/Heartbeat', + request_serializer=agent__handler__pb2.HeartbeatRequest.SerializeToString, + response_deserializer=agent__handler__pb2.HeartbeatResponse.FromString, + _registered_method=True) self.UnregisterAgent = channel.unary_unary( - "/bindu.grpc.BinduService/UnregisterAgent", - request_serializer=agent__handler__pb2.UnregisterAgentRequest.SerializeToString, - response_deserializer=agent__handler__pb2.UnregisterAgentResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.BinduService/UnregisterAgent', + request_serializer=agent__handler__pb2.UnregisterAgentRequest.SerializeToString, + response_deserializer=agent__handler__pb2.UnregisterAgentResponse.FromString, + _registered_method=True) class BinduServiceServicer(object): @@ -74,50 +68,49 @@ def RegisterAgent(self, request, context): Returns agent identity and the A2A endpoint URL. """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def Heartbeat(self, request, context): - """Periodic heartbeat to signal the SDK is still alive.""" + """Periodic heartbeat to signal the SDK is still alive. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def UnregisterAgent(self, request, context): - """Unregister an agent and shut down its A2A server.""" + """Unregister an agent and shut down its A2A server. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_BinduServiceServicer_to_server(servicer, server): rpc_method_handlers = { - "RegisterAgent": grpc.unary_unary_rpc_method_handler( - servicer.RegisterAgent, - request_deserializer=agent__handler__pb2.RegisterAgentRequest.FromString, - response_serializer=agent__handler__pb2.RegisterAgentResponse.SerializeToString, - ), - "Heartbeat": grpc.unary_unary_rpc_method_handler( - servicer.Heartbeat, - request_deserializer=agent__handler__pb2.HeartbeatRequest.FromString, - response_serializer=agent__handler__pb2.HeartbeatResponse.SerializeToString, - ), - "UnregisterAgent": grpc.unary_unary_rpc_method_handler( - servicer.UnregisterAgent, - request_deserializer=agent__handler__pb2.UnregisterAgentRequest.FromString, - response_serializer=agent__handler__pb2.UnregisterAgentResponse.SerializeToString, - ), + 'RegisterAgent': grpc.unary_unary_rpc_method_handler( + servicer.RegisterAgent, + request_deserializer=agent__handler__pb2.RegisterAgentRequest.FromString, + response_serializer=agent__handler__pb2.RegisterAgentResponse.SerializeToString, + ), + 'Heartbeat': grpc.unary_unary_rpc_method_handler( + servicer.Heartbeat, + request_deserializer=agent__handler__pb2.HeartbeatRequest.FromString, + response_serializer=agent__handler__pb2.HeartbeatResponse.SerializeToString, + ), + 'UnregisterAgent': grpc.unary_unary_rpc_method_handler( + servicer.UnregisterAgent, + request_deserializer=agent__handler__pb2.UnregisterAgentRequest.FromString, + response_serializer=agent__handler__pb2.UnregisterAgentResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - "bindu.grpc.BinduService", rpc_method_handlers - ) + 'bindu.grpc.BinduService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers( - "bindu.grpc.BinduService", rpc_method_handlers - ) + server.add_registered_method_handlers('bindu.grpc.BinduService', rpc_method_handlers) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class BinduService(object): """============================================================================= BinduService — SDK calls this on the Core to register and manage agents @@ -126,22 +119,20 @@ class BinduService(object): """ @staticmethod - def RegisterAgent( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def RegisterAgent(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.BinduService/RegisterAgent", + '/bindu.grpc.BinduService/RegisterAgent', agent__handler__pb2.RegisterAgentRequest.SerializeToString, agent__handler__pb2.RegisterAgentResponse.FromString, options, @@ -152,26 +143,23 @@ def RegisterAgent( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def Heartbeat( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Heartbeat(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.BinduService/Heartbeat", + '/bindu.grpc.BinduService/Heartbeat', agent__handler__pb2.HeartbeatRequest.SerializeToString, agent__handler__pb2.HeartbeatResponse.FromString, options, @@ -182,26 +170,23 @@ def Heartbeat( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def UnregisterAgent( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def UnregisterAgent(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.BinduService/UnregisterAgent", + '/bindu.grpc.BinduService/UnregisterAgent', agent__handler__pb2.UnregisterAgentRequest.SerializeToString, agent__handler__pb2.UnregisterAgentResponse.FromString, options, @@ -212,8 +197,7 @@ def UnregisterAgent( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) class AgentHandlerStub(object): @@ -230,29 +214,25 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.HandleMessages = channel.unary_unary( - "/bindu.grpc.AgentHandler/HandleMessages", - request_serializer=agent__handler__pb2.HandleRequest.SerializeToString, - response_deserializer=agent__handler__pb2.HandleResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.AgentHandler/HandleMessages', + request_serializer=agent__handler__pb2.HandleRequest.SerializeToString, + response_deserializer=agent__handler__pb2.HandleResponse.FromString, + _registered_method=True) self.HandleMessagesStream = channel.unary_stream( - "/bindu.grpc.AgentHandler/HandleMessagesStream", - request_serializer=agent__handler__pb2.HandleRequest.SerializeToString, - response_deserializer=agent__handler__pb2.HandleResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.AgentHandler/HandleMessagesStream', + request_serializer=agent__handler__pb2.HandleRequest.SerializeToString, + response_deserializer=agent__handler__pb2.HandleResponse.FromString, + _registered_method=True) self.GetCapabilities = channel.unary_unary( - "/bindu.grpc.AgentHandler/GetCapabilities", - request_serializer=agent__handler__pb2.GetCapabilitiesRequest.SerializeToString, - response_deserializer=agent__handler__pb2.GetCapabilitiesResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.AgentHandler/GetCapabilities', + request_serializer=agent__handler__pb2.GetCapabilitiesRequest.SerializeToString, + response_deserializer=agent__handler__pb2.GetCapabilitiesResponse.FromString, + _registered_method=True) self.HealthCheck = channel.unary_unary( - "/bindu.grpc.AgentHandler/HealthCheck", - request_serializer=agent__handler__pb2.HealthCheckRequest.SerializeToString, - response_deserializer=agent__handler__pb2.HealthCheckResponse.FromString, - _registered_method=True, - ) + '/bindu.grpc.AgentHandler/HealthCheck', + request_serializer=agent__handler__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=agent__handler__pb2.HealthCheckResponse.FromString, + _registered_method=True) class AgentHandlerServicer(object): @@ -267,63 +247,62 @@ def HandleMessages(self, request, context): Core sends messages, SDK runs the developer's handler, returns response. """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def HandleMessagesStream(self, request, context): """Execute a handler with streaming response (server-side streaming). SDK yields chunks; core collects them via ResultProcessor. """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetCapabilities(self, request, context): - """Query agent capabilities (skills, supported modes).""" + """Query agent capabilities (skills, supported modes). + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def HealthCheck(self, request, context): - """Health check to verify the SDK process is responsive.""" + """Health check to verify the SDK process is responsive. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_AgentHandlerServicer_to_server(servicer, server): rpc_method_handlers = { - "HandleMessages": grpc.unary_unary_rpc_method_handler( - servicer.HandleMessages, - request_deserializer=agent__handler__pb2.HandleRequest.FromString, - response_serializer=agent__handler__pb2.HandleResponse.SerializeToString, - ), - "HandleMessagesStream": grpc.unary_stream_rpc_method_handler( - servicer.HandleMessagesStream, - request_deserializer=agent__handler__pb2.HandleRequest.FromString, - response_serializer=agent__handler__pb2.HandleResponse.SerializeToString, - ), - "GetCapabilities": grpc.unary_unary_rpc_method_handler( - servicer.GetCapabilities, - request_deserializer=agent__handler__pb2.GetCapabilitiesRequest.FromString, - response_serializer=agent__handler__pb2.GetCapabilitiesResponse.SerializeToString, - ), - "HealthCheck": grpc.unary_unary_rpc_method_handler( - servicer.HealthCheck, - request_deserializer=agent__handler__pb2.HealthCheckRequest.FromString, - response_serializer=agent__handler__pb2.HealthCheckResponse.SerializeToString, - ), + 'HandleMessages': grpc.unary_unary_rpc_method_handler( + servicer.HandleMessages, + request_deserializer=agent__handler__pb2.HandleRequest.FromString, + response_serializer=agent__handler__pb2.HandleResponse.SerializeToString, + ), + 'HandleMessagesStream': grpc.unary_stream_rpc_method_handler( + servicer.HandleMessagesStream, + request_deserializer=agent__handler__pb2.HandleRequest.FromString, + response_serializer=agent__handler__pb2.HandleResponse.SerializeToString, + ), + 'GetCapabilities': grpc.unary_unary_rpc_method_handler( + servicer.GetCapabilities, + request_deserializer=agent__handler__pb2.GetCapabilitiesRequest.FromString, + response_serializer=agent__handler__pb2.GetCapabilitiesResponse.SerializeToString, + ), + 'HealthCheck': grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=agent__handler__pb2.HealthCheckRequest.FromString, + response_serializer=agent__handler__pb2.HealthCheckResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - "bindu.grpc.AgentHandler", rpc_method_handlers - ) + 'bindu.grpc.AgentHandler', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers( - "bindu.grpc.AgentHandler", rpc_method_handlers - ) + server.add_registered_method_handlers('bindu.grpc.AgentHandler', rpc_method_handlers) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class AgentHandler(object): """============================================================================= AgentHandler — Core calls this on the SDK to execute tasks @@ -332,22 +311,20 @@ class AgentHandler(object): """ @staticmethod - def HandleMessages( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def HandleMessages(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.AgentHandler/HandleMessages", + '/bindu.grpc.AgentHandler/HandleMessages', agent__handler__pb2.HandleRequest.SerializeToString, agent__handler__pb2.HandleResponse.FromString, options, @@ -358,26 +335,23 @@ def HandleMessages( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def HandleMessagesStream( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def HandleMessagesStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/bindu.grpc.AgentHandler/HandleMessagesStream", + '/bindu.grpc.AgentHandler/HandleMessagesStream', agent__handler__pb2.HandleRequest.SerializeToString, agent__handler__pb2.HandleResponse.FromString, options, @@ -388,26 +362,23 @@ def HandleMessagesStream( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetCapabilities( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetCapabilities(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.AgentHandler/GetCapabilities", + '/bindu.grpc.AgentHandler/GetCapabilities', agent__handler__pb2.GetCapabilitiesRequest.SerializeToString, agent__handler__pb2.GetCapabilitiesResponse.FromString, options, @@ -418,26 +389,23 @@ def GetCapabilities( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def HealthCheck( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def HealthCheck(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/bindu.grpc.AgentHandler/HealthCheck", + '/bindu.grpc.AgentHandler/HealthCheck', agent__handler__pb2.HealthCheckRequest.SerializeToString, agent__handler__pb2.HealthCheckResponse.FromString, options, @@ -448,5 +416,4 @@ def HealthCheck( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) diff --git a/bindu/grpc/server.py b/bindu/grpc/server.py index d1e51119e..2d212d97b 100644 --- a/bindu/grpc/server.py +++ b/bindu/grpc/server.py @@ -22,9 +22,17 @@ from concurrent import futures -import grpc +try: + import grpc + + from bindu.grpc.generated import agent_handler_pb2_grpc +except ImportError as exc: # pragma: no cover - optional dependency path + grpc = None # type: ignore[assignment] + agent_handler_pb2_grpc = None # type: ignore[assignment] + _GRPC_IMPORT_ERROR = exc +else: + _GRPC_IMPORT_ERROR = None -from bindu.grpc.generated import agent_handler_pb2_grpc from bindu.grpc.registry import AgentRegistry from bindu.grpc.service import BinduServiceImpl from bindu.settings import app_settings @@ -33,6 +41,13 @@ logger = get_logger("bindu.grpc.server") +def _require_grpc_dependencies() -> None: + if grpc is None or agent_handler_pb2_grpc is None: + raise ImportError( + "gRPC support requires optional dependencies. Install with: pip install 'bindu[grpc]'" + ) from _GRPC_IMPORT_ERROR + + def start_grpc_server( registry: AgentRegistry | None = None, host: str | None = None, @@ -54,6 +69,8 @@ def start_grpc_server( The started grpc.Server instance. Call wait_for_termination() to block, or stop() to shut down. """ + _require_grpc_dependencies() + registry = registry or AgentRegistry() host = host or app_settings.grpc.host port = port or app_settings.grpc.port diff --git a/bindu/grpc/service.py b/bindu/grpc/service.py index d66d6bae8..9f5bf9efd 100644 --- a/bindu/grpc/service.py +++ b/bindu/grpc/service.py @@ -18,10 +18,19 @@ import time from pathlib import Path -import grpc +try: + import grpc + + from bindu.grpc.generated import agent_handler_pb2, agent_handler_pb2_grpc +except ImportError as exc: # pragma: no cover - optional dependency path + grpc = None # type: ignore[assignment] + agent_handler_pb2 = None # type: ignore[assignment] + agent_handler_pb2_grpc = None # type: ignore[assignment] + _GRPC_IMPORT_ERROR = exc +else: + _GRPC_IMPORT_ERROR = None from bindu.grpc.client import GrpcAgentClient -from bindu.grpc.generated import agent_handler_pb2, agent_handler_pb2_grpc from bindu.grpc.registry import AgentRegistry from bindu.settings import app_settings from bindu.utils.logging import get_logger @@ -29,6 +38,21 @@ logger = get_logger("bindu.grpc.service") +def _require_grpc_dependencies() -> None: + if grpc is None or agent_handler_pb2 is None or agent_handler_pb2_grpc is None: + raise ImportError( + "gRPC support requires optional dependencies. Install with: pip install 'bindu[grpc]'" + ) from _GRPC_IMPORT_ERROR + + +if agent_handler_pb2_grpc is not None: + _BinduServiceBase = agent_handler_pb2_grpc.BinduServiceServicer +else: # pragma: no cover - optional dependency path + + class _BinduServiceBase: # type: ignore[too-many-ancestors] + pass + + def _proto_skills_to_dicts( skills: list[agent_handler_pb2.SkillDefinition], ) -> list[dict]: @@ -45,6 +69,7 @@ def _proto_skills_to_dicts( List of skill dicts compatible with create_manifest(). """ result = [] + _require_grpc_dependencies() for skill in skills: skill_dict = { "name": skill.name, @@ -64,7 +89,7 @@ def _proto_skills_to_dicts( return result -class BinduServiceImpl(agent_handler_pb2_grpc.BinduServiceServicer): +class BinduServiceImpl(_BinduServiceBase): """gRPC servicer for BinduService — handles SDK registration and lifecycle. This runs on the Bindu core's gRPC server (port 3774). SDKs connect to @@ -75,6 +100,7 @@ class BinduServiceImpl(agent_handler_pb2_grpc.BinduServiceServicer): """ def __init__(self, registry: AgentRegistry) -> None: # noqa: D107 + _require_grpc_dependencies() self.registry = registry def RegisterAgent( diff --git a/bindu/penguin/bindufy.py b/bindu/penguin/bindufy.py index cfb4337ba..e14ae79e8 100644 --- a/bindu/penguin/bindufy.py +++ b/bindu/penguin/bindufy.py @@ -493,6 +493,30 @@ def _bindufy_core( # Add x402 extension to capabilities capabilities = add_extension_to_capabilities(capabilities, x402_extension) + # Voice extension (optional) + voice_config = validated_config.get("voice") + if voice_config and isinstance(voice_config, dict): + from bindu.extensions.voice import VoiceAgentExtension + + try: + voice_extension = VoiceAgentExtension(**voice_config) + capabilities = add_extension_to_capabilities(capabilities, voice_extension) + logger.info(f"Voice extension created: {voice_extension}") + except TypeError as e: + # Invalid or unexpected configuration keys + raise ValueError( + f"Invalid voice configuration: {e}. " + f"Expected keys: stt_provider, stt_model, stt_language, tts_provider, " + f"tts_voice_id, tts_model, sample_rate, allow_interruptions, vad_enabled, description" + ) from e + except ValueError as e: + # Validation errors (e.g., invalid sample_rate) + raise ValueError(f"Voice configuration validation failed: {e}") from e + except Exception as e: + # Catch any other unexpected errors + logger.error(f"Failed to create voice extension: {e}") + raise ValueError(f"Unable to initialize voice extension: {e}") from e + # Create agent manifest with loaded skills _manifest = create_manifest( agent_function=handler_callable, diff --git a/bindu/penguin/config_validator.py b/bindu/penguin/config_validator.py index 6d4be0e8b..ac4b8d689 100644 --- a/bindu/penguin/config_validator.py +++ b/bindu/penguin/config_validator.py @@ -151,8 +151,11 @@ def _validate_deployment_url(cls, config: Dict[str, Any]) -> None: "ConfigError: 'deployment.url' is a required field in the agent configuration." ) + deployment_url = deployment_url.strip() + deployment_config["url"] = deployment_url + parsed = urlparse(deployment_url) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: + if parsed.scheme not in {"http", "https"} or not parsed.hostname: raise ConfigError( f"ConfigError: 'deployment.url' must be a valid http(s) URL, got '{deployment_url}'." ) diff --git a/bindu/server/applications.py b/bindu/server/applications.py index e03ec3e2a..adad45851 100644 --- a/bindu/server/applications.py +++ b/bindu/server/applications.py @@ -110,6 +110,9 @@ def __init__( # Setup middleware chain x402_ext = get_x402_extension_from_capabilities(manifest) + from bindu.utils import get_voice_extension_from_capabilities + + voice_ext = get_voice_extension_from_capabilities(manifest) payment_requirements_for_middleware = None if x402_ext: # Type narrowing: if x402_ext exists, manifest must exist @@ -147,6 +150,8 @@ def __init__( self._scheduler: Scheduler | None = None self._agent_card_json_schema: bytes | None = None self._x402_ext = x402_ext + self._voice_ext = voice_ext + self._voice_session_manager = None self._payment_session_manager = None self._payment_requirements = None self._paywall_config = None @@ -239,6 +244,9 @@ async def root_redirect(app: BinduApplication, request: Request) -> Response: if self._x402_ext: self._register_payment_endpoints() + if self._voice_ext: + self._register_voice_endpoints() + def _register_payment_endpoints(self) -> None: """Register payment session endpoints.""" from .endpoints import ( @@ -266,6 +274,39 @@ def _register_payment_endpoints(self) -> None: with_app=True, ) + def _register_voice_endpoints(self) -> None: + """Register voice session REST + WebSocket endpoints.""" + from starlette.routing import WebSocketRoute + + from .endpoints.voice_endpoints import ( + voice_session_end, + voice_session_start, + voice_session_status, + voice_websocket, + ) + + self._add_route( + "/voice/session/start", + voice_session_start, + ["POST"], + with_app=True, + ) + self._add_route( + "/voice/session/{session_id}", + voice_session_end, + ["DELETE"], + with_app=True, + ) + self._add_route( + "/voice/session/{session_id}/status", + voice_session_status, + ["GET"], + with_app=True, + ) + self.router.routes.append( + WebSocketRoute("/ws/voice/{session_id}", voice_websocket) + ) + def _add_route( self, path: str, @@ -368,6 +409,18 @@ async def lifespan(app: BinduApplication) -> AsyncIterator[None]: if app._payment_session_manager: await app._payment_session_manager.start_cleanup_task() + # Initialize voice session manager if voice extension is enabled + if app._voice_ext: + from bindu.extensions.voice.session_factory import ( + create_session_manager, + ) + + app._voice_session_manager = await create_session_manager( + app_settings.voice + ) + await app._voice_session_manager.start_cleanup_loop() + logger.info("āœ… Voice session manager started") + # Start TaskManager if manifest: logger.info("šŸ”§ Starting TaskManager...") @@ -386,6 +439,13 @@ async def lifespan(app: BinduApplication) -> AsyncIterator[None]: if app._payment_session_manager: await app._payment_session_manager.stop_cleanup_task() + # Stop voice session manager + if app._voice_session_manager: + from bindu.extensions.voice.session_factory import close_session_manager + + await close_session_manager(app._voice_session_manager) + logger.info("šŸ›‘ Voice session manager stopped") + # Cleanup storage logger.info("🧹 Cleaning up storage...") from .storage.factory import close_storage diff --git a/bindu/server/endpoints/utils.py b/bindu/server/endpoints/utils.py index b7a8c671a..ee69c0cfc 100644 --- a/bindu/server/endpoints/utils.py +++ b/bindu/server/endpoints/utils.py @@ -5,6 +5,7 @@ from functools import wraps from typing import Any, Callable, Tuple, Type, get_args +from fastapi import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -104,6 +105,8 @@ async def wrapper(*args, **kwargs) -> Response: try: return await func(*args, **kwargs) + except HTTPException: + raise except Exception as e: logger.error( f"Error serving {endpoint_name} to {client_ip}: {e}", exc_info=True diff --git a/bindu/server/endpoints/voice_endpoints.py b/bindu/server/endpoints/voice_endpoints.py new file mode 100644 index 000000000..81b98e7c5 --- /dev/null +++ b/bindu/server/endpoints/voice_endpoints.py @@ -0,0 +1,980 @@ +"""Voice session REST + WebSocket endpoints. + +Provides: + POST /voice/session/start → Start a new voice session + DELETE /voice/session/{session_id} → End a voice session + GET /voice/session/{session_id}/status → Get session status + WS /ws/voice/{session_id} → Bidirectional audio stream +""" + +from __future__ import annotations + +import json +import asyncio +import importlib +import secrets +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from fastapi import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState + +from bindu.settings import app_settings +from bindu.utils.logging import get_logger +from bindu.server.endpoints.utils import handle_endpoint_errors + +if TYPE_CHECKING: + from bindu.server.applications import BinduApplication + +logger = get_logger("bindu.server.endpoints.voice") + +_BACKGROUND_TASKS: set[asyncio.Task[None]] = set() +_VOICE_RATE_LIMIT_LOCK = asyncio.Lock() +_VOICE_RATE_LIMIT_IP_BUCKET: dict[str, list[float]] = {} +_VOICE_RATE_LIMIT_REDIS_LOCK = asyncio.Lock() +_VOICE_RATE_LIMIT_REDIS_CLIENT: Any | None = None + +try: + import redis.asyncio as _redis_async # type: ignore[import-not-found] + + _REDIS_AVAILABLE = True +except Exception: # pragma: no cover + _redis_async = None # type: ignore[assignment] + _REDIS_AVAILABLE = False + + +_RATE_LIMIT_LUA = """ +-- Sliding-window rate limit using a sorted set. +-- KEYS[1] = zset key +-- ARGV[1] = now (seconds) +-- ARGV[2] = cutoff (seconds) +-- ARGV[3] = member (unique) +-- ARGV[4] = limit (int) +redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, tonumber(ARGV[2])) +redis.call('ZADD', KEYS[1], tonumber(ARGV[1]), ARGV[3]) +local count = redis.call('ZCARD', KEYS[1]) +redis.call('EXPIRE', KEYS[1], 120) +if count > tonumber(ARGV[4]) then + return 0 +end +return 1 +""" + + +async def _get_rate_limit_redis_client() -> Any | None: + """Lazy init a Redis client for rate limiting (best-effort).""" + global _VOICE_RATE_LIMIT_REDIS_CLIENT + if _VOICE_RATE_LIMIT_REDIS_CLIENT is not None: + return _VOICE_RATE_LIMIT_REDIS_CLIENT + + if not _REDIS_AVAILABLE: + return None + + redis_url = app_settings.voice.redis_url + if not redis_url: + return None + + async with _VOICE_RATE_LIMIT_REDIS_LOCK: + if _VOICE_RATE_LIMIT_REDIS_CLIENT is not None: + return _VOICE_RATE_LIMIT_REDIS_CLIENT + try: + _VOICE_RATE_LIMIT_REDIS_CLIENT = _redis_async.from_url( # type: ignore[union-attr] + redis_url, + encoding="utf-8", + decode_responses=True, + ) + return _VOICE_RATE_LIMIT_REDIS_CLIENT + except Exception: + logger.exception("Failed to initialize Redis rate limiter client") + _VOICE_RATE_LIMIT_REDIS_CLIENT = None + return None + + +async def _rate_limit_allow_ip( + ip: str, + *, + limit_per_minute: int, + now: float | None = None, +) -> bool: + """Allow an IP through the sliding-window rate limiter.""" + if limit_per_minute <= 0: + return True + t = float(time.time() if now is None else now) + cutoff = t - 60.0 + + if app_settings.voice.rate_limit_backend == "redis": + client = await _get_rate_limit_redis_client() + if client is not None: + key = f"voice:rate_limit:ip:{ip}" + member = f"{t}:{time.time_ns()}" + try: + allowed = await client.eval( + _RATE_LIMIT_LUA, + 1, + key, + t, + cutoff, + member, + int(limit_per_minute), + ) + return bool(allowed) + except Exception: + logger.exception("Redis rate limiter failed; falling back to memory") + + async with _VOICE_RATE_LIMIT_LOCK: + window = _VOICE_RATE_LIMIT_IP_BUCKET.get(ip, []) + window = [ts for ts in window if ts >= cutoff] + + if not window: + _VOICE_RATE_LIMIT_IP_BUCKET.pop(ip, None) + + if len(window) >= limit_per_minute: + if window: + _VOICE_RATE_LIMIT_IP_BUCKET[ip] = window + return False + window.append(t) + _VOICE_RATE_LIMIT_IP_BUCKET[ip] = window + return True + + +@dataclass +class _VoiceControlState: + muted: bool = False + stopped: bool = False + suppress_audio_until: float = 0.0 + + +class _FilteredWebSocket: + """WebSocket wrapper that filters inbound frames. + + Used to keep Pipecat's transport focused on audio frames while this endpoint + consumes and handles JSON control messages (start/mute/unmute/stop/etc). + """ + + def __init__(self, websocket: WebSocket, queue: asyncio.Queue[dict[str, Any]]): + self._ws = websocket + self._queue = queue + + def __getattr__(self, name: str) -> Any: + return getattr(self._ws, name) + + async def receive(self) -> dict[str, Any]: + msg = await self._queue.get() + msg_type = msg.get("type", "unknown") + data_bytes = msg.get("bytes") + logger.info( + f"FilteredWebSocket.receive: type={msg_type}, bytes={len(data_bytes) if data_bytes else 0}" + ) + return msg + + async def receive_text(self) -> str: + message = await self.receive() + if message.get("type") == "websocket.disconnect": + raise WebSocketDisconnect(code=message.get("code", 1000)) + text = message.get("text") + if text is None: + raise RuntimeError("Expected text WebSocket message") + return text + + async def receive_bytes(self) -> bytes: + message = await self.receive() + if message.get("type") == "websocket.disconnect": + logger.info("FilteredWebSocket.receive_bytes: disconnect received") + raise WebSocketDisconnect(code=message.get("code", 1000)) + data = message.get("bytes") + logger.debug( + f"FilteredWebSocket.receive_bytes: got {len(data) if data else 0} bytes" + ) + if data is None: + raise RuntimeError("Expected bytes WebSocket message") + return data + + +try: + from pipecat.serializers.base_serializer import ( + FrameSerializer as _PipecatFrameSerializer, + ) +except Exception: # pragma: no cover + + class _PipecatFrameSerializer: # type: ignore[too-many-ancestors] + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + +class _RawAudioFrameSerializer(_PipecatFrameSerializer): + """Serializer for raw PCM audio over WebSocket. + + Converts inbound binary frames into Pipecat input audio frames and sends + outbound audio frames as raw bytes for browser playback. + """ + + def __init__(self, sample_rate: int, num_channels: int): + super().__init__() + self._sample_rate = sample_rate + self._num_channels = num_channels + + async def setup(self, _frame: Any) -> None: + return + + async def serialize(self, frame: Any) -> str | bytes | None: + from pipecat.frames.frames import ( + OutputAudioRawFrame, + OutputTransportMessageFrame, + OutputTransportMessageUrgentFrame, + ) + + if isinstance(frame, OutputAudioRawFrame): + return frame.audio + + if isinstance( + frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame) + ): + return json.dumps(frame.message) + + return None + + async def deserialize(self, data: str | bytes) -> Any | None: + from pipecat.frames.frames import InputAudioRawFrame, InputTransportMessageFrame + + if isinstance(data, (bytes, bytearray)): + return InputAudioRawFrame( + audio=bytes(data), + sample_rate=self._sample_rate, + num_channels=self._num_channels, + ) + + if isinstance(data, str): + try: + payload = json.loads(data) + except json.JSONDecodeError: + return None + return InputTransportMessageFrame(message=payload) + + return None + + +async def _send_json( + websocket: Any, + payload: dict[str, Any], + send_lock: asyncio.Lock | None = None, +) -> bool: + """Send a JSON payload over a WebSocket safely. + + If a lock is provided, it is used to serialize concurrent send_text calls. + """ + try: + client_state = getattr(websocket, "client_state", None) + application_state = getattr(websocket, "application_state", None) + if ( + client_state == WebSocketState.DISCONNECTED + or application_state == WebSocketState.DISCONNECTED + ): + return False + except Exception: + # If state inspection fails, still attempt the send below. + pass + + data = json.dumps(payload) + try: + if send_lock is None: + await websocket.send_text(data) + return True + async with send_lock: + await websocket.send_text(data) + return True + except (WebSocketDisconnect, RuntimeError): + # RuntimeError is raised by Starlette after a close frame is sent. + return False + + +def _trim_overlap_text(previous: str, current: str) -> str: + """Trim repeated word-overlap between adjacent transcript chunks. + + Returns only the delta portion of ``current`` that does not duplicate + the suffix of ``previous``. + """ + prev = (previous or "").strip() + curr = (current or "").strip() + if not prev: + return curr + if prev == curr: + return "" + + prev_tokens = prev.split() + curr_tokens = curr.split() + max_overlap = min(len(prev_tokens), len(curr_tokens)) + + for overlap in range(max_overlap, 0, -1): + if prev_tokens[-overlap:] == curr_tokens[:overlap]: + return " ".join(curr_tokens[overlap:]).strip() + + return curr + + +def _extract_bearer_token(value: str | None) -> str | None: + if not value: + return None + parts = value.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + token = parts[1].strip() + return token or None + return None + + +def _extract_ws_session_token(websocket: WebSocket) -> str | None: + """Extract session token from WS headers (subprotocol or Authorization).""" + subprotocols = websocket.headers.get("sec-websocket-protocol") + if subprotocols: + # Starlette exposes the raw comma-separated list. + for item in subprotocols.split(","): + token = item.strip() + if token: + return token + + return _extract_bearer_token(websocket.headers.get("authorization")) + + +def _has_ws_subprotocols(websocket: WebSocket) -> bool: + return bool(websocket.headers.get("sec-websocket-protocol")) + + +def _classify_voice_pipeline_error(exc: Exception) -> tuple[str, int]: + """Map internal pipeline exceptions to a user-facing error + close code.""" + module = type(exc).__module__ + name = type(exc).__name__ + message = str(exc) + + if isinstance(exc, asyncio.TimeoutError) or name == "TimeoutError": + return ("Voice pipeline timed out", 1011) + + if module.startswith("websockets") or "ConnectionClosed" in name: + return ("Voice provider connection closed", 1011) + + lowered = message.lower() + if "deepgram" in lowered and ("disconnect" in lowered or "closed" in lowered): + return ("Deepgram connection closed", 1011) + if "elevenlabs" in lowered and ("disconnect" in lowered or "closed" in lowered): + return ("ElevenLabs connection closed", 1011) + + return ("Voice pipeline error", 1011) + + +def _voice_preflight_error() -> str | None: + """Return a user-facing error if voice is not runnable in this process.""" + # Ensure optional dependency group is installed. + try: + importlib.import_module("pipecat") + except Exception: + return "Voice dependencies are not installed. Install with: pip install 'bindu[voice]'" + + # Ensure provider keys exist (provider-dependent). + if ( + app_settings.voice.stt_provider == "deepgram" + and not app_settings.voice.stt_api_key + ): + return "VOICE__STT_API_KEY is required when VOICE__STT_PROVIDER=deepgram" + + tts_provider = app_settings.voice.tts_provider + if tts_provider == "elevenlabs" and not app_settings.voice.tts_api_key: + return "VOICE__TTS_API_KEY is required when VOICE__TTS_PROVIDER=elevenlabs" + if tts_provider == "azure": + if not app_settings.voice.azure_tts_api_key: + return "VOICE__AZURE_TTS_API_KEY is required when VOICE__TTS_PROVIDER=azure" + if not app_settings.voice.azure_tts_region: + return "VOICE__AZURE_TTS_REGION is required when VOICE__TTS_PROVIDER=azure" + if tts_provider not in {"elevenlabs", "piper", "azure"}: + return f"Unsupported VOICE__TTS_PROVIDER={tts_provider!r}" + + return None + + +async def _send_error_and_close( + websocket: WebSocket, + message: str, + *, + send_lock: asyncio.Lock, + close_code: int = 1008, +) -> None: + try: + await _send_json(websocket, {"type": "error", "message": message}, send_lock) + finally: + try: + await websocket.close(code=close_code, reason=message) + except Exception: + pass + + +async def _voice_control_reader( + websocket: WebSocket, + inbound_queue: asyncio.Queue[dict[str, Any]], + control: _VoiceControlState, + *, + vad_enabled: bool, + send_lock: asyncio.Lock, + on_user_text: Any | None = None, +) -> None: + """Read from the real WebSocket and push only audio frames to the queue.""" + max_binary_frame_bytes = 64 * 1024 + max_frames_per_second = 50 + max_frames_in_flight = 10 + + window_started_at = time.monotonic() + window_count = 0 + + while True: + message: dict[str, Any] = await websocket.receive() + message_type = message.get("type") + + if message_type == "websocket.disconnect": + await inbound_queue.put(message) + return + + if message_type != "websocket.receive": + await inbound_queue.put(message) + continue + + text = message.get("text") + if text is not None: + try: + payload = json.loads(text) + except json.JSONDecodeError: + await _send_error_and_close( + websocket, + "Malformed JSON control frame", + send_lock=send_lock, + ) + return + + frame_type = payload.get("type") + if frame_type == "mute": + control.muted = True + await _send_json( + websocket, {"type": "state", "state": "muted"}, send_lock + ) + continue + if frame_type == "unmute": + control.muted = False + await _send_json( + websocket, {"type": "state", "state": "listening"}, send_lock + ) + continue + if frame_type == "stop": + control.stopped = True + await _send_json( + websocket, {"type": "state", "state": "ended"}, send_lock + ) + try: + await websocket.close() + finally: + return + if frame_type == "user_text": + user_text = payload.get("text") + if isinstance(user_text, str) and user_text.strip() and on_user_text: + try: + await on_user_text(user_text.strip()) + except Exception: + logger.exception("Failed to handle user_text control frame") + continue + if frame_type in {"start", "commit_turn"}: + # If VAD is disabled, the transport/STT may rely on explicit turn boundary + # control frames (e.g. commit_turn). Forward these to the transport. + if not vad_enabled: + await inbound_queue.put(message) + continue + + # Unknown control frame: ignore to preserve forward-compat. + continue + + data = message.get("bytes") + if data is None: + logger.debug("Voice control reader: no bytes in message, skipping") + continue + + logger.info(f"Voice control reader: received audio frame, {len(data)} bytes") + + now_monotonic = time.monotonic() + if control.muted or now_monotonic < float(control.suppress_audio_until): + continue + + if len(data) > max_binary_frame_bytes: + await _send_error_and_close( + websocket, + f"Audio frame too large (max {max_binary_frame_bytes} bytes)", + send_lock=send_lock, + ) + return + + if now_monotonic - window_started_at >= 1.0: + window_started_at = now_monotonic + window_count = 0 + window_count += 1 + if window_count > max_frames_per_second: + await _send_error_and_close( + websocket, + f"Too many audio frames per second (max {max_frames_per_second})", + send_lock=send_lock, + ) + return + + if inbound_queue.qsize() >= max_frames_in_flight: + await _send_error_and_close( + websocket, + f"Too many audio frames in flight (max {max_frames_in_flight})", + send_lock=send_lock, + ) + return + + await inbound_queue.put(message) + + +# --------------------------------------------------------------------------- +# REST Endpoints +# --------------------------------------------------------------------------- + + +@handle_endpoint_errors("start voice session") +async def voice_session_start(app: BinduApplication, request: Request) -> Response: + """Start a new voice session. + + Request body (optional JSON): + { "context_id": "" } + + Returns: + { "session_id": "...", "ws_url": "ws://host/ws/voice/{session_id}" } + """ + session_manager = getattr(app, "_voice_session_manager", None) + if session_manager is None: + return JSONResponse( + {"error": "Voice extension is not enabled"}, status_code=501 + ) + + if not app_settings.voice.enabled: + return JSONResponse({"error": "Voice is disabled"}, status_code=501) + + preflight_error = _voice_preflight_error() + if preflight_error: + return JSONResponse({"error": preflight_error}, status_code=503) + + # Per-IP rate limit (best-effort; request.client may be missing in tests/proxies) + client_host = request.client.host if request.client else None + if client_host: + allowed = await _rate_limit_allow_ip( + client_host, + limit_per_minute=int(app_settings.voice.rate_limit_per_ip_per_minute), + ) + if not allowed: + return JSONResponse({"error": "Rate limit exceeded"}, status_code=429) + + # Parse optional context_id from body + context_id = str(uuid4()) + raw_body = await request.body() + if raw_body: + try: + body = json.loads(raw_body) + except json.JSONDecodeError as exc: + raise HTTPException( + status_code=400, detail="Malformed JSON payload" + ) from exc + + if isinstance(body, dict) and "context_id" in body: + context_id = str(body["context_id"]) + + session_token: str | None = None + session_token_expires_at: float | None = None + if app_settings.voice.session_auth_required: + session_token = secrets.token_urlsafe(32) + session_token_expires_at = time.time() + max( + 1, int(app_settings.voice.session_token_ttl) + ) + + try: + session = await session_manager.create_session( + context_id, + session_token=session_token, + session_token_expires_at=session_token_expires_at, + ) + except RuntimeError as exc: + return JSONResponse({"error": str(exc)}, status_code=429) + + # Build WebSocket URL from request + scheme = "wss" if request.url.scheme == "https" else "ws" + # Use hostname from request, fallback to client host, or raise error if unavailable + host = request.url.hostname or (request.client.host if request.client else None) + if not host: + return JSONResponse( + {"error": "Unable to determine host for WebSocket URL"}, + status_code=400, + ) + ws_url = f"{scheme}://{host}" + if request.url.port: + ws_url += f":{request.url.port}" + ws_url += f"/ws/voice/{session.id}" + + return JSONResponse( + { + "session_id": session.id, + "context_id": session.context_id, + "ws_url": ws_url, + **({"session_token": session_token} if session_token else {}), + }, + status_code=201, + ) + + +@handle_endpoint_errors("end voice session") +async def voice_session_end(app: BinduApplication, request: Request) -> Response: + """End a voice session. + + Path params: + session_id: The voice session ID + + Returns: + { "status": "ended" } + """ + session_manager = getattr(app, "_voice_session_manager", None) + if session_manager is None: + return JSONResponse( + {"error": "Voice extension is not enabled"}, status_code=501 + ) + + session_id = request.path_params["session_id"] + session = await session_manager.get_session(session_id) + if session is None: + return JSONResponse({"error": "Session not found"}, status_code=404) + + await session_manager.end_session(session_id) + return JSONResponse({"status": "ended"}) + + +@handle_endpoint_errors("voice session status") +async def voice_session_status(app: BinduApplication, request: Request) -> Response: + """Get voice session status. + + Path params: + session_id: The voice session ID + + Returns: + { "session_id": "...", "state": "...", "duration": 12.3, "context_id": "..." } + """ + session_manager = getattr(app, "_voice_session_manager", None) + if session_manager is None: + return JSONResponse( + {"error": "Voice extension is not enabled"}, status_code=501 + ) + + session_id = request.path_params["session_id"] + session = await session_manager.get_session(session_id) + if session is None: + return JSONResponse({"error": "Session not found"}, status_code=404) + + return JSONResponse( + { + "session_id": session.id, + "context_id": session.context_id, + "state": session.state, + "duration": round(session.duration_seconds, 1), + "task_id": session.task_id, + } + ) + + +# --------------------------------------------------------------------------- +# WebSocket Handler +# --------------------------------------------------------------------------- + + +async def voice_websocket(websocket: WebSocket) -> None: + """Bidirectional voice WebSocket handler using Pipecat pipeline.""" + app: BinduApplication = websocket.app # type: ignore[assignment] + session_id = websocket.path_params.get("session_id", "") + + session_manager = getattr(app, "_voice_session_manager", None) + if session_manager is None: + await websocket.close(code=1008, reason="Voice extension is not enabled") + return + if not app_settings.voice.enabled: + await websocket.close(code=1008, reason="Voice is disabled") + return + + preflight_error = _voice_preflight_error() + if preflight_error: + await websocket.close(code=1008, reason=preflight_error) + return + + session = await session_manager.get_session(session_id) + if session is None: + await websocket.close(code=1008, reason="Invalid session ID") + return + + # Per-IP rate limit on websocket connects (best-effort) + client_host = websocket.client.host if websocket.client else None + if client_host: + allowed = await _rate_limit_allow_ip( + client_host, + limit_per_minute=int(app_settings.voice.rate_limit_per_ip_per_minute), + ) + if not allowed: + await websocket.close(code=1008, reason="Rate limit exceeded") + return + + if app_settings.voice.session_auth_required: + expected = getattr(session, "session_token", None) + expires_at = getattr(session, "session_token_expires_at", None) + provided = _extract_ws_session_token(websocket) + + # If the client sent Sec-WebSocket-Protocol, the server should select one. + # We select the token itself as the negotiated subprotocol. + if provided and _has_ws_subprotocols(websocket): + await websocket.accept(subprotocol=provided) + else: + await websocket.accept() + if not provided: + try: + provided = (await websocket.receive_text()).strip() + except Exception: + await websocket.close(code=1008, reason="Missing session token") + return + + if not expected or provided != expected: + await websocket.close(code=1008, reason="Invalid session token") + return + if isinstance(expires_at, (int, float)) and time.time() > float(expires_at): + await websocket.close(code=1008, reason="Session token expired") + return + else: + await websocket.accept() + + await session_manager.update_state(session_id, "active") + + voice_ext = getattr(app, "_voice_ext", None) + manifest = getattr(app, "manifest", None) + if voice_ext is None or manifest is None or not hasattr(manifest, "run"): + await websocket.send_text( + json.dumps({"type": "error", "message": "Agent not configured for voice"}) + ) + await websocket.close(code=1011) + return + + from bindu.extensions.voice.pipeline_builder import build_voice_pipeline + from pipecat.transports.websocket.fastapi import ( + FastAPIWebsocketTransport, + FastAPIWebsocketParams, + ) + from pipecat.pipeline.pipeline import Pipeline + from pipecat.pipeline.task import PipelineTask + from pipecat.pipeline.runner import PipelineRunner + + # Notify UI we are listening + send_lock = asyncio.Lock() + await _send_json(websocket, {"type": "state", "state": "listening"}, send_lock) + + async def _on_user_transcript(text: str) -> None: + await _send_json( + websocket, + {"type": "transcript", "role": "user", "text": text, "is_final": True}, + send_lock, + ) + + async def _on_agent_response(text: str) -> None: + control.suppress_audio_until = max( + float(control.suppress_audio_until), time.monotonic() + 0.6 + ) + await _send_json( + websocket, + {"type": "agent_response", "text": text, "task_id": session.task_id}, + send_lock, + ) + + async def _on_state_change(state: str) -> None: + if state == "agent-speaking": + control.suppress_audio_until = max( + float(control.suppress_audio_until), time.monotonic() + 1.0 + ) + elif state == "listening": + control.suppress_audio_until = max( + float(control.suppress_audio_until), time.monotonic() + 0.35 + ) + await _send_json(websocket, {"type": "state", "state": state}, send_lock) + + async def _on_agent_transcript(text: str, is_final: bool) -> None: + control.suppress_audio_until = max( + float(control.suppress_audio_until), + time.monotonic() + (0.6 if is_final else 0.9), + ) + logger.info( + f"on_agent_transcript: got text='{text[:50]}...' is_final={is_final}" + ) + await _send_json( + websocket, + { + "type": "transcript", + "role": "agent", + "text": text, + "is_final": is_final, + }, + send_lock, + ) + + components = build_voice_pipeline( + voice_ext=voice_ext, + manifest_run=manifest.run, + context_id=session.context_id, + on_state_change=_on_state_change, + on_user_transcript=_on_user_transcript, + on_agent_response=_on_agent_response, + on_agent_transcript=_on_agent_transcript, + ) + + control = _VoiceControlState() + inbound_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=10) + filtered_ws = _FilteredWebSocket(websocket, inbound_queue) + + async def _handle_user_text(text: str) -> None: + await _send_json( + websocket, + {"type": "transcript", "role": "user", "text": text, "is_final": True}, + send_lock, + ) + response = await components["bridge"].process_transcription( + text, emit_frames=True + ) + if response: + await _on_agent_response(response) + + reader_task: asyncio.Task[Any] | None = None + runner_task: asyncio.Task[Any] | None = None + + try: + transport = FastAPIWebsocketTransport( + websocket=filtered_ws, # type: ignore[arg-type] + params=FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + audio_in_sample_rate=app_settings.voice.sample_rate, + audio_out_sample_rate=app_settings.voice.sample_rate, + add_wav_header=False, + serializer=_RawAudioFrameSerializer( + sample_rate=app_settings.voice.sample_rate, + num_channels=app_settings.voice.audio_channels, + ), + ), + ) + + logger.info( + f"Voice pipeline: transport created, sample_rate={app_settings.voice.sample_rate}" + ) + logger.info( + f"Voice pipeline: components - STT={type(components['stt']).__name__}, " + f"Bridge={type(components['bridge']).__name__}, TTS={type(components['tts']).__name__}" + ) + + pipeline_components = [transport.input()] + if components.get("vad"): + pipeline_components.append(components["vad"]) + logger.info("Voice pipeline: VAD enabled and added") + + pipeline_components.extend( + [ + components["stt"], + components["bridge"], + components["tts"], + transport.output(), + ] + ) + logger.info( + f"Voice pipeline: total components in pipeline: {len(pipeline_components)}" + ) + + pipeline = Pipeline(pipeline_components) + logger.info("Voice pipeline: Pipeline created successfully") + + task = PipelineTask(pipeline) + logger.info("Voice pipeline: PipelineTask created, starting runner...") + runner = PipelineRunner() + + runner_task = asyncio.create_task(runner.run(task)) + + # Start reading control/audio only after pipeline runner is live so + # user_text cannot emit TTS frames before StartFrame initialization. + reader_task = asyncio.create_task( + _voice_control_reader( + websocket, + inbound_queue, + control, + vad_enabled=app_settings.voice.vad_enabled, + send_lock=send_lock, + on_user_text=_handle_user_text, + ) + ) + + async with asyncio.timeout(float(app_settings.voice.session_timeout)): + await runner_task + except WebSocketDisconnect: + logger.info(f"Voice WebSocket disconnected: {session_id}") + except TimeoutError: + logger.info(f"Voice session timed out: {session_id}") + if websocket.client_state == WebSocketState.CONNECTED: + await _send_json( + websocket, + {"type": "error", "message": "Voice session timed out"}, + send_lock, + ) + except Exception as e: + logger.exception(f"Error in voice WebSocket: {session_id}: {e}") + if websocket.client_state == WebSocketState.CONNECTED: + user_message, close_code = _classify_voice_pipeline_error(e) + await _send_json( + websocket, {"type": "error", "message": user_message}, send_lock + ) + try: + await websocket.close(code=close_code, reason=user_message) + except Exception: + pass + finally: + if runner_task and not runner_task.done(): + runner_task.cancel() + try: + await runner_task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Voice pipeline runner task failed") + + if reader_task and not reader_task.done(): + reader_task.cancel() + try: + await reader_task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Voice control reader task failed") + + if websocket.client_state == WebSocketState.CONNECTED: + try: + await _send_json( + websocket, {"type": "state", "state": "ended"}, send_lock + ) + except Exception: + pass + + try: + await session_manager.update_state(session_id, "ending") + except Exception: + pass + try: + await components["bridge"].cleanup_background_tasks() + except Exception: + pass + try: + await session_manager.end_session(session_id) + except Exception: + pass + if websocket.client_state == WebSocketState.CONNECTED: + try: + await websocket.close() + except Exception: + pass diff --git a/bindu/server/metrics.py b/bindu/server/metrics.py index b5eb8eab2..24a216e4d 100644 --- a/bindu/server/metrics.py +++ b/bindu/server/metrics.py @@ -7,6 +7,7 @@ from collections import defaultdict from threading import Lock +from typing import cast from bindu.utils.logging import get_logger @@ -359,4 +360,4 @@ def get_metrics() -> PrometheusMetrics: with _metrics_init_lock: if _metrics_instance is None: _metrics_instance = PrometheusMetrics() - return _metrics_instance + return cast(PrometheusMetrics, _metrics_instance) diff --git a/bindu/server/scheduler/memory_scheduler.py b/bindu/server/scheduler/memory_scheduler.py index 609da0b42..f30315c00 100644 --- a/bindu/server/scheduler/memory_scheduler.py +++ b/bindu/server/scheduler/memory_scheduler.py @@ -28,6 +28,10 @@ DEFAULT_RETRY_MIN_WAIT = 0.1 DEFAULT_RETRY_MAX_WAIT = 1.0 +# Bounded buffer prevents unbounded memory growth while allowing the API +# handler to enqueue a task before the worker loop is ready to receive. +_TASK_QUEUE_BUFFER_SIZE = 100 + class InMemoryScheduler(Scheduler): """A scheduler that schedules tasks in memory.""" @@ -37,12 +41,11 @@ async def __aenter__(self): self.aexit_stack = AsyncExitStack() await self.aexit_stack.__aenter__() - # Buffer of 100 prevents deadlock: without buffering the sender blocks - # until a worker is ready to receive, which stalls the API server. - # math.inf was previously used here but removed to restore backpressure. + # Bounded buffer allows the API handler to enqueue tasks before the + # worker loop is ready while preventing unbounded memory growth. self._write_stream, self._read_stream = anyio.create_memory_object_stream[ TaskOperation - ](100) + ](_TASK_QUEUE_BUFFER_SIZE) await self.aexit_stack.enter_async_context(self._read_stream) await self.aexit_stack.enter_async_context(self._write_stream) @@ -60,15 +63,22 @@ async def _send_operation( ) -> None: """Send task operation with live span for trace context. + Uses bounded backpressure with a timeout so brief spikes wait for + worker capacity instead of failing immediately. + Args: operation_class: The operation class to instantiate operation: Operation type string params: Task parameters + + Raises: + TimeoutError: If the queue stays blocked past the timeout window. """ task_op = operation_class( operation=operation, params=params, _current_span=get_current_span() ) - await self._write_stream.send(task_op) + with anyio.fail_after(5.0): + await self._write_stream.send(task_op) @retry_scheduler_operation( max_attempts=DEFAULT_RETRY_ATTEMPTS, diff --git a/bindu/server/workers/base.py b/bindu/server/workers/base.py index 570b29d45..ed0060935 100644 --- a/bindu/server/workers/base.py +++ b/bindu/server/workers/base.py @@ -21,12 +21,18 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from contextlib import asynccontextmanager, nullcontext +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, AsyncIterator import anyio from opentelemetry.trace import get_tracer, use_span +from opentelemetry.trace.span import ( + INVALID_SPAN_CONTEXT, + NonRecordingSpan, + SpanContext, + TraceFlags, +) from bindu.server.scheduler import TaskOperation @@ -39,6 +45,36 @@ logger = get_logger(__name__) +def _reconstruct_span(trace_id: str | None, span_id: str | None) -> NonRecordingSpan: + """Reconstruct a NonRecordingSpan from serialized trace_id/span_id strings. + + Used to restore OpenTelemetry trace context after the scheduler serializes + the span into primitive strings (required for Redis JSON serialization). + + Args: + trace_id: Hex-encoded trace ID (32 chars) or None + span_id: Hex-encoded span ID (16 chars) or None + + Returns: + A NonRecordingSpan that carries the trace context for correlation + """ + if trace_id and span_id: + try: + ctx = SpanContext( + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + return NonRecordingSpan(ctx) + except (ValueError, TypeError): + logger.warning( + f"Invalid trace context: trace_id={trace_id}, span_id={span_id}" + ) + # Return a no-op span with invalid context as fallback + return NonRecordingSpan(INVALID_SPAN_CONTEXT) + + @dataclass class Worker(ABC): """Abstract base worker for A2A protocol task execution. @@ -106,7 +142,10 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: """Dispatch task operation to appropriate handler. Args: - task_operation: Operation dict with 'operation', 'params', and '_current_span' + task_operation: Operation dict with 'operation', 'params', and tracing metadata. + Prefer providing _current_span first. It should be a Span-like object or + a dict containing the live span context. If _current_span is unavailable, + provide trace_id and span_id as strings so the span can be reconstructed. Supported Operations: - run: Execute a task @@ -126,9 +165,19 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: } try: - # Preserve trace context from scheduler (if available) + # Prefer in-memory span when available; otherwise reconstruct + # from serialized trace_id/span_id for Redis-backed operations. span = task_operation.get("_current_span") - ctx_manager = use_span(span) if span else nullcontext() + if span is not None: + ctx_manager = use_span(span) + else: + ctx_manager = use_span( + _reconstruct_span( + task_operation.get("trace_id"), + task_operation.get("span_id"), + ) + ) + with ctx_manager: with tracer.start_as_current_span( f"{task_operation['operation']} task", diff --git a/bindu/settings.py b/bindu/settings.py index 9a2c8d1fc..f19d4d9f1 100644 --- a/bindu/settings.py +++ b/bindu/settings.py @@ -3,7 +3,7 @@ This module defines the configuration settings for the application using pydantic models. """ -from pydantic import Field, computed_field, BaseModel, HttpUrl +from pydantic import BaseModel, Field, HttpUrl, computed_field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic import AliasChoices from typing import Literal @@ -875,6 +875,12 @@ class SentrySettings(BaseSettings): and release health tracking for production deployments. """ + model_config = SettingsConfigDict( + env_file=".env", + env_prefix="SENTRY__", + extra="allow", + ) + # Enable/disable Sentry enabled: bool = False @@ -1016,6 +1022,117 @@ class GrpcSettings(BaseSettings): ) +class VoiceSettings(BaseSettings): + """Voice agent configuration settings. + + Configures the real-time voice pipeline powered by Pipecat, + including STT (Deepgram), TTS (Piper/ElevenLabs/Azure), VAD, + and session management. + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_prefix="VOICE__", + extra="allow", + ) + + # Master toggle + enabled: bool = False + + # Speech-to-Text + stt_provider: Literal["deepgram"] = "deepgram" + stt_api_key: str = "" + stt_model: str = "nova-3" + stt_language: str = "en" + + # Provider URLs + provider_urls: dict[str, str] = Field( + default_factory=lambda: { + "deepgram_listen": "https://api.deepgram.com/v1/listen", + "elevenlabs_tts": "https://api.elevenlabs.io/v1/text-to-speech", + "azure_tts_voices": "https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list", + } + ) + + # Optional credential lifecycle (for short-lived tokens, if used) + token_refresh_endpoint: str = "" + token_expiry_seconds: int = 0 + token_refresh_leeway_seconds: int = 60 + + # HTTP client behavior + http_timeout_seconds: float = 30.0 + + # Text-to-Speech + tts_provider: Literal["elevenlabs", "piper", "azure"] = "elevenlabs" + tts_fallback_provider: Literal["none", "elevenlabs", "azure"] = "none" + tts_api_key: str = "" + tts_voice_id: str = "21m00Tcm4TlvDq8ikWAM" # ElevenLabs "Rachel" + tts_model: str = "eleven_turbo_v2_5" + tts_stability: float = 0.5 + tts_similarity_boost: float = 0.75 + + # Azure Text-to-Speech (used when tts_provider=azure or fallback is azure) + azure_tts_api_key: str = "" + azure_tts_region: str = "" + azure_tts_voice: str = "en-US-SaraNeural" + + # Audio format + sample_rate: int = 16000 + audio_channels: int = 1 + audio_encoding: str = "linear16" # PCM 16-bit + audio_sample_width_bytes: int = 2 + + # Audio chunking and streaming + chunk_overlap_fraction: float = 0.25 + chunk_throttle_ms: int = 800 + + # Voice Activity Detection + vad_enabled: bool = True + vad_threshold: float = 0.5 + + # Behavior + allow_interruptions: bool = True + agent_timeout_secs: float = 10.0 + utterance_timeout_secs: float = 30.0 + retry_attempts: int = 3 + retry_backoff_start_ms: int = 200 + retry_backoff_factor: float = 2.0 + retry_backoff_max_ms: int = 2000 + cancellation_grace_secs: float = 0.5 + conversation_history_limit: int = 50 + conversation_policy: Literal["unsaved", "terminate"] = "unsaved" + session_timeout: int = 300 # seconds (5 min) + max_concurrent_sessions: int = 10 + + # Session storage backend (for multi-worker compatibility) + session_backend: Literal["memory", "redis"] = "memory" + redis_url: str = "" # e.g., "redis://localhost:6379/0" + redis_session_ttl: int = 300 # seconds, TTL for session keys in Redis + + # WebSocket session authentication + session_auth_required: bool = False + session_token_ttl: int = 300 # seconds; must be <= session_timeout + + # Rate limiting (0 disables) + rate_limit_per_ip_per_minute: int = 120 + rate_limit_backend: Literal["memory", "redis"] = "memory" + + # Extension metadata + # Note: bindu:// is an internal routing scheme used by the voice agent extension. + # Consumers should handle this as a special case for internal routing. + extension_uri: str = "bindu://voice" + extension_description: str = "Real-time voice conversation for Bindu agents" + + @model_validator(mode="after") + def _validate_session_token_ttl(self) -> "VoiceSettings": + # Enforce that tokens never outlive sessions. + # If misconfigured via env (e.g. TTL > timeout), clamp TTL to timeout + # to fail safe without preventing the server from starting. + if self.session_token_ttl > self.session_timeout: + self.session_token_ttl = self.session_timeout + return self + + class Settings(BaseSettings): """Main settings class that aggregates all configuration components.""" @@ -1044,6 +1161,7 @@ class Settings(BaseSettings): negotiation: NegotiationSettings = NegotiationSettings() sentry: SentrySettings = SentrySettings() grpc: GrpcSettings = GrpcSettings() + voice: VoiceSettings = VoiceSettings() app_settings = Settings() diff --git a/bindu/utils/__init__.py b/bindu/utils/__init__.py index ef9cb971b..d7e288f0a 100644 --- a/bindu/utils/__init__.py +++ b/bindu/utils/__init__.py @@ -13,6 +13,7 @@ # Core utilities (kept at top level) from .capabilities import ( add_extension_to_capabilities, + get_voice_extension_from_capabilities, get_x402_extension_from_capabilities, ) from .exceptions import ( @@ -41,6 +42,7 @@ "find_skill_by_id", # Capability utilities "add_extension_to_capabilities", + "get_voice_extension_from_capabilities", "get_x402_extension_from_capabilities", # DID utilities "validate_did_extension", diff --git a/bindu/utils/capabilities.py b/bindu/utils/capabilities.py index 9c64aef00..b69b5e812 100644 --- a/bindu/utils/capabilities.py +++ b/bindu/utils/capabilities.py @@ -55,3 +55,21 @@ def get_x402_extension_from_capabilities(manifest: Any) -> Optional[Any]: return ext return None + + +def get_voice_extension_from_capabilities(manifest: Any) -> Optional[Any]: + """Extract Voice extension from manifest capabilities. + + Args: + manifest: Agent manifest with capabilities + + Returns: + VoiceAgentExtension instance if configured, None otherwise + """ + from bindu.extensions.voice import VoiceAgentExtension + + for ext in manifest.capabilities.get("extensions", []): + if isinstance(ext, VoiceAgentExtension): + return ext + + return None diff --git a/bindu/utils/logging.py b/bindu/utils/logging.py index c9ed86208..8909f4e78 100644 --- a/bindu/utils/logging.py +++ b/bindu/utils/logging.py @@ -4,7 +4,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import Optional, cast from loguru import logger from rich.console import Console @@ -42,8 +42,8 @@ def _get_console() -> Console: show_locals=app_settings.logging.show_locals, width=app_settings.logging.traceback_width, ) - # Type narrowing: _console is guaranteed to be Console here - return _console + # Type narrowing: _console is guaranteed to be Console here. + return cast(Console, _console) def configure_logger( diff --git a/bindu/utils/notifications.py b/bindu/utils/notifications.py index 10e76694b..6b53c5ee2 100644 --- a/bindu/utils/notifications.py +++ b/bindu/utils/notifications.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import Any from urllib.parse import urlparse +from urllib import error, request from bindu.common.protocol.types import PushNotificationConfig from bindu.utils.logging import get_logger @@ -36,38 +37,25 @@ ] -def _resolve_and_check_ip(hostname: str) -> str: - """Resolve hostname to an IP address and verify it is not in a blocked range. - - Returns the resolved IP address string so callers can connect directly to it, - preventing a DNS-rebinding attack where a second resolution (inside urlopen) could - return a different—potentially internal—address. - - Args: - hostname: The hostname to resolve and validate. - - Returns: - The resolved IP address as a string. - - Raises: - ValueError: If the hostname cannot be resolved or resolves to a blocked range. - """ +def _resolve_and_check_ip(hostname: str, port: int) -> str: + """Resolve hostname to an IP address and verify it is not in a blocked range.""" try: - resolved_ip = str(socket.getaddrinfo(hostname, None)[0][4][0]) - addr = ipaddress.ip_address(resolved_ip) + infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) + addrs = [ipaddress.ip_address(info[4][0]) for info in infos] except (socket.gaierror, ValueError) as exc: raise ValueError( f"Push notification URL hostname could not be resolved: {exc}" ) from exc - for blocked in _BLOCKED_NETWORKS: - if addr in blocked: - raise ValueError( - f"Push notification URL resolves to a blocked address range " - f"({addr} is in {blocked}). Internal addresses are not allowed." - ) + for addr in addrs: + for blocked in _BLOCKED_NETWORKS: + if addr in blocked: + raise ValueError( + f"Push notification URL resolves to a blocked address range " + f"({addr} is in {blocked}). Internal addresses are not allowed." + ) - return resolved_ip + return str(addrs[0]) class NotificationDeliveryError(Exception): @@ -126,6 +114,8 @@ def validate_config(self, config: PushNotificationConfig) -> str: Returns the resolved IP address so the caller can connect directly to it, eliminating the DNS-rebinding race between validation and connection. + + Validates destination safety before any outbound call. """ parsed = urlparse(config["url"]) if parsed.scheme not in {"http", "https"}: @@ -136,11 +126,45 @@ def validate_config(self, config: PushNotificationConfig) -> str: # SSRF defence: resolve the hostname and reject internal/private addresses. # The returned IP is used directly for the connection so that no second # DNS lookup can return a different (internal) address. + return self._resolve_and_validate_destination(config["url"]) + + @staticmethod + def _resolve_and_validate_destination(url: str) -> str: + parsed = urlparse(url) + scheme = (parsed.scheme or "").lower() + if scheme not in {"http", "https"}: + raise ValueError("Push notification URL must use http or https scheme.") hostname = parsed.hostname if not hostname: raise ValueError("Push notification URL must include a valid hostname.") - return _resolve_and_check_ip(hostname) + port = parsed.port or (443 if scheme == "https" else 80) + + # If hostname is an IP literal, validate without DNS. + try: + addr = ipaddress.ip_address(hostname) + except ValueError: + try: + return _resolve_and_check_ip(hostname, port) + except (socket.gaierror, ValueError) as exc: + if "blocked address range" in str(exc): + raise + logger.warning( + "Push notification hostname resolution failed; blocking registration", + hostname=hostname, + error=str(exc), + ) + raise ValueError( + "Push notification URL hostname could not be resolved." + ) from exc + + for blocked in _BLOCKED_NETWORKS: + if addr in blocked: + raise ValueError( + f"Push notification URL resolves to a blocked address range " + f"({addr} is in {blocked}). Internal addresses are not allowed." + ) + return str(addr) @create_retry_decorator("api", max_attempts=3, min_wait=0.5, max_wait=5.0) async def _post_with_retry( @@ -157,7 +181,7 @@ async def _post_with_retry( try: status = await asyncio.to_thread( - self._post_once, url, resolved_ip, headers, payload + self._post_once, url, headers, payload, resolved_ip ) logger.debug( "Delivered push notification", @@ -191,67 +215,116 @@ async def _post_with_retry( raise def _post_once( - self, url: str, resolved_ip: str, headers: dict[str, str], payload: bytes + self, + url: str, + headers: dict[str, str], + payload: bytes, + resolved_ip: str | None = None, ) -> int: - """POST *payload* to *url*, connecting directly to *resolved_ip*. - - Bypassing a second DNS lookup closes the DNS-rebinding window: the IP - has already been validated in validate_config() and we re-use it here so - that no attacker-controlled DNS TTL change can redirect the connection to - an internal address between validation and delivery. - - For HTTPS, a raw TCP socket is opened to *resolved_ip* and then wrapped - with TLS using *hostname* as the SNI server_hostname, so certificate - validation still uses the original domain name rather than the IP. - """ parsed = urlparse(url) - hostname = parsed.hostname or "" - port = parsed.port or (443 if parsed.scheme == "https" else 80) - path = parsed.path or "/" - if parsed.query: - path = f"{path}?{parsed.query}" - - # Set the Host header explicitly so virtual-host routing works correctly - # even though we are connecting to a raw IP. - host_header = f"{hostname}:{port}" if parsed.port else hostname - + scheme = (parsed.scheme or "").lower() + hostname = parsed.hostname + if scheme not in {"http", "https"}: + raise NotificationDeliveryError( + None, "Push notification URL must use http or https scheme." + ) + if not hostname: + raise NotificationDeliveryError( + None, "Push notification URL must include a valid hostname." + ) try: - if parsed.scheme == "https": - # Open a plain TCP socket to the pre-validated IP, then wrap it - # with TLS using the original hostname for SNI and cert validation. - # This avoids a second DNS lookup while keeping TLS correct. - ctx = ssl.create_default_context() - raw_sock = socket.create_connection( - (resolved_ip, port), timeout=self.timeout - ) - tls_sock = ctx.wrap_socket(raw_sock, server_hostname=hostname) - conn = http.client.HTTPSConnection( - resolved_ip, port, timeout=self.timeout, context=ctx - ) - conn.sock = tls_sock - else: - conn = http.client.HTTPConnection( - resolved_ip, port, timeout=self.timeout + destination_ip = resolved_ip or self._resolve_and_validate_destination(url) + except ValueError as exc: + raise NotificationDeliveryError(None, str(exc)) from exc + + default_port = 443 if scheme == "https" else 80 + port = parsed.port or default_port + host_header = f"[{hostname}]" if ":" in hostname else hostname + if parsed.port is not None and parsed.port != default_port: + host_header = f"{host_header}:{parsed.port}" + + if scheme == "https": + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + + context = ssl.create_default_context() + request_headers = {"Host": host_header, "Connection": "close", **headers} + request_headers["Content-Length"] = str(len(payload)) + try: + request_lines = [f"POST {path} HTTP/1.1\r\n"] + request_lines.extend( + f"{key}: {value}\r\n" for key, value in request_headers.items() ) + request_lines.append("\r\n") + request_bytes = "".join(request_lines).encode("latin-1") + payload + + with socket.create_connection( + (destination_ip, port), timeout=self.timeout + ) as sock: + with context.wrap_socket( + sock, server_hostname=hostname + ) as tls_socket: + tls_socket.sendall(request_bytes) + response = http.client.HTTPResponse(tls_socket) + response.begin() + status = response.status + body = response.read() or b"" + if 200 <= status < 300: + return status + message = body.decode("utf-8", errors="ignore").strip() + raise NotificationDeliveryError( + status, message or f"HTTP error {status}" + ) + except (OSError, http.client.HTTPException, ssl.SSLError) as exc: + raise NotificationDeliveryError( + None, f"Connection error: {exc}" + ) from exc + + target_url = url + if isinstance(destination_ip, str) and destination_ip: + destination_host = ( + f"[{destination_ip}]" if ":" in destination_ip else destination_ip + ) + netloc = ( + f"{destination_host}:{parsed.port}" + if parsed.port is not None + else destination_host + ) + target_url = parsed._replace(netloc=netloc).geturl() - request_headers = dict(headers) - request_headers["Host"] = host_header - - conn.request("POST", path, body=payload, headers=request_headers) - response = conn.getresponse() - status = response.status - if 200 <= status < 300: - return status + req = request.Request(target_url, data=payload, method="POST") + req.add_header("Host", host_header) + for key, value in headers.items(): + req.add_header(key, value) + try: + # URL scheme is validated in validate_config() to only allow http/https + with request.urlopen(req, timeout=self.timeout) as response: # nosec B310 + status = response.getcode() + if 200 <= status < 300: + return status + raise NotificationDeliveryError( + status, f"Unexpected status code: {status}" + ) + except error.HTTPError as exc: + status = exc.code body = b"" try: - body = response.read() or b"" + body = exc.read() or b"" except OSError: body = b"" message = body.decode("utf-8", errors="ignore").strip() - raise NotificationDeliveryError(status, message or f"HTTP error {status}") + raise NotificationDeliveryError( + status, message or f"HTTP error {status}" + ) from exc except NotificationDeliveryError: raise + except error.URLError as exc: + reason = getattr(exc, "reason", exc) + raise NotificationDeliveryError( + None, f"Connection error: {reason}" + ) from exc except (OSError, http.client.HTTPException) as exc: raise NotificationDeliveryError(None, f"Connection error: {exc}") from exc diff --git a/bindu/utils/retry.py b/bindu/utils/retry.py index b025b90d8..418c3335b 100644 --- a/bindu/utils/retry.py +++ b/bindu/utils/retry.py @@ -139,8 +139,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: reraise=True, ): with attempt: + func_name = getattr(func, "__name__", func.__class__.__name__) logger.debug( - f"Executing {operation_type} operation {func.__name__} " + f"Executing {operation_type} operation {func_name} " f"(attempt {attempt.retry_state.attempt_number}/{_max_attempts})" ) return await func(*args, **kwargs) diff --git a/bindu/utils/worker/messages.py b/bindu/utils/worker/messages.py index 993b3b743..b76eee30e 100644 --- a/bindu/utils/worker/messages.py +++ b/bindu/utils/worker/messages.py @@ -18,6 +18,8 @@ logger = get_logger("bindu.utils.worker.messages") +MAX_FILE_SIZE = 10 * 1024 * 1024 + # Type aliases for better readability ChatMessage = dict[str, str] ProtocolMessage = Message @@ -52,6 +54,22 @@ def _extract_docx(file_bytes: bytes) -> str: logger.error(f"Failed to parse DOCX: {e}") return "[Error: Could not parse DOCX content]" + @staticmethod + def _decode_plain_text(file_bytes: bytes) -> str: + """Decode plain text with UTF-8 first and safe fallbacks.""" + for encoding in ("utf-8", "cp1252", "latin-1"): + try: + if encoding == "utf-8": + return file_bytes.decode(encoding) + text = file_bytes.decode(encoding) + logger.info(f"Decoded plain text file using {encoding}") + return text + except UnicodeDecodeError: + continue + + logger.warning("Falling back to replacement decoding for plain text file") + return file_bytes.decode("utf-8", errors="replace") + @classmethod def intercept_and_parse(cls, parts: list[Part]) -> list[dict[str, Any]]: """Intercept file parts, extract text, and replace with text parts.""" @@ -62,29 +80,51 @@ def intercept_and_parse(cls, parts: list[Part]) -> list[dict[str, Any]]: processed_parts.append(part) continue - mime_type = part.get("mimeType", "") - base64_data = str(part.get("data", "")) + file_info = part.get("file") or {} + mime_type = file_info.get("mimeType", "") + file_name = file_info.get("name", "uploaded file") + base64_data = file_info.get("bytes") or file_info.get("data", "") if mime_type not in cls.SUPPORTED_MIME_TYPES: logger.warning(f"Unsupported MIME type rejected: {mime_type}") processed_parts.append( { "kind": "text", - "text": f"[Unsupported file type: {mime_type}]", + "text": ( + f"[System: User uploaded an unsupported file format " + f"({mime_type or 'unknown'}) for {file_name}]" + ), } ) continue try: # Decode the Base64 payload + if not base64_data: + raise ValueError("Missing file bytes") + + padding = ( + 2 + if base64_data.endswith("==") + else 1 + if base64_data.endswith("=") + else 0 + ) + estimated_size = (len(base64_data) * 3) // 4 - padding + if estimated_size > MAX_FILE_SIZE: + raise ValueError("File too large") + file_bytes = base64.b64decode(base64_data) + if len(file_bytes) > MAX_FILE_SIZE: + raise ValueError("File too large") + extracted_text = "" # Route to specific parser based on MIME type if mime_type == "application/pdf": extracted_text = cls._extract_pdf(file_bytes) elif mime_type == "text/plain": - extracted_text = file_bytes.decode("utf-8") + extracted_text = cls._decode_plain_text(file_bytes) elif ( mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" @@ -95,12 +135,16 @@ def intercept_and_parse(cls, parts: list[Part]) -> list[dict[str, Any]]: processed_parts.append( { "kind": "text", - "text": f"--- Document Uploaded ---\n{extracted_text}\n--- End of Document ---", + "text": ( + f"--- Document Uploaded: {file_name} ({mime_type}) ---\n" + f"{extracted_text}\n" + f"--- End of Document ---" + ), } ) except Exception as e: - logger.error(f"Base64 decoding or routing failed: {e}") + logger.exception(f"Base64 decoding or routing failed: {e}") processed_parts.append( { "kind": "text", diff --git a/bindu/utils/worker/parts.py b/bindu/utils/worker/parts.py index 835652026..31fd65bec 100644 --- a/bindu/utils/worker/parts.py +++ b/bindu/utils/worker/parts.py @@ -35,8 +35,7 @@ def dict_to_part(data: dict[str, Any]) -> Part: return part_class(**data) # Fallback: convert unknown dict to DataPart - # DataPart requires 'text' field even though it's a data part - return DataPart(kind="data", data=data, text="") + return DataPart(kind="data", data=data) @staticmethod def result_to_parts(result: Any) -> list[Part]: diff --git a/frontend/.env b/frontend/.env.example similarity index 82% rename from frontend/.env rename to frontend/.env.example index a03e77c7c..3d8c2c121 100644 --- a/frontend/.env +++ b/frontend/.env.example @@ -7,7 +7,7 @@ OPENAI_BASE_URL=https://router.huggingface.co/v1 # Canonical auth token for any OpenAI-compatible provider -OPENAI_API_KEY=#your provider API key (works for HF router, OpenAI, LM Studio, etc.). +OPENAI_API_KEY= # When set to true, user token will be used for inference calls USE_USER_TOKEN=false # Automatically redirect to oauth login page if user is not logged in, when set to "true" @@ -20,21 +20,19 @@ AUTOMATIC_LOGIN=false ## Public app configuration ## PUBLIC_APP_NAME="Chat with Bindu" # name used as title throughout the app PUBLIC_APP_ASSETS=chatui # used to find logos & favicons in static/$PUBLIC_APP_ASSETS -PUBLIC_APP_DESCRIPTION="Chat with Bindu AI agents"# description used throughout the app +PUBLIC_APP_DESCRIPTION="Chat with Bindu AI agents" # description used throughout the app PUBLIC_SMOOTH_UPDATES=false # set to true to enable smoothing of messages client-side, can be CPU intensive +PUBLIC_AGENT_BASE_URL=http://localhost:3773 # base URL for direct agent API calls from the client PUBLIC_ORIGIN= PUBLIC_SHARE_PREFIX= PUBLIC_GOOGLE_ANALYTICS_ID= PUBLIC_PLAUSIBLE_SCRIPT_URL= PUBLIC_APPLE_APP_ID= -COUPLE_SESSION_WITH_COOKIE_NAME= # when OPEN_ID is configured, users are required to login after the welcome modal OPENID_CLIENT_ID="" # You can set to "__CIMD__" for automatic oauth app creation when deployed, see https://datatracker.ietf.org/doc/draft-ietf-oauth-client-id-metadata-document/ OPENID_CLIENT_SECRET= OPENID_SCOPES="openid profile inference-api read-mcp read-billing" -USE_USER_TOKEN= -AUTOMATIC_LOGIN=# if true authentication is required on all routes ## Models overrides @@ -69,7 +67,7 @@ PUBLIC_LLM_ROUTER_ALIAS_ID=omni # Voice-to-text transcription using Whisper models # If set, enables the microphone button in the chat input # Example: openai/whisper-large-v3-turbo -TRANSCRIPTION_MODEL= +TRANSCRIPTION_MODEL=openai/whisper-large-v3-turbo # Optional: Base URL for transcription API (defaults to HF inference) # Default: https://router.huggingface.co/hf-inference/models TRANSCRIPTION_BASE_URL= @@ -88,6 +86,7 @@ ALTERNATIVE_REDIRECT_URLS=[] COOKIE_NAME=hf-chat # If the value of this cookie changes, the session is destroyed. Useful if chat-ui is deployed on a subpath # of your domain, and you want chat ui sessions to reset if the user's auth changes +# Canonical cookie name for coupling sessions with cookies COUPLE_SESSION_WITH_COOKIE_NAME= # specify secure behaviour for cookies COOKIE_SAMESITE=# can be "lax", "strict", "none" or left empty @@ -96,12 +95,13 @@ TRUSTED_EMAIL_HEADER=# header to use to get the user email, only use if you know ### Admin stuff ### ADMIN_CLI_LOGIN=true # set to false to disable the CLI login -ADMIN_TOKEN=#We recommend leaving this empty, you can get the token from the terminal. +# We recommend leaving this empty, you can get the token from the terminal. +ADMIN_TOKEN= ### Feature Flags ### LLM_SUMMARIZATION=true # generate conversation titles with LLMs -ALLOW_IFRAME=true # Allow the app to be embedded in an iframe +ALLOW_IFRAME=false # Disallow embedding in iframes; set to true only if intentionally needed ### Bindu Agent Configuration ### # Base URL of the Bindu agent server (A2A Protocol) @@ -117,6 +117,18 @@ BINDU_ONLY_MODE=false # Timeout for Bindu requests in milliseconds (default: 10000) BINDU_TIMEOUT_MS=10000 +### Voice ### +# Enable the voice session overlay and WebSocket voice pipeline +VOICE__ENABLED=false +# Deepgram STT API key used by the voice extension +VOICE__STT_API_KEY= +# ElevenLabs TTS API key used by the voice extension +VOICE__TTS_API_KEY= +# Browser and server audio should stay aligned on 16 kHz mono PCM +VOICE__SAMPLE_RATE=16000 +VOICE__AUDIO_CHANNELS=1 +VOICE__AUDIO_ENCODING=linear16 + ### Rate limits ### # See `src/lib/server/usageLimits.ts` # { @@ -134,7 +146,7 @@ USAGE_LIMITS={} # Used for setting early access & admin flags to users HF_ORG_ADMIN= HF_ORG_EARLY_ACCESS= -WEBHOOK_URL_REPORT_ASSISTANT=#provide slack webhook url to get notified for reports/feature requests +WEBHOOK_URL_REPORT_ASSISTANT= # provide slack webhook url to get notified for reports/feature requests ### Metrics ### @@ -147,7 +159,7 @@ LOG_LEVEL=info # Not in use anymore but useful to export conversations to a parquet file as a HuggingFace dataset PARQUET_EXPORT_DATASET= PARQUET_EXPORT_HF_TOKEN= -ADMIN_API_SECRET=# secret to admin API calls, like computing usage stats or exporting parquet data +ADMIN_API_SECRET= # secret to admin API calls, like computing usage stats or exporting parquet data ### Config ### ENABLE_CONFIG_MANAGER=true @@ -155,7 +167,7 @@ ENABLE_CONFIG_MANAGER=true ### Docker build variables ### # These values cannot be updated at runtime # They need to be passed when building the docker image -# See https://github.com/huggingface/chat-ui/main/.github/workflows/deploy-prod.yml#L44-L47 +# See https://github.com/huggingface/chat-ui/blob/main/.github/workflows/deploy-prod.yml#L44-L47 APP_BASE="" # base path of the app, e.g. /chat, left blank as default ### Body size limit for SvelteKit https://svelte.dev/docs/kit/adapter-node#Environment-variables-BODY_SIZE_LIMIT BODY_SIZE_LIMIT=15728640 @@ -163,7 +175,7 @@ PUBLIC_COMMIT_SHA= ### LEGACY parameters ALLOW_INSECURE_COOKIES=false # LEGACY! Use COOKIE_SECURE and COOKIE_SAMESITE instead -PARQUET_EXPORT_SECRET=#DEPRECATED, use ADMIN_API_SECRET instead +PARQUET_EXPORT_SECRET= # DEPRECATED, use ADMIN_API_SECRET instead RATE_LIMIT= # /!\ DEPRECATED definition of messages per minute. Use USAGE_LIMITS.messagesPerMinute instead OPENID_NAME_CLAIM="name" # Change to "username" for some providers that do not provide name OPENID_PROVIDER_URL=https://huggingface.co # for Google, use https://accounts.google.com diff --git a/frontend/.gitignore b/frontend/.gitignore index a97fcbc00..34761100f 100644 --- a/frontend/.gitignore +++ b/frontend/.gitignore @@ -7,10 +7,11 @@ node_modules .env.* vite.config.js.timestamp-* vite.config.ts.timestamp-* +tsconfig.tsbuildinfo SECRET_CONFIG .idea !.env.ci -!.env +!.env.example gcp-*.json db data diff --git a/frontend/src/lib/buildPrompt.ts b/frontend/src/lib/buildPrompt.ts index 4d7458db0..1a392599e 100644 --- a/frontend/src/lib/buildPrompt.ts +++ b/frontend/src/lib/buildPrompt.ts @@ -26,7 +26,7 @@ export async function buildPrompt({ }) // Not super precise, but it's truncated in the model's backend anyway .split(" ") - .slice(-(model.parameters?.truncate ?? 0)) + .slice(-(model.parameters?.truncate ?? Infinity)) .join(" "); return prompt; diff --git a/frontend/src/lib/components/ShareConversationModal.svelte b/frontend/src/lib/components/ShareConversationModal.svelte index 402d8874a..700656cca 100644 --- a/frontend/src/lib/components/ShareConversationModal.svelte +++ b/frontend/src/lib/components/ShareConversationModal.svelte @@ -23,10 +23,15 @@ let justCopied = $state(false); async function handleCreate() { + const id = page.params.id; + if (!id) { + errorMsg = "No conversation selected"; + return; + } try { creating = true; errorMsg = null; - createdUrl = await createShareLink(page.params.id); + createdUrl = await createShareLink(id); } catch (e) { errorMsg = (e as Error).message || "Could not create link"; } finally { diff --git a/frontend/src/lib/components/chat/ChatInput.svelte b/frontend/src/lib/components/chat/ChatInput.svelte index 7e338afda..1106654d5 100644 --- a/frontend/src/lib/components/chat/ChatInput.svelte +++ b/frontend/src/lib/components/chat/ChatInput.svelte @@ -13,6 +13,7 @@ import CarbonClose from "~icons/carbon/close"; import UrlFetchModal from "./UrlFetchModal.svelte"; import { TEXT_MIME_ALLOWLIST, IMAGE_MIME_ALLOWLIST_DEFAULT } from "$lib/constants/mime"; + import { error } from "$lib/stores/errors"; import { isVirtualKeyboard } from "$lib/utils/isVirtualKeyboard"; import { requireAuthUser } from "$lib/utils/auth"; @@ -32,7 +33,10 @@ children?: import("svelte").Snippet; onPaste?: (e: ClipboardEvent) => void; focused?: boolean; - onsubmit?: () => void; + onsubmit?: ( + message: string, + fileParts: Array<{ name: string; mime: string; value: string }> + ) => Promise | boolean; } let { @@ -155,7 +159,7 @@ adjustTextareaHeight(); }); - function handleKeydown(event: KeyboardEvent) { + async function handleKeydown(event: KeyboardEvent) { if ( event.key === "Enter" && !event.shiftKey && @@ -164,11 +168,51 @@ value.trim() !== "" ) { event.preventDefault(); - tick(); - onsubmit?.(); + await tick(); + let fileParts = []; + try { + fileParts = await getFileParts(); + } catch (err) { + console.error("Error reading file parts:", err); + $error = `Error reading file: ${err instanceof Error ? err.message : String(err)}`; + // Reset file input state + if (fileInputEl) fileInputEl.value = ""; + files = []; + return; + } + + try { + if (typeof onsubmit === "function") { + const submitted = await onsubmit(value, fileParts); + if (submitted) { + // Clear files and input after successful submit + files = []; + value = ""; + } + } + } catch (err) { + console.error("Error submitting message:", err); + alert( + `Error submitting message: ${err instanceof Error ? err.message : String(err)}` + ); + } } } + // Helper to convert files to base64 parts for agentMessageHandler + async function getFileParts() { + if (!files || files.length === 0) return []; + const file2base64 = (await import("$lib/utils/file2base64")).default; + const fileParts = await Promise.all( + files.map(async (file) => ({ + name: file.name, + mime: file.type || "application/octet-stream", + value: await file2base64(file), + })) + ); + return fileParts; + } + function handleFocus() { if (requireAuthUser()) { return; diff --git a/frontend/src/lib/components/chat/ChatMessage.svelte b/frontend/src/lib/components/chat/ChatMessage.svelte index ae0de78cb..d401bfe1d 100644 --- a/frontend/src/lib/components/chat/ChatMessage.svelte +++ b/frontend/src/lib/components/chat/ChatMessage.svelte @@ -67,6 +67,22 @@ } } + function handleMessageClick(e: MouseEvent) { + const target = e.target as HTMLElement | null; + if ( + target && + target.closest( + "button, a, input, textarea, select, [role='button'], [data-no-toggle]" + ) + ) { + return; + } + handleContentClick(e); + if (!(e.target instanceof HTMLImageElement)) { + isTapped = !isTapped; + } + } + $effect(() => { // referenced to appease linter for currently-unused props void _isAuthor; @@ -252,9 +268,15 @@ : ''}" data-message-id={message.id} data-message-role="assistant" - role="presentation" - onclick={() => (isTapped = !isTapped)} - onkeydown={() => (isTapped = !isTapped)} + role="button" + tabindex="0" + onclick={handleMessageClick} + onkeydown={(event) => { + if (event.key === "Enter" || event.key === " ") { + event.preventDefault(); + isTapped = !isTapped; + } + }} > -
+
{#if isLast && loading && blocks.length === 0} {/if} diff --git a/frontend/src/lib/components/chat/ChatWindow.svelte b/frontend/src/lib/components/chat/ChatWindow.svelte index d9646d156..39989b90d 100644 --- a/frontend/src/lib/components/chat/ChatWindow.svelte +++ b/frontend/src/lib/components/chat/ChatWindow.svelte @@ -6,7 +6,10 @@ import IconMic from "~icons/lucide/mic"; import ChatInput from "./ChatInput.svelte"; + import { sendAgentMessage } from "$lib/utils/agentMessageHandler"; import VoiceRecorder from "./VoiceRecorder.svelte"; + import VoiceCallButton from "$lib/components/voice/VoiceCallButton.svelte"; + import VoiceCallPanel from "$lib/components/voice/VoiceCallPanel.svelte"; import StopGeneratingBtn from "../StopGeneratingBtn.svelte"; import type { Model } from "$lib/types/Model"; import FileDropzone from "./FileDropzone.svelte"; @@ -29,7 +32,7 @@ import ReplyIndicator from "./ReplyIndicator.svelte"; import { agentInspector, resetAgentInspector } from "$lib/stores/agentInspector"; - import { fly } from "svelte/transition"; + import { fade, fly } from "svelte/transition"; import { cubicInOut } from "svelte/easing"; import { isVirtualKeyboard } from "$lib/utils/isVirtualKeyboard"; @@ -40,7 +43,13 @@ isMessageToolErrorUpdate, isMessageToolResultUpdate, } from "$lib/utils/messageUpdates"; + import { MessageUpdateType } from "$lib/types/MessageUpdate"; import type { ToolFront } from "$lib/types/Tool"; + import { + startVoiceSession, + voiceSessionId, + voiceError, + } from "$lib/stores/voice"; interface Props { messages?: Message[]; @@ -139,11 +148,53 @@ ); let isTouchDevice = $derived(browser && navigator.maxTouchPoints > 0); - const handleSubmit = () => { - if (requireAuthUser() || loading || !draft) return; - onmessage?.(draft); - draft = ""; - }; + async function handleSubmit(): Promise { + if (!draft) { + return false; + } + try { + const fileParts = (sources ? await Promise.all(sources) : []).filter(Boolean) as MessageFile[]; + return await submit(draft, { fileParts }); + } catch (err) { + console.error("Error preparing chat submission:", err); + $error = err instanceof Error ? err.message : String(err); + return false; + } + } + + async function submit( + message: string, + options: { fileParts?: MessageFile[] } = {} + ): Promise { + const fileParts = options.fileParts ?? []; + if (!message || loading || isReadOnly) return false; + if (requireAuthUser()) return false; + + const contextId = agentContextId ?? undefined; + let success = false; + try { + let streamed = ""; + for await (const update of sendAgentMessage(message, contextId, { fileParts })) { + if (update.type === MessageUpdateType.Stream) { + streamed += update.token ?? ""; + onmessage?.(streamed); + } else if (update.type === MessageUpdateType.FinalAnswer) { + onmessage?.(update.text ?? ""); + } + } + success = true; + } catch (err) { + console.error("Error sending agent message:", err); + $error = err instanceof Error ? err.message : String(err); + } + + if (success) { + draft = ""; + files = []; + } + + return success; + } let lastTarget: EventTarget | null = null; @@ -251,11 +302,11 @@ }); let sources = $derived( - files?.map>((file) => + files?.map((file) => file2base64(file).then((value) => ({ type: "base64", value, - mime: file.type, + mime: file.type || "application/octet-stream", name: file.name, })) ) @@ -309,15 +360,15 @@ import { TEXT_MIME_ALLOWLIST, IMAGE_MIME_ALLOWLIST_DEFAULT, DOCUMENT_MIME_ALLOWLIST } from "$lib/constants/mime"; let activeMimeTypes = $derived( - Array.from( - new Set([ - ...TEXT_MIME_ALLOWLIST, - ...DOCUMENT_MIME_ALLOWLIST, - ...(modelIsMultimodal - ? (currentModel.multimodalAcceptedMimetypes ?? [...IMAGE_MIME_ALLOWLIST_DEFAULT]) - : []), - ]) - ) + Array.from( + new Set([ + ...TEXT_MIME_ALLOWLIST, + ...DOCUMENT_MIME_ALLOWLIST, + ...(modelIsMultimodal + ? (currentModel.multimodalAcceptedMimetypes ?? [...IMAGE_MIME_ALLOWLIST_DEFAULT]) + : []), + ]) + ) ); let isFileUploadEnabled = $derived(activeMimeTypes.length > 0); let focused = $state(false); @@ -372,7 +423,7 @@ if (trimmedText) { // Set draft and send immediately draft = draft.trim() ? `${draft.trim()} ${trimmedText}` : trimmedText; - handleSubmit(); + await handleSubmit(); } } catch (err) { console.error("Transcription error:", err); @@ -387,6 +438,31 @@ isRecording = false; $error = message; } + + async function toggleVoiceSession() { + if ($voiceSessionId) { + return; + } + + try { + await startVoiceSession(agentContextId ?? undefined); + } catch (err) { + console.error("Voice session error:", err); + $error = (err as Error).message || "Failed to start voice session"; + } + } + + let previousVoiceError: string | null = null; + $effect(() => { + if ($voiceError) { + $error = $voiceError; + previousVoiceError = $voiceError; + } else if ($error && $error === previousVoiceError) { + // Voice error was cleared, and the global error is that voice error, so clear it. + $error = undefined; + previousVoiceError = null; + } + }); {#each [ { text: "Generate an image", icon: "šŸŽØ" }, @@ -473,14 +549,14 @@ { text: "Find a dataset", icon: "šŸ“Š" }, { text: "Gift ideas", icon: "šŸŽ" } ] as prompt} - {/each} @@ -495,6 +571,10 @@
+ {#if $voiceSessionId} + + {/if} +
{ + files = files.filter((_, i) => i !== index); + }} /> {/await} {/each} @@ -520,7 +603,11 @@ {/if}
- + {#if $voiceError && !$voiceSessionId} +

+ {$voiceError} +

+ {/if}
{#if !loading && lastIsError} @@ -543,15 +630,15 @@
-
{ - e.preventDefault(); - handleSubmit(); - }} - class="composer {isReadOnly ? 'opacity-30' : ''} {focused && isVirtualKeyboard() ? 'max-sm:mb-4' : ''} {pastedLongContent ? 'paste-glow' : ''}" - > + { + e.preventDefault(); + await handleSubmit(); + }} + class="composer {isReadOnly ? 'opacity-30' : ''} {focused && isVirtualKeyboard() ? 'max-sm:mb-4' : ''} {pastedLongContent ? 'paste-glow' : ''}" + > {#if isRecording || isTranscribing} {:else} - {#if lastIsError} - - {:else} - - {#if loading} - onstop?.()} - showBorder={true} - classNames="composer-btn icon-btn" - /> - {:else} - {#if transcriptionEnabled} - - {/if} +
+ {#if lastIsError} + + {:else} + { + return await submit(message, { fileParts: fileParts as MessageFile[] }); + }} + {onPaste} + disabled={isReadOnly || lastIsError} + {modelIsMultimodal} + {modelSupportsTools} + bind:focused + /> + {/if} + + {#if loading} + onstop?.()} + showBorder={true} + classNames="absolute bottom-2 right-2 size-8 sm:size-7 self-end rounded-full border bg-white text-black shadow transition-none dark:border-transparent dark:bg-gray-600 dark:text-white" + /> + {:else} + {#if transcriptionEnabled} {/if} - - {/if} + + + {/if} +
{/if}
diff --git a/frontend/src/lib/components/voice/LiveTranscript.svelte b/frontend/src/lib/components/voice/LiveTranscript.svelte new file mode 100644 index 000000000..150758d69 --- /dev/null +++ b/frontend/src/lib/components/voice/LiveTranscript.svelte @@ -0,0 +1,24 @@ + + +
+ {#if items.length === 0} +

Voice transcript will appear here...

+ {:else} + {#each items as item, idx (item.id)} +
+
+ {item.role === 'user' ? 'You' : 'Agent'} +
+
{item.text}
+
+ {/each} + {/if} +
diff --git a/frontend/src/lib/components/voice/VoiceCallButton.svelte b/frontend/src/lib/components/voice/VoiceCallButton.svelte new file mode 100644 index 000000000..6e56971e0 --- /dev/null +++ b/frontend/src/lib/components/voice/VoiceCallButton.svelte @@ -0,0 +1,23 @@ + + + diff --git a/frontend/src/lib/components/voice/VoiceCallPanel.svelte b/frontend/src/lib/components/voice/VoiceCallPanel.svelte new file mode 100644 index 000000000..667ac9302 --- /dev/null +++ b/frontend/src/lib/components/voice/VoiceCallPanel.svelte @@ -0,0 +1,238 @@ + + +
+
+
+
Voice Session
+
State: {$voiceState}
+
+
+ + +
+
+ + {#if $voiceError} +

+ {$voiceError} +

+ {/if} + + +
diff --git a/frontend/src/lib/constants/mime.ts b/frontend/src/lib/constants/mime.ts index 56ec0656b..d5eec3b11 100644 --- a/frontend/src/lib/constants/mime.ts +++ b/frontend/src/lib/constants/mime.ts @@ -10,6 +10,8 @@ export const TEXT_MIME_ALLOWLIST = [ export const IMAGE_MIME_ALLOWLIST_DEFAULT = ["image/jpeg", "image/png"] as const; +// Document types permitted as attachments. +// Includes binary document formats (PDF, Word, Excel, PowerPoint) plus text/plain and text/csv. export const DOCUMENT_MIME_ALLOWLIST = [ "application/pdf", "application/msword", diff --git a/frontend/src/lib/jobs/refresh-conversation-stats.ts b/frontend/src/lib/jobs/refresh-conversation-stats.ts index df5a6a59e..d0dc45dc5 100644 --- a/frontend/src/lib/jobs/refresh-conversation-stats.ts +++ b/frontend/src/lib/jobs/refresh-conversation-stats.ts @@ -254,7 +254,9 @@ async function computeStats(params: { }, ]; - await collections.conversations.aggregate(pipeline, { allowDiskUse: true }).next(); + await collections.conversations + .aggregate(pipeline as Record[], { allowDiskUse: true }) + .next(); logger.debug( { minDate, dateField: params.dateField, span: params.span, type: params.type }, diff --git a/frontend/src/lib/migrations/lock.ts b/frontend/src/lib/migrations/lock.ts index f78e73476..b0a6f6800 100644 --- a/frontend/src/lib/migrations/lock.ts +++ b/frontend/src/lib/migrations/lock.ts @@ -1,6 +1,6 @@ import { collections } from "$lib/server/database"; import ObjectId from "bson-objectid"; -import type { Semaphores } from "$lib/types/Semaphore"; +import type { Semaphore, Semaphores } from "$lib/types/Semaphore"; /** * Returns the lock id if the lock was acquired, false otherwise @@ -8,14 +8,15 @@ import type { Semaphores } from "$lib/types/Semaphore"; export async function acquireLock(key: Semaphores | string): Promise { try { const id = new ObjectId(); - - const insert = await collections.semaphores.insertOne({ + const lockDocument = { _id: id, key, createdAt: new Date(), updatedAt: new Date(), deleteAt: new Date(Date.now() + 1000 * 60 * 3), // 3 minutes - }); + } satisfies Semaphore & { _id: ObjectId }; + + const insert = await collections.semaphores.insertOne(lockDocument); return insert.acknowledged ? id : false; // true if the document was inserted } catch (e) { diff --git a/frontend/src/lib/migrations/routines/02-update-assistants-models.ts b/frontend/src/lib/migrations/routines/02-update-assistants-models.ts index e2d6df9a6..c732fa356 100644 --- a/frontend/src/lib/migrations/routines/02-update-assistants-models.ts +++ b/frontend/src/lib/migrations/routines/02-update-assistants-models.ts @@ -17,7 +17,7 @@ const updateAssistantsModels: Migration = { // Find all assistants whose modelId is not in modelIds, and update it const bulkOps = await assistants .find({ modelId: { $nin: modelIds } }) - .map((assistant) => { + .map((assistant: { _id: unknown; modelId?: string }) => { // has an old model let newModelId = defaultModelId; diff --git a/frontend/src/lib/migrations/routines/10-update-reports-assistantid.ts b/frontend/src/lib/migrations/routines/10-update-reports-assistantid.ts index bd5590ed5..10e2e3532 100644 --- a/frontend/src/lib/migrations/routines/10-update-reports-assistantid.ts +++ b/frontend/src/lib/migrations/routines/10-update-reports-assistantid.ts @@ -6,22 +6,28 @@ const migration: Migration = { _id: new ObjectId("000000000000000000000010"), name: "Update reports with assistantId to use contentId", up: async () => { - await collections.reports.updateMany( - { + const reports = await collections.reports + .find({ assistantId: { $exists: true, $ne: null }, - }, - [ + }) + .toArray(); + + for (const report of reports) { + const assistantId = report.assistantId; + if (!assistantId) { + continue; + } + await collections.reports.updateOne( + { _id: report._id }, { $set: { object: "assistant", - contentId: "$assistantId", + contentId: assistantId, }, - }, - { - $unset: "assistantId", - }, - ] - ); + $unset: { assistantId: "" }, + } + ); + } return true; }, }; diff --git a/frontend/src/lib/server/config.ts b/frontend/src/lib/server/config.ts index da371c613..2b5e2ee8d 100644 --- a/frontend/src/lib/server/config.ts +++ b/frontend/src/lib/server/config.ts @@ -141,10 +141,19 @@ class ConfigManager { }; } + const configWithDefaults = { + ...(config as Record), + PUBLIC_AGENT_BASE_URL: + (config as Record).PUBLIC_AGENT_BASE_URL || + "http://localhost:3773", + }; + const publicEnvKeys = Object.keys(publicEnv); return Object.fromEntries( - Object.entries(config).filter(([key]) => publicEnvKeys.includes(key)) + Object.entries(configWithDefaults).filter( + ([key]) => publicEnvKeys.includes(key) || key === "PUBLIC_AGENT_BASE_URL" + ) ) as Record; } } @@ -162,6 +171,7 @@ type ExtraConfigKeys = | "HF_TOKEN" | "OLD_MODELS" | "ENABLE_ASSISTANTS" + | "ENABLE_DATA_EXPORT" | "METRICS_ENABLED" | "METRICS_PORT"; diff --git a/frontend/src/lib/server/database.ts b/frontend/src/lib/server/database.ts index ef5eb818d..6e2cab19d 100644 --- a/frontend/src/lib/server/database.ts +++ b/frontend/src/lib/server/database.ts @@ -59,13 +59,26 @@ class InMemoryCollection> { this.name = name; } - async findOne(filter: Record = {}): Promise { - for (const doc of this.data.values()) { - if (this.matchesFilter(doc, filter)) { - return doc; + async findOne( + filter: Record = {}, + options?: { sort?: Record } + ): Promise { + const docs = Array.from(this.data.values()); + const matched = docs.filter((doc) => this.matchesFilter(doc, filter)); + if (options?.sort) { + const [sortKey, sortDir] = Object.entries(options.sort)[0] ?? []; + if (sortKey && sortDir) { + matched.sort((a, b) => { + const aVal = this.getNestedValue(a, sortKey); + const bVal = this.getNestedValue(b, sortKey); + if (aVal === bVal) return 0; + if (aVal === undefined || aVal === null) return 1; + if (bVal === undefined || bVal === null) return -1; + return aVal < bVal ? -1 * sortDir : 1 * sortDir; + }); } } - return null; + return matched[0] ?? null; } find(filter: Record = {}): InMemoryCursor { @@ -78,9 +91,25 @@ class InMemoryCollection> { return new InMemoryCursor(results); } - async insertOne(doc: T): Promise<{ insertedId: ObjectId; acknowledged: boolean }> { - const id = new ObjectId(); - const docWithId = { ...doc, _id: id } as T; + async insertOne( + doc: Partial | T + ): Promise<{ insertedId: ObjectId; acknowledged: boolean }> { + const source = doc as Record; + let id: ObjectId; + if ("_id" in source && source._id) { + if (source._id instanceof ObjectId) { + id = source._id; + } else { + try { + id = new ObjectId(String(source._id)); + } catch { + id = new ObjectId(); + } + } + } else { + id = new ObjectId(); + } + const docWithId = { ...source, _id: id } as T; this.data.set(id.toString(), docWithId); return { insertedId: id, acknowledged: true }; } @@ -98,28 +127,36 @@ class InMemoryCollection> { filter: Record, update: Record, options: { upsert?: boolean } = {} - ): Promise<{ matchedCount: number; modifiedCount: number; upsertedId?: ObjectId }> { + ): Promise<{ matchedCount: number; modifiedCount: number; upsertedId?: ObjectId; acknowledged: boolean }> { for (const [id, doc] of this.data.entries()) { if (this.matchesFilter(doc, filter)) { const updated = this.applyUpdate(doc, update); this.data.set(id, updated); - return { matchedCount: 1, modifiedCount: 1 }; + return { matchedCount: 1, modifiedCount: 1, acknowledged: true }; } } if (options.upsert) { const $set = (update.$set || {}) as Partial; const $setOnInsert = (update.$setOnInsert || {}) as Partial; const newDoc = { ...filter, ...$set, ...$setOnInsert } as T; - const result = await this.insertOne(newDoc); - return { matchedCount: 0, modifiedCount: 0, upsertedId: result.insertedId }; + const result = await this.insertOne(newDoc as Partial); + return { + matchedCount: 0, + modifiedCount: 0, + upsertedId: result.insertedId, + acknowledged: true, + }; } - return { matchedCount: 0, modifiedCount: 0 }; + return { matchedCount: 0, modifiedCount: 0, acknowledged: true }; } async updateMany( filter: Record, - update: Record + update: Record | Array> ): Promise<{ matchedCount: number; modifiedCount: number }> { + if (Array.isArray(update)) { + throw new Error("Pipeline-style updates are unsupported in InMemoryCollection.updateMany"); + } let count = 0; for (const [id, doc] of this.data.entries()) { if (this.matchesFilter(doc, filter)) { @@ -171,7 +208,10 @@ class InMemoryCollection> { return "index_" + Date.now(); } - aggregate(_pipeline: Record[]): InMemoryCursor { + aggregate( + _pipeline: Record[], + _options?: Record + ): InMemoryCursor { // Simplified aggregation - just return all docs return new InMemoryCursor(Array.from(this.data.values()) as unknown as R[]); } @@ -369,6 +409,10 @@ class InMemoryCursor { return this.getProcessed(); } + map(fn: (value: T, index: number) => U): InMemoryCursor { + return new InMemoryCursor(this.getProcessed().map(fn)); + } + async hasNext(): Promise { return this._index < this.getProcessed().length; } @@ -416,22 +460,94 @@ class InMemoryBucket { if (_event === "finish") setTimeout(() => cb(), 0); return stream; }, + once: (_event: string, cb: (err?: Error) => void) => { + if (_event === "finish") setTimeout(() => cb(), 0); + return stream; + }, }; return stream; } openDownloadStream(id: ObjectId | string) { const file = this.files.get(id.toString()); + const listeners = new Map void>>(); + let destroyed = false; + const scheduled = { + error: false, + data: false, + end: false, + }; + + const emit = (event: string, arg?: Error | Uint8Array) => { + if (destroyed) return; + const handlers = listeners.get(event); + if (!handlers) return; + for (const handler of handlers) { + handler(arg); + } + }; + + const addListener = ( + event: string, + handler: (arg?: Error | Uint8Array) => void + ) => { + if (!listeners.has(event)) { + listeners.set(event, new Set()); + } + listeners.get(event)?.add(handler); + }; + + const scheduleInitialEmit = (event: string) => { + if (event === "error" && !scheduled.error && !file) { + scheduled.error = true; + setTimeout(() => emit("error", new Error("File not found")), 0); + } + if (event === "data" && !scheduled.data && file) { + scheduled.data = true; + setTimeout(() => emit("data", file.data), 0); + } + if (event === "end" && !scheduled.end && file) { + scheduled.end = true; + setTimeout(() => emit("end"), 0); + } + }; + const downloadStream = { pipe void; end: () => void }>(dest: T): T { if (file) dest.write(file.data); dest.end(); return dest; }, - on(_event: string, cb: (err?: Error) => void) { - if (_event === "error" && !file) cb(new Error("File not found")); + on(_event: string, cb: (arg?: Error | Uint8Array) => void) { + addListener(_event, cb); + scheduleInitialEmit(_event); + return downloadStream; + }, + once(event: string, cb: (arg?: Error | Uint8Array) => void) { + const wrapped = (arg?: Error | Uint8Array) => { + downloadStream.off(event, wrapped); + cb(arg); + }; + addListener(event, wrapped); + scheduleInitialEmit(event); + return downloadStream; + }, + off(event: string, cb: (arg?: Error | Uint8Array) => void) { + listeners.get(event)?.delete(cb); return downloadStream; }, + destroy(err?: Error) { + if (err) { + const errorHandlers = listeners.get("error"); + if (errorHandlers) { + for (const handler of errorHandlers) { + handler(err); + } + } + } + destroyed = true; + listeners.clear(); + }, }; return downloadStream; } @@ -470,6 +586,12 @@ export class Database { // Return a mock client for compatibility with MongoDB client interface return { connect: async () => ({ + startSession: () => ({ + withTransaction: async (fn: () => Promise) => { + return await fn(); + }, + endSession: async () => {}, + }), db: () => ({ collection: (name: string) => new InMemoryCollection(name), }), @@ -478,7 +600,9 @@ export class Database { collection: (name: string) => new InMemoryCollection(name), }), startSession: () => ({ - withTransaction: async (fn: () => Promise) => { await fn(); }, + withTransaction: async (fn: () => Promise) => { + return await fn(); + }, endSession: async () => {}, }), }; diff --git a/frontend/src/lib/server/files/downloadFile.ts b/frontend/src/lib/server/files/downloadFile.ts index d289fc10c..62e608b42 100644 --- a/frontend/src/lib/server/files/downloadFile.ts +++ b/frontend/src/lib/server/files/downloadFile.ts @@ -25,10 +25,37 @@ export async function downloadFile( const buffer = await new Promise((resolve, reject) => { const chunks: Uint8Array[] = []; - fileStream.on("data", (chunk) => chunks.push(chunk)); - fileStream.on("error", reject); - fileStream.on("end", () => resolve(Buffer.concat(chunks))); + const onData = (chunk: unknown) => { + if (chunk instanceof Uint8Array) { + chunks.push(chunk); + return; + } + const err = new Error("Unexpected chunk type from fileStream"); + fileStream.destroy(); + cleanup(); + reject(err); + }; + + const onError = (err: Error | null | undefined) => { + cleanup(); + reject(err ?? new Error("File download failed")); + }; + + const onEnd = () => { + cleanup(); + resolve(Buffer.concat(chunks)); + }; + + const cleanup = () => { + fileStream.off("data", onData); + fileStream.off("error", onError); + fileStream.off("end", onEnd); + }; + + fileStream.on("data", onData); + fileStream.once("error", onError); + fileStream.once("end", onEnd); }); - return { type: "base64", name, value: buffer.toString("base64"), mime }; + return { type: "base64", name, value: buffer.toString("base64"), mime: String(mime ?? "") }; } diff --git a/frontend/src/lib/server/files/uploadFile.ts b/frontend/src/lib/server/files/uploadFile.ts index 97b335bea..361859c91 100644 --- a/frontend/src/lib/server/files/uploadFile.ts +++ b/frontend/src/lib/server/files/uploadFile.ts @@ -5,8 +5,8 @@ import { fileTypeFromBuffer } from "file-type"; import { collections } from "$lib/server/database"; export async function uploadFile(file: File, conv: Conversation): Promise { - const sha = await sha256(await file.text()); const buffer = await file.arrayBuffer(); + const sha = await sha256(Buffer.from(buffer)); // Attempt to detect the mime type of the file, fallback to the uploaded mime const mime = await fileTypeFromBuffer(buffer).then((fileType) => fileType?.mime ?? file.type); @@ -15,15 +15,38 @@ export async function uploadFile(file: File, conv: Conversation): Promise { - upload.once("finish", () => - resolve({ type: "hash", value: sha, mime: file.type, name: file.name }) - ); - upload.once("error", reject); - setTimeout(() => reject(new Error("Upload timed out")), 20_000); + const timeoutId = setTimeout(() => { + if (typeof upload.off === "function") { + upload.off("finish", handleFinish); + upload.off("error", handleError); + } + reject(new Error("Upload timed out")); + }, 20_000); + + const clearListeners = () => { + clearTimeout(timeoutId); + if (typeof upload.off === "function") { + upload.off("finish", handleFinish); + upload.off("error", handleError); + } + }; + + const handleFinish = () => { + clearListeners(); + resolve({ type: "hash", value: sha, mime: mime ?? file.type, name: file.name }); + }; + + const handleError = (err?: Error) => { + clearListeners(); + reject(err ?? new Error("Upload failed")); + }; + + upload.once("finish", handleFinish); + upload.once("error", handleError); + + upload.write(Buffer.from(buffer)); + upload.end(); }); } diff --git a/frontend/src/lib/server/models.ts b/frontend/src/lib/server/models.ts index cd7f57d5b..77c04b772 100644 --- a/frontend/src/lib/server/models.ts +++ b/frontend/src/lib/server/models.ts @@ -13,6 +13,13 @@ interface ModelConfig { name: string; displayName?: string; description?: string; + websiteUrl?: string; + modelUrl?: string; + datasetName?: string; + datasetUrl?: string; + logoUrl?: string; + promptExamples?: Array<{ title: string; prompt: string }>; + providers?: Array<{ provider: string } & Record>; preprompt: string; multimodal: boolean; multimodalAcceptedMimetypes?: string[]; @@ -47,7 +54,11 @@ const processModel = async (m: ModelConfig) => ({ id: m.id || m.name, displayName: m.displayName || m.name, preprompt: m.preprompt, - parameters: { stop_sequences: [] as string[] }, + parameters: { + stop_sequences: [] as string[], + stop: [] as string[], + truncate: null, + }, unlisted: m.unlisted ?? false, }); @@ -254,5 +265,16 @@ export const validateModel = (_models: BackendModel[]) => { export type BackendModel = Optional< typeof defaultModel, - "preprompt" | "parameters" | "multimodal" | "unlisted" | "hasInferenceAPI" + | "preprompt" + | "parameters" + | "multimodal" + | "unlisted" + | "hasInferenceAPI" + | "websiteUrl" + | "modelUrl" + | "datasetName" + | "datasetUrl" + | "logoUrl" + | "promptExamples" + | "providers" >; diff --git a/frontend/src/lib/services/agent-api.ts b/frontend/src/lib/services/agent-api.ts index 0f3ad0843..4161a9920 100644 --- a/frontend/src/lib/services/agent-api.ts +++ b/frontend/src/lib/services/agent-api.ts @@ -48,6 +48,10 @@ export class AgentAPI { return this.authToken; } + getBaseUrl(): string { + return this.baseUrl; + } + private getHeaders(): Record { const headers: Record = { 'Content-Type': 'application/json' @@ -210,3 +214,7 @@ export class AgentAPI { } export const agentAPI = new AgentAPI(); + +export function getAgentBaseUrl(): string { + return agentAPI.getBaseUrl(); +} diff --git a/frontend/src/lib/services/voice-client.ts b/frontend/src/lib/services/voice-client.ts new file mode 100644 index 000000000..3941416f8 --- /dev/null +++ b/frontend/src/lib/services/voice-client.ts @@ -0,0 +1,492 @@ +import { getAgentBaseUrl } from "$lib/services/agent-api"; + +export type VoiceState = + | "idle" + | "connecting" + | "active" + | "listening" + | "muted" + | "agent-speaking" + | "ended" + | "error"; + +export type TranscriptEvent = { + role: "user" | "agent"; + text: string; + isFinal: boolean; + ts: number; +}; + +type VoiceSessionStart = { + session_id: string; + context_id: string; + ws_url: string; + session_token?: string; +}; + +export class VoiceClient { + private ws: WebSocket | null = null; + private sessionId: string | null = null; + private state: VoiceState = "idle"; + private mediaStream: MediaStream | null = null; + private audioContext: AudioContext | null = null; + private sourceNode: MediaStreamAudioSourceNode | null = null; + private processorNode: AudioWorkletNode | null = null; + private silentGainNode: GainNode | null = null; + private isStreamingAudio = false; + private pendingConnectResolve: (() => void) | null = null; + private pendingConnectReject: ((reason?: unknown) => void) | null = null; + private duplexHoldUntilMs = 0; + private isStopping = false; + private stopToken = 0; + + private extendDuplexHold(ms: number): void { + this.duplexHoldUntilMs = Math.max(this.duplexHoldUntilMs, Date.now() + ms); + } + + holdMicFor(ms: number): void { + if (ms <= 0) { + return; + } + this.extendDuplexHold(ms); + } + + onTranscript?: (event: TranscriptEvent) => void; + onAgentResponse?: (text: string) => void; + onAgentAudio?: (audioData: ArrayBuffer) => void; + onStateChange?: (state: VoiceState) => void; + onError?: (message: string) => void; + + async startSession(contextId?: string): Promise { + const baseUrl = getAgentBaseUrl(); + let response: Response; + try { + response = await fetch(`${baseUrl}/voice/session/start`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(contextId ? { context_id: contextId } : {}), + }); + } catch { + throw new Error(`Cannot reach voice backend at ${baseUrl}. Is the agent running?`); + } + + if (!response.ok) { + const text = await response.text().catch(() => "Unknown error"); + if (response.status === 404) { + throw new Error( + "Voice endpoint not found. Run a voice-enabled agent (examples/voice-agent/main.py)." + ); + } + throw new Error(`Failed to start voice session: ${response.status} ${text}`); + } + + try { + return (await response.json()) as VoiceSessionStart; + } catch (err) { + throw new Error( + `Failed to parse voice session response: ${err instanceof Error ? err.message : String(err)}` + ); + } + } + + async connect(wsUrl: string, sessionId: string, sessionToken?: string): Promise { + this.setState("connecting"); + this.sessionId = sessionId; + + await new Promise((resolve, reject) => { + this.pendingConnectResolve = resolve; + this.pendingConnectReject = reject; + try { + const resolved = this.resolveWebSocketUrl(wsUrl); + // If the backend requires session auth, it expects the token either as: + // - Sec-WebSocket-Protocol header (preferred), or + // - first text frame after connect (fallback). + this.ws = sessionToken ? new WebSocket(resolved, [sessionToken]) : new WebSocket(resolved); + } catch (err) { + this.pendingConnectResolve = null; + this.pendingConnectReject = null; + reject(err); + return; + } + + if (!this.ws) { + this.pendingConnectResolve = null; + this.pendingConnectReject = null; + reject(new Error("WebSocket initialization failed")); + return; + } + + this.ws.binaryType = "arraybuffer"; + + this.ws.onopen = () => { + this.isStopping = false; + this.pendingConnectResolve = null; + this.pendingConnectReject = null; + if (sessionToken) { + // Only send the token as a first text frame if subprotocol negotiation + // did not succeed. Otherwise the server will treat this as a malformed + // JSON control frame and close the connection. + const negotiatedProtocol = this.ws?.protocol; + if (!negotiatedProtocol) { + try { + this.ws?.send(sessionToken); + } catch { + // Ignore and proceed; server may already have token via headers. + } + } + } + this.sendControl({ type: "start", config: { sampleRate: 16000 } }); + this.setState("active"); + resolve(); + }; + + this.ws.onerror = () => { + this.pendingConnectResolve = null; + this.pendingConnectReject = null; + this.setState("error"); + this.onError?.("Voice WebSocket connection error"); + reject(new Error("Voice WebSocket connection error")); + }; + + this.ws.onclose = () => { + if (this.pendingConnectReject) { + this.pendingConnectReject(new Error("WebSocket closed before open")); + this.pendingConnectResolve = null; + this.pendingConnectReject = null; + } + this.cleanupAudioStreaming(); + if (this.state !== "ended" && this.state !== "idle") { + this.setState("ended"); + } + }; + + this.ws.onmessage = (event) => { + if (event.data instanceof ArrayBuffer) { + // While agent audio is arriving, hold mic uplink briefly to avoid + // speaker-to-mic feedback loops. + this.extendDuplexHold(1800); + this.onAgentAudio?.(event.data); + return; + } + + if (typeof event.data !== "string") { + return; + } + + try { + const data = JSON.parse(event.data) as { + type?: string; + role?: "user" | "agent"; + text?: string; + is_final?: boolean; + state?: VoiceState; + message?: string; + }; + + if (data.type === "transcript" && data.role && data.text) { + if (data.role === "agent") { + this.extendDuplexHold(data.is_final ? 1800 : 2600); + } + this.onTranscript?.({ + role: data.role, + text: data.text, + isFinal: Boolean(data.is_final ?? true), + ts: Date.now(), + }); + return; + } + + if (data.type === "agent_response" && data.text) { + this.extendDuplexHold(1800); + this.onAgentResponse?.(data.text); + return; + } + + if (data.type === "state" && data.state) { + this.setState(data.state); + return; + } + + if (data.type === "error" && data.message) { + this.setState("error"); + this.onError?.(data.message); + } + } catch { + // Ignore malformed frames + } + }; + }); + } + + async sendUserText(text: string): Promise { + if (!text.trim()) { + return; + } + this.sendControl({ type: "user_text", text: text.trim() }); + } + + commitTurn(): void { + this.sendControl({ type: "commit_turn" }); + } + + mute(): void { + if (this.sendControl({ type: "mute" })) { + this.setState("muted"); + } + } + + unmute(): void { + if (this.sendControl({ type: "unmute" })) { + this.setState("listening"); + } + } + + async stopSession(): Promise { + this.isStopping = true; + this.stopToken += 1; + const id = this.sessionId; + this.sendControl({ type: "stop" }); + this.cleanupAudioStreaming(); + + if (this.ws) { + this.ws.close(); + this.ws = null; + } + + this.sessionId = null; + this.setState("ended"); + + if (id) { + const baseUrl = getAgentBaseUrl(); + await fetch(`${baseUrl}/voice/session/${id}`, { method: "DELETE" }).catch(() => undefined); + } + } + + private sendControl(payload: Record): boolean { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return false; + } + this.ws.send(JSON.stringify(payload)); + return true; + } + + private setState(state: VoiceState): void { + this.state = state; + if (state === "agent-speaking") { + this.extendDuplexHold(2500); + } else if (state === "listening") { + this.extendDuplexHold(800); + } + this.onStateChange?.(state); + } + + private canSendMicAudio(): boolean { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return false; + } + if (this.state === "muted" || this.state === "agent-speaking") { + return false; + } + return Date.now() >= this.duplexHoldUntilMs; + } + + async startAudioStreaming(): Promise { + if (this.isStreamingAudio) { + return; + } + + if (this.isStopping || !this.sessionId) { + return; + } + + const setupToken = this.stopToken; + + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new Error("Voice WebSocket is not connected"); + } + + const stream = await navigator.mediaDevices.getUserMedia({ + audio: { + channelCount: 1, + sampleRate: 16000, + echoCancellation: true, + noiseSuppression: true, + autoGainControl: true, + }, + }); + + if (this.isStopping || setupToken !== this.stopToken || !this.sessionId) { + stream.getTracks().forEach((track) => track.stop()); + return; + } + + const AudioContextCtor = + window.AudioContext || + (window as typeof window & { webkitAudioContext?: typeof AudioContext }).webkitAudioContext; + + if (!AudioContextCtor) { + stream.getTracks().forEach((track) => track.stop()); + throw new Error("AudioContext is not supported in this browser"); + } + + const desiredSampleRate = 16000; + const audioContext = new AudioContextCtor({ sampleRate: desiredSampleRate }); + await audioContext.resume(); + + if (this.isStopping || setupToken !== this.stopToken || !this.sessionId) { + void audioContext.close(); + stream.getTracks().forEach((track) => track.stop()); + return; + } + + // Inline AudioWorkletProcessor code to avoid external file dependencies + const workletCode = ` + class PCM16Processor extends AudioWorkletProcessor { + constructor() { + super(); + this.buffer = new Float32Array(4096); + this.bufferIndex = 0; + } + + process(inputs, outputs, parameters) { + const input = inputs[0]; + if (!input || !input[0]) return true; + + const channelData = input[0]; + + for (let i = 0; i < channelData.length; i++) { + this.buffer[this.bufferIndex++] = channelData[i]; + if (this.bufferIndex >= this.buffer.length) { + this.flushBuffer(); + } + } + + return true; + } + + flushBuffer() { + const pcm16 = new Int16Array(this.bufferIndex); + for (let j = 0; j < this.bufferIndex; j++) { + const s = Math.max(-1, Math.min(1, this.buffer[j])); + pcm16[j] = s < 0 ? s * 0x8000 : s * 0x7FFF; + } + this.port.postMessage(pcm16.buffer, [pcm16.buffer]); + this.bufferIndex = 0; + } + } + registerProcessor('pcm16-processor', PCM16Processor); + `; + + const blob = new Blob([workletCode], { type: "application/javascript" }); + const workletUrl = URL.createObjectURL(blob); + + try { + await audioContext.audioWorklet.addModule(workletUrl); + } finally { + URL.revokeObjectURL(workletUrl); + } + + if (this.isStopping || setupToken !== this.stopToken || !this.sessionId) { + void audioContext.close(); + stream.getTracks().forEach((track) => track.stop()); + return; + } + + const sourceNode = audioContext.createMediaStreamSource(stream); + const workletNode = new AudioWorkletNode(audioContext, "pcm16-processor"); + const silentGainNode = audioContext.createGain(); + silentGainNode.gain.value = 0; + + workletNode.port.onmessage = (event) => { + if (!this.canSendMicAudio()) { + return; + } + const ws = this.ws; + if (!ws || ws.readyState !== WebSocket.OPEN) { + return; + } + ws.send(event.data); + }; + + sourceNode.connect(workletNode); + workletNode.connect(silentGainNode); + silentGainNode.connect(audioContext.destination); + + if (this.isStopping || setupToken !== this.stopToken || !this.sessionId) { + workletNode.disconnect(); + silentGainNode.disconnect(); + sourceNode.disconnect(); + void audioContext.close(); + stream.getTracks().forEach((track) => track.stop()); + return; + } + + this.mediaStream = stream; + this.audioContext = audioContext; + this.sourceNode = sourceNode; + this.processorNode = workletNode; + this.silentGainNode = silentGainNode; + this.isStreamingAudio = true; + } + + stopAudioStreaming(): void { + this.cleanupAudioStreaming(); + this.commitTurn(); + } + + private cleanupAudioStreaming(): void { + this.stopToken += 1; + this.isStreamingAudio = false; + + if (this.processorNode) { + this.processorNode.disconnect(); + this.processorNode.port.onmessage = null; + this.processorNode = null; + } + + if (this.silentGainNode) { + this.silentGainNode.disconnect(); + this.silentGainNode = null; + } + + if (this.sourceNode) { + this.sourceNode.disconnect(); + this.sourceNode = null; + } + + if (this.audioContext) { + void this.audioContext.close(); + this.audioContext = null; + } + + if (this.mediaStream) { + this.mediaStream.getTracks().forEach((track) => track.stop()); + this.mediaStream = null; + } + } + + private resolveWebSocketUrl(wsUrl: string): string { + const configuredBaseUrl = getAgentBaseUrl().replace(/^http/, "ws"); + + try { + const endpointUrl = new URL(wsUrl); + const proxyBaseUrl = new URL(configuredBaseUrl); + endpointUrl.protocol = proxyBaseUrl.protocol; + endpointUrl.host = proxyBaseUrl.host; + return endpointUrl.toString(); + } catch { + return wsUrl; + } + } +} + +function convertFloat32ToPcm16(input: Float32Array): ArrayBuffer { + const pcm = new Int16Array(input.length); + for (let index = 0; index < input.length; index += 1) { + const sample = Math.max(-1, Math.min(1, input[index] ?? 0)); + pcm[index] = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + } + return pcm.buffer; +} diff --git a/frontend/src/lib/stores/chat.ts b/frontend/src/lib/stores/chat.ts index 1d2b27a53..8f7b34526 100644 --- a/frontend/src/lib/stores/chat.ts +++ b/frontend/src/lib/stores/chat.ts @@ -12,6 +12,63 @@ export interface DisplayMessage { timestamp: number; } +type TextPart = { + kind: 'text'; + text: string; + metadata?: Record; +}; + +type FilePart = { + kind: 'file'; + fileId?: string; + file?: { + bytes?: string; + mimeType?: string; + name?: string; + }; + metadata?: Record; +}; + +type Part = TextPart | FilePart; + +function base64FromBytes(bytes: Uint8Array): string { + let binary = ''; + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i] ?? 0); + } + return btoa(binary); +} + +async function normalizeFileBytes(value: string | ArrayBuffer | Uint8Array | Blob): Promise { + if (typeof value === 'string') return value; + if (value instanceof Uint8Array) return base64FromBytes(value); + if (value instanceof ArrayBuffer) return base64FromBytes(new Uint8Array(value)); + if (value instanceof Blob) { + const buffer = await value.arrayBuffer(); + return base64FromBytes(new Uint8Array(buffer)); + } + return ''; +} + +async function normalizePartsForSend(parts: Part[]): Promise { + const normalized = await Promise.all( + parts.map(async (part) => { + if (part.kind !== 'file') return part; + const file = part.file; + if (!file?.bytes) return part; + const normalizedBytes = await normalizeFileBytes(file.bytes as string | ArrayBuffer | Uint8Array | Blob); + return { + ...part, + file: { + ...file, + bytes: normalizedBytes, + }, + }; + }) + ); + return normalized; +} + export const currentTaskId = writable(null); export const currentTaskState = writable(null); export const contextId = writable(null); @@ -226,7 +283,7 @@ export async function clearContext(ctxId: string) { } } -export async function sendMessage(text: string) { +export async function sendMessage(parts: Part[]) { const currentState = get(currentTaskState); const currentTask = get(currentTaskId); const currentContext = get(contextId); @@ -256,8 +313,18 @@ export async function sendMessage(text: string) { const useContextId = currentContext || generateUUID(); try { - // Add user message immediately - addMessage(text, 'user', taskId); + const normalizedParts = await normalizePartsForSend(parts || []); + // Add user message immediately (combine all text parts for display) + const text = normalizedParts + .filter( + (p): p is TextPart => + p.kind === 'text' && typeof p.text === 'string' && p.text.trim().length > 0 + ) + .map((p) => p.text) + .join('\n'); + if (text) { + addMessage(text, 'user', taskId); + } replyToTaskId.set(null); isThinking.set(true); @@ -265,7 +332,7 @@ export async function sendMessage(text: string) { const task = await agentAPI.sendMessage({ message: { role: 'user' as const, - parts: [{ kind: 'text' as const, text }], + parts: normalizedParts, kind: 'message' as const, messageId, contextId: useContextId, diff --git a/frontend/src/lib/stores/voice.ts b/frontend/src/lib/stores/voice.ts new file mode 100644 index 000000000..bd48faa5d --- /dev/null +++ b/frontend/src/lib/stores/voice.ts @@ -0,0 +1,242 @@ +import { get, writable } from "svelte/store"; +import { VoiceClient, type TranscriptEvent, type VoiceState } from "../services/voice-client"; + +export type VoiceTranscript = TranscriptEvent & { id: string }; + +export const voiceSessionId = writable(null); +export const voiceContextId = writable(null); +export const voiceState = writable("idle"); +export const isVoiceMuted = writable(false); +export const transcripts = writable([]); +export const currentUserTranscript = writable(""); +export const currentAgentTranscript = writable(""); +export const latestAgentAudio = writable(null); +export const voiceError = writable(null); + +let client: VoiceClient | null = null; +let isStarting = false; +let startTokenCounter = 0; +let transcriptIdCounter = 0; + +function resetVoiceCallState(): void { + isVoiceMuted.set(false); + latestAgentAudio.set(null); +} + +function mergeTranscriptText(previous: string, incoming: string): string { + const prev = previous.trim(); + const next = incoming.trim(); + + if (!prev) return next; + if (!next) return prev; + + // If transport sends cumulative text, prefer the longer/latest cumulative value. + if (next.startsWith(prev)) return next; + if (prev.startsWith(next)) return prev; + + // If values only differ in whitespace around punctuation, keep the longer one. + const normalize = (value: string): string => + value + .replace(/\s+/g, " ") + .replace(/\s+([.,!?;:])/g, "$1") + .trim() + .toLowerCase(); + const prevNormalized = normalize(prev); + const nextNormalized = normalize(next); + if (prevNormalized === nextNormalized) { + return next.length >= prev.length ? next : prev; + } + + // Character overlap merge to handle non-tokenized stream chunks. + const maxOverlap = Math.min(prev.length, next.length); + for (let overlap = maxOverlap; overlap > 0; overlap -= 1) { + if (prev.slice(-overlap) === next.slice(0, overlap)) { + return `${prev}${next.slice(overlap)}`.trim(); + } + } + + // If transport sends token deltas, stitch naturally. + if (/^[.,!?;:]$/.test(next)) { + return `${prev}${next}`; + } + + return `${prev} ${next}`; +} + +function appendTranscript(event: TranscriptEvent): void { + transcripts.update((items) => { + const last = items[items.length - 1]; + + if (last && last.role === event.role && !last.isFinal) { + const mergedText = mergeTranscriptText(last.text, event.text); + const merged = { + ...last, + text: mergedText, + isFinal: event.isFinal, + ts: event.ts, + }; + + if (event.role === "user") { + currentUserTranscript.set(mergedText); + } else { + currentAgentTranscript.set(mergedText); + } + + return [...items.slice(0, -1), merged]; + } + + const transcript = { + ...event, + id: `${event.role}-${event.ts}-${transcriptIdCounter++}`, + }; + + if (transcript.role === "user") { + currentUserTranscript.set(transcript.text); + } else { + currentAgentTranscript.set(transcript.text); + } + + return [...items, transcript]; + }); +} + +export async function startVoiceSession(contextId?: string): Promise { + if (isStarting) { + throw new Error("A voice session is already starting"); + } + + isStarting = true; + const startToken = ++startTokenCounter; + const localClient = new VoiceClient(); + + voiceError.set(null); + transcripts.set([]); + currentUserTranscript.set(""); + currentAgentTranscript.set(""); + resetVoiceCallState(); + + // Clean up any existing client before creating a new one + const existingClient = client; + if (existingClient) { + client = null; + try { + await existingClient.stopSession(); + } catch (err) { + console.error("Error cleaning up existing client:", err); + } + } + + localClient.onTranscript = appendTranscript; + localClient.onStateChange = (state) => { + voiceState.set(state); + }; + localClient.onAgentAudio = (audioData) => { + latestAgentAudio.set(audioData); + }; + localClient.onError = (message) => { + voiceError.set(message); + voiceState.set("error"); + }; + + voiceState.set("connecting"); + try { + const session = await localClient.startSession(contextId); + if (startToken !== startTokenCounter) { + await localClient.stopSession().catch(() => undefined); + return; + } + voiceSessionId.set(session.session_id); + voiceContextId.set(session.context_id); + + client = localClient; + await localClient.connect(session.ws_url, session.session_id, session.session_token); + if (startToken !== startTokenCounter) { + await localClient.stopSession().catch(() => undefined); + return; + } + } catch (err) { + // On failure, clear the partially-initialized client + if (client === localClient) { + client = null; + } + const errorMessage = err instanceof Error ? err.message : String(err); + voiceError.set(errorMessage); + voiceState.set("error"); + await localClient.stopSession().catch(() => undefined); + voiceSessionId.set(null); + voiceContextId.set(null); + throw err; + } finally { + if (startToken === startTokenCounter) { + isStarting = false; + } + } +} + +export async function endVoiceSession(): Promise { + startTokenCounter += 1; + const active = client; + client = null; + + if (active) { + try { + await active.stopSession(); + } catch (err) { + console.error("Error stopping voice session:", err); + // Continue with cleanup even if stopSession throws + } + } + + // Reset all state variables + voiceSessionId.set(null); + voiceContextId.set(null); + resetVoiceCallState(); + voiceState.set("idle"); +} + +export function toggleMute(): void { + if (!client) { + return; + } + + const muted = get(isVoiceMuted); + if (muted) { + client.unmute(); + isVoiceMuted.set(false); + } else { + client.mute(); + isVoiceMuted.set(true); + } +} + +export async function sendVoiceText(text: string): Promise { + if (!client) { + throw new Error("No active voice session"); + } + await client.sendUserText(text); +} + +export function commitVoiceTurn(): void { + if (!client) { + return; + } + client.commitTurn(); +} + +export async function startVoiceStreaming(): Promise { + if (!client) { + throw new Error("No active voice session"); + } + await client.startAudioStreaming(); +} + +export function stopVoiceStreaming(): void { + client?.stopAudioStreaming(); +} + +export function holdVoiceInputFor(ms: number): void { + if (!client || ms <= 0) { + return; + } + client.holdMicFor(ms); +} diff --git a/frontend/src/lib/types/ConvSidebar.ts b/frontend/src/lib/types/ConvSidebar.ts index bbba9abc5..5c4c6bb59 100644 --- a/frontend/src/lib/types/ConvSidebar.ts +++ b/frontend/src/lib/types/ConvSidebar.ts @@ -1,4 +1,4 @@ -import type { ObjectId } from "bson"; +import type ObjectId from "bson-objectid"; export interface ConvSidebar { id: ObjectId | string; diff --git a/frontend/src/lib/types/Model.ts b/frontend/src/lib/types/Model.ts index 6e9377543..bb8f8b29d 100644 --- a/frontend/src/lib/types/Model.ts +++ b/frontend/src/lib/types/Model.ts @@ -15,4 +15,11 @@ export type Model = Pick< | "description" | "preprompt" | "multimodalAcceptedMimetypes" + | "websiteUrl" + | "modelUrl" + | "datasetName" + | "datasetUrl" + | "logoUrl" + | "promptExamples" + | "providers" >>; diff --git a/frontend/src/lib/types/Session.ts b/frontend/src/lib/types/Session.ts index 8bba6b942..b72dd6e95 100644 --- a/frontend/src/lib/types/Session.ts +++ b/frontend/src/lib/types/Session.ts @@ -1,4 +1,4 @@ -import type { ObjectId } from "bson"; +import type ObjectId from "bson-objectid"; import type { Timestamps } from "./Timestamps"; import type { User } from "./User"; diff --git a/frontend/src/lib/utils/agentMessageHandler.ts b/frontend/src/lib/utils/agentMessageHandler.ts index 01ba88e3c..1807f5891 100644 --- a/frontend/src/lib/utils/agentMessageHandler.ts +++ b/frontend/src/lib/utils/agentMessageHandler.ts @@ -7,8 +7,9 @@ import type { MessageUpdate } from '$lib/types/MessageUpdate'; import { MessageUpdateType, MessageUpdateStatus } from '$lib/types/MessageUpdate'; import { handlePaymentRequired, getPaymentHeaders, clearPaymentToken } from './paymentHandler'; +import { env as publicEnv } from '$env/dynamic/public'; -const AGENT_BASE_URL = 'http://localhost:3773'; +const AGENT_BASE_URL = publicEnv.PUBLIC_AGENT_BASE_URL || 'http://localhost:3773'; /** * Submit feedback for a task @@ -61,9 +62,20 @@ export async function submitTaskFeedback( return result.result; } + +interface FilePart { + kind: 'file'; + file: { + bytes: string; + mimeType?: string; + name?: string; + }; +} + + interface AgentMessage { role: 'user' | 'agent'; - parts: Array<{ kind: 'text'; text: string }>; + parts: Array<{ kind: 'text'; text: string } | FilePart>; kind: 'message'; messageId: string; contextId: string; @@ -107,6 +119,27 @@ function getAuthToken(): string | null { return localStorage.getItem('bindu_oauth_token'); } +function base64FromBytes(bytes: Uint8Array): string { + let binary = ''; + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i] ?? 0); + } + return btoa(binary); +} + +async function normalizeFileValue( + value: string | ArrayBuffer | Uint8Array | Blob +): Promise { + if (typeof value === 'string') return value; + if (value instanceof Uint8Array) return base64FromBytes(value); + if (value instanceof ArrayBuffer) return base64FromBytes(new Uint8Array(value)); + if (value instanceof Blob) { + const buffer = await value.arrayBuffer(); + return base64FromBytes(new Uint8Array(buffer)); + } + return ''; +} + /** * Extract text from task (artifacts or history) */ @@ -161,11 +194,25 @@ function extractTextFromTask(task: AgentTask): string { export async function* sendAgentMessage( message: string, contextId?: string, - abortSignal?: AbortSignal, - currentTaskId?: string, - taskState?: string, - replyToTaskId?: string + options: { + abortSignal?: AbortSignal; + currentTaskId?: string; + taskState?: string; + replyToTaskId?: string; + fileParts?: Array<{ + name: string; + mime: string; + value: string | ArrayBuffer | Uint8Array | Blob; + }>; + } = {} ): AsyncGenerator { + const { + abortSignal, + currentTaskId, + taskState, + replyToTaskId, + fileParts, + } = options; const token = typeof window !== 'undefined' ? localStorage.getItem('bindu_oauth_token') : null; const headers: Record = { 'Content-Type': 'application/json', @@ -214,14 +261,61 @@ export async function* sendAgentMessage( : generateId(); // Build message with optional referenceTaskIds + const parts: AgentMessage["parts"] = [{ kind: 'text', text: message }]; + if (fileParts && fileParts.length > 0) { + for (const f of fileParts) { + const mime = typeof f.mime === 'string' ? f.mime.trim() : ''; + const name = typeof f.name === 'string' ? f.name.trim() : ''; + const value = f.value; + + const hasValue = + typeof value === 'string' + ? value.length > 0 + : value instanceof ArrayBuffer + ? value.byteLength > 0 + : value instanceof Uint8Array + ? value.byteLength > 0 + : value instanceof Blob + ? value.size > 0 + : Boolean(value); + + if (!hasValue || !mime || !name) { + console.warn('[agentMessageHandler] Dropping invalid file part', { + hasValue, + mime, + name, + }); + continue; + } + + const normalizedValue = await normalizeFileValue(value); + if (!normalizedValue) { + console.warn('[agentMessageHandler] Dropping empty file part after normalization', { + mime, + name, + }); + continue; + } + + parts.push({ + kind: 'file', + file: { + bytes: normalizedValue, + mimeType: mime, + name, + }, + }); + } + } + const agentMessage: AgentMessage = { role: 'user', - parts: [{ kind: 'text', text: message }], + parts, kind: 'message', messageId, contextId: newContextId, taskId, - ...(referenceTaskIds.length > 0 && { referenceTaskIds }) + ...(referenceTaskIds.length > 0 && { referenceTaskIds }), }; // Step 1: Send message diff --git a/frontend/src/lib/utils/paymentHandler.ts b/frontend/src/lib/utils/paymentHandler.ts index 01e95c0ee..e2d00d9ef 100644 --- a/frontend/src/lib/utils/paymentHandler.ts +++ b/frontend/src/lib/utils/paymentHandler.ts @@ -3,7 +3,9 @@ * Handles payment sessions and token management for agents requiring payment */ -const AGENT_BASE_URL = 'http://localhost:3773'; +import { env as publicEnv } from '$env/dynamic/public'; + +const AGENT_BASE_URL = publicEnv.PUBLIC_AGENT_BASE_URL || 'http://localhost:3773'; // Payment state let paymentToken: string | null = null; diff --git a/frontend/src/lib/utils/tree/addChildren.spec.ts b/frontend/src/lib/utils/tree/addChildren.spec.ts index 8ef5c0bf0..5326ad9b3 100644 --- a/frontend/src/lib/utils/tree/addChildren.spec.ts +++ b/frontend/src/lib/utils/tree/addChildren.spec.ts @@ -16,7 +16,7 @@ Object.freeze(newMessage); describe("addChildren", async () => { it("should let you append on legacy conversations", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const convLength = conv.messages.length; @@ -26,14 +26,14 @@ describe("addChildren", async () => { }); it("should not let you create branches on legacy conversations", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); expect(() => addChildren(conv, newMessage, conv.messages[0].id)).toThrow(); }); it("should not let you create a message that already exists", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const messageThatAlreadyExists: Message = { @@ -46,7 +46,7 @@ describe("addChildren", async () => { }); it("should let you create branches on conversations with subtrees", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const nChildren = conv.messages[0].children?.length; @@ -57,7 +57,7 @@ describe("addChildren", async () => { it("should let you create a new leaf", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const parentId = conv.messages[conv.messages.length - 1].id; @@ -84,7 +84,7 @@ describe("addChildren", async () => { it("should throw if you don't specify a parentId in a conversation with messages", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); expect(() => addChildren(conv, newMessage)).toThrow(); @@ -92,7 +92,7 @@ describe("addChildren", async () => { it("should return the id of the new message", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); expect(addChildren(conv, newMessage, conv.messages[conv.messages.length - 1].id)).toEqual( diff --git a/frontend/src/lib/utils/tree/addSibling.spec.ts b/frontend/src/lib/utils/tree/addSibling.spec.ts index 89e0a7ade..7d7888f15 100644 --- a/frontend/src/lib/utils/tree/addSibling.spec.ts +++ b/frontend/src/lib/utils/tree/addSibling.spec.ts @@ -29,7 +29,7 @@ describe("addSibling", async () => { it("should fail on legacy conversations", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); expect(() => addSibling(conv, newMessage, conv.messages[0].id)).toThrow( @@ -39,7 +39,7 @@ describe("addSibling", async () => { it("should fail if the sibling message doesn't exist", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); expect(() => addSibling(conv, newMessage, "not-a-real-id-test")).toThrow( @@ -50,7 +50,7 @@ describe("addSibling", async () => { // TODO: This behaviour should be fixed, we do not need to fail on the root message. it("should fail if the sibling message is the root message", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); if (!conv.rootMessageId) throw new Error("Root message not found"); @@ -61,10 +61,10 @@ describe("addSibling", async () => { it("should add a sibling to a message", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); - // add sibling and check children count for parnets + // Add a sibling and check the parent children count. const nChildren = conv.messages[1].children?.length; const siblingId = addSibling(conv, newMessage, conv.messages[2].id); diff --git a/frontend/src/lib/utils/tree/buildSubtree.spec.ts b/frontend/src/lib/utils/tree/buildSubtree.spec.ts index 6acabf31d..83594c99e 100644 --- a/frontend/src/lib/utils/tree/buildSubtree.spec.ts +++ b/frontend/src/lib/utils/tree/buildSubtree.spec.ts @@ -12,7 +12,7 @@ import { buildSubtree } from "./buildSubtree"; describe("buildSubtree", () => { it("a subtree in a legacy conversation should be just a slice", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); // check middle @@ -33,7 +33,7 @@ describe("buildSubtree", () => { it("a subtree in a linear branch conversation should be the ancestors and the message", async () => { const convId = await insertLinearBranchConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); // check middle @@ -54,7 +54,7 @@ describe("buildSubtree", () => { it("should throw an error if the message is not found", async () => { const convId = await insertLinearBranchConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const id = "not-a-real-id-test"; @@ -64,7 +64,7 @@ describe("buildSubtree", () => { it("should throw an error if the ancestor is not found", async () => { const convId = await insertLinearBranchConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const id = "1-1-1-1-2"; @@ -87,7 +87,7 @@ describe("buildSubtree", () => { it("should work for conversation with subtrees", async () => { const convId = await insertSideBranchesConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const subtree = buildSubtree(conv, "1-1-1-1-2"); diff --git a/frontend/src/lib/utils/tree/convertLegacyConversation.spec.ts b/frontend/src/lib/utils/tree/convertLegacyConversation.spec.ts index 4577c9b06..eeb66dc44 100644 --- a/frontend/src/lib/utils/tree/convertLegacyConversation.spec.ts +++ b/frontend/src/lib/utils/tree/convertLegacyConversation.spec.ts @@ -8,7 +8,7 @@ import { insertLegacyConversation } from "./treeHelpers.spec"; describe("convertLegacyConversation", () => { it("should convert a legacy conversation", async () => { const convId = await insertLegacyConversation(); - const conv = await collections.conversations.findOne({ _id: new ObjectId(convId) }); + const conv = await collections.conversations.findOne({ _id: convId }); if (!conv) throw new Error("Conversation not found"); const newConv = convertLegacyConversation(conv); diff --git a/frontend/src/routes/+layout.svelte b/frontend/src/routes/+layout.svelte index dedda2206..45535b684 100644 --- a/frontend/src/routes/+layout.svelte +++ b/frontend/src/routes/+layout.svelte @@ -1,12 +1,11 @@ @@ -130,7 +128,7 @@
- + diff --git a/frontend/src/routes/+page.svelte b/frontend/src/routes/+page.svelte index c14b33a22..d561df31c 100644 --- a/frontend/src/routes/+page.svelte +++ b/frontend/src/routes/+page.svelte @@ -1,5 +1,6 @@