Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions backend/app/rag/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def _build_traceable(name: str, run_type: str, metadata: Optional[dict[str, Any]
return _langsmith_traceable(name=name, run_type=run_type)


import inspect

async def async_trace_call(
name: str,
fn: Callable[..., Any],
*args: Any,
run_type: str = "chain",
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Execute an asynchronous callable with LangSmith tracing when available."""
if not LANGSMITH_ENABLED:
return await fn(*args, **kwargs)

decorator = _build_traceable(name, run_type, metadata)
if decorator is None:
return await fn(*args, **kwargs)

traced_fn = decorator(fn)
return await traced_fn(*args, **kwargs)


def trace_call(
name: str,
fn: Callable[..., Any],
Expand All @@ -65,7 +87,10 @@ def trace_call(
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Execute a callable with LangSmith tracing when available."""
"""Execute a callable with LangSmith tracing when available. Supports both sync and async."""
if inspect.iscoroutinefunction(fn):
return async_trace_call(name, fn, *args, run_type=run_type, metadata=metadata, **kwargs)

if not LANGSMITH_ENABLED:
return fn(*args, **kwargs)

Expand All @@ -83,8 +108,22 @@ def trace_function(
run_type: str = "chain",
metadata_factory: Optional[Callable[..., dict[str, Any]]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator wrapper that becomes a no-op when LangSmith is disabled."""
"""Decorator wrapper that becomes a no-op when LangSmith is disabled. Supports both sync and async."""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
if inspect.iscoroutinefunction(fn):
@wraps(fn)
async def async_wrapped(*args: Any, **kwargs: Any) -> Any:
metadata = metadata_factory(*args, **kwargs) if metadata_factory else None
return await trace_call(
name,
fn,
*args,
run_type=run_type,
metadata=metadata,
**kwargs,
)
return async_wrapped

@wraps(fn)
def wrapped(*args: Any, **kwargs: Any) -> Any:
metadata = metadata_factory(*args, **kwargs) if metadata_factory else None
Expand All @@ -96,7 +135,6 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
metadata=metadata,
**kwargs,
)

return wrapped

return decorator
return decorator
Loading