diff --git a/backend/app/rag/tracing.py b/backend/app/rag/tracing.py index f95e8b18..10a1fd86 100644 --- a/backend/app/rag/tracing.py +++ b/backend/app/rag/tracing.py @@ -12,29 +12,77 @@ logger = logging.getLogger(__name__) settings = get_settings() -try: - from langsmith import traceable as _langsmith_traceable -except Exception: # pragma: no cover - optional dependency safety - _langsmith_traceable = None - - -def configure_langsmith() -> bool: - """Configure LangSmith environment variables when tracing is enabled.""" - if not settings.LANGSMITH_TRACING: - return False - - if not settings.LANGSMITH_API_KEY: - logger.warning("LangSmith tracing enabled but LANGSMITH_API_KEY is not set; tracing disabled.") - return False +from abc import ABC, abstractmethod + +# 1. Base Class Strategy Interface +class BaseTracingProvider(ABC): + @abstractmethod + def trace_call( + self, + name: str, + fn: Callable[..., Any], + *args: Any, + run_type: str = "chain", + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + pass + + +# 2. Refactored LangSmith Implementation +class LangSmithProvider(BaseTracingProvider): + def __init__(self): + try: + from langsmith import traceable as _langsmith_traceable + self._traceable = _langsmith_traceable + except Exception: + self._traceable = None + self.enabled = self._configure() + + def _configure(self) -> bool: + # Check if the global tracing provider setting matches this provider + provider_setting = getattr(settings, "TRACING_PROVIDER", "none").lower() + if provider_setting != "langsmith": + return False + + if not settings.LANGSMITH_API_KEY: + logger.warning("LangSmith tracing enabled but LANGSMITH_API_KEY is not set; tracing disabled.") + return False + + os.environ["LANGSMITH_TRACING"] = "true" + os.environ["LANGSMITH_API_KEY"] = settings.LANGSMITH_API_KEY + os.environ["LANGSMITH_ENDPOINT"] = settings.LANGSMITH_ENDPOINT + os.environ["LANGSMITH_PROJECT"] = settings.LANGSMITH_PROJECT + return self._traceable is not None + + def trace_call(self, name, fn, *args, run_type="chain", metadata=None, **kwargs): + if not self.enabled or self._traceable is None: + return fn(*args, **kwargs) + + sanitized = {k: v for k, v in (metadata or {}).items() if v is not None} + try: + decorator = self._traceable(name=name, run_type=run_type, metadata=sanitized or None) + except TypeError: + decorator = self._traceable(name=name, run_type=run_type) + + return decorator(fn)(*args, **kwargs) + + +# 3. Fallback/No-Op Provider for "none" or alternative defaults +class NoOpProvider(BaseTracingProvider): + def trace_call(self, name, fn, *args, run_type="chain", metadata=None, **kwargs): + return fn(*args, **kwargs) - os.environ["LANGSMITH_TRACING"] = "true" - os.environ["LANGSMITH_API_KEY"] = settings.LANGSMITH_API_KEY - os.environ["LANGSMITH_ENDPOINT"] = settings.LANGSMITH_ENDPOINT - os.environ["LANGSMITH_PROJECT"] = settings.LANGSMITH_PROJECT - return _langsmith_traceable is not None +# 4. Factory Initialization +def _get_active_provider() -> BaseTracingProvider: + provider_setting = getattr(settings, "TRACING_PROVIDER", "none").lower() + if provider_setting == "langsmith": + return LangSmithProvider() + # Note: You can easily plug in LangfuseProvider here later! + return NoOpProvider() -LANGSMITH_ENABLED = configure_langsmith() +_active_provider = _get_active_provider() def _sanitize_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, Any]: @@ -65,16 +113,10 @@ def trace_call( metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: - """Execute a callable with LangSmith tracing when available.""" - if not LANGSMITH_ENABLED: - return fn(*args, **kwargs) - - decorator = _build_traceable(name, run_type, metadata) - if decorator is None: - return fn(*args, **kwargs) - - traced_fn = decorator(fn) - return traced_fn(*args, **kwargs) + """Execute a callable routing to the configured monitoring provider.""" + return _active_provider.trace_call( + name, fn, *args, run_type=run_type, metadata=metadata, **kwargs + ) def trace_function(