Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 139 additions & 17 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import asyncio

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
Expand Down Expand Up @@ -57,21 +57,121 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_pipe_id(pipe_id: str) -> bool:
"""Validate that pipe_id is safe and corresponds to an authorized function."""
if not pipe_id or not isinstance(pipe_id, str):
return False

# Extract base pipe_id if it contains a dot (manifold pipes)
base_pipe_id = pipe_id.split('.')[0] if '.' in pipe_id else pipe_id

# Check if the function exists in the database
function = Functions.get_function_by_id(base_pipe_id)
if not function:
log.warning(f"Attempted to load non-existent function: {base_pipe_id}")
return False

# Verify function type is 'pipe'
if function.type != "pipe":
log.warning(f"Attempted to load non-pipe function: {base_pipe_id}")
return False

return True


def validate_pipe_function(pipe_function, pipe_id: str) -> bool:
"""Validate that the pipe function is safe to execute."""
if not callable(pipe_function):
log.error(f"Pipe function for {pipe_id} is not callable")
return False

# Check that the function signature is reasonable
try:
sig = inspect.signature(pipe_function)
# Ensure no dangerous parameter types
for param_name, param in sig.parameters.items():
if param.annotation != inspect.Parameter.empty:
# Block any parameters that could execute code
if 'eval' in str(param.annotation).lower() or 'exec' in str(param.annotation).lower():
log.error(f"Dangerous parameter type detected in pipe function {pipe_id}: {param_name}")
return False
except Exception as e:
log.error(f"Error inspecting pipe function {pipe_id}: {e}")
return False

return True


def sanitize_params(params: dict) -> dict:
"""Sanitize parameters to prevent code injection."""
sanitized = {}

for key, value in params.items():
# Only allow safe parameter names (no dunder methods except allowed ones)
if key.startswith('__') and key not in {'__event_emitter__', '__event_call__', '__chat_id__',
'__session_id__', '__message_id__', '__task__',
'__task_body__', '__files__', '__user__', '__metadata__',
'__oauth_token__', '__request__', '__tools__', '__model__',
'__messages__'}:
log.warning(f"Blocked suspicious parameter name: {key}")
continue

# Recursively sanitize nested dictionaries
if isinstance(value, dict):
sanitized[key] = sanitize_params(value)
# Sanitize lists
elif isinstance(value, list):
sanitized[key] = [sanitize_params(item) if isinstance(item, dict) else item for item in value]
else:
sanitized[key] = value

return sanitized


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate pipe_id before loading
if not validate_pipe_id(pipe_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or unauthorized function ID"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves

# Verify that Valves is a Pydantic BaseModel subclass
if not (inspect.isclass(Valves) and issubclass(Valves, BaseModel)):
log.error(f"Valves class for function {pipe_id} is not a valid Pydantic BaseModel")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid Valves configuration class"
)

valves = Functions.get_function_valves_by_id(pipe_id)

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Validate that valves is a dictionary
if not isinstance(valves, dict):
log.error(f"Valves data for function {pipe_id} is not a dictionary")
raise ValueError("Valves data must be a dictionary")

# Filter out None values and ensure only valid keys
filtered_valves = {k: v for k, v in valves.items() if v is not None}

# Use Pydantic's validation to safely deserialize
# This will raise ValidationError if data is invalid
function_module.valves = Valves.model_validate(filtered_valves)
except ValidationError as e:
log.error(f"Validation error loading valves for function {pipe_id}: {e}")
# Fall back to default valves instead of propagating the error
function_module.valves = Valves()
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
# Fall back to default valves for any other errors
function_module.valves = Valves()
else:
function_module.valves = Valves()

Expand Down Expand Up @@ -159,11 +259,18 @@ async def get_function_models(request):
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
async def execute_pipe(pipe, params, pipe_id: str):
# Validate the pipe function before execution
if not validate_pipe_function(pipe, pipe_id):
raise ValueError(f"Invalid or unsafe pipe function: {pipe_id}")

# Sanitize parameters before execution
sanitized_params = sanitize_params(params)

if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
return await pipe(**sanitized_params)
else:
return pipe(**params)
return pipe(**sanitized_params)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -210,12 +317,27 @@ def get_function_params(function_module, form_data, user, extra_params=None):
}

if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
# Verify that UserValves is a Pydantic BaseModel subclass
UserValves = function_module.UserValves
if not (inspect.isclass(UserValves) and issubclass(UserValves, BaseModel)):
log.error(f"UserValves class for function {pipe_id} is not a valid Pydantic BaseModel")
params["__user__"]["valves"] = None
else:
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
# Validate that user_valves is a dictionary
if not isinstance(user_valves, dict):
log.error(f"User valves data for function {pipe_id} is not a dictionary")
params["__user__"]["valves"] = UserValves()
else:
# Use Pydantic's validation to safely deserialize
params["__user__"]["valves"] = UserValves.model_validate(user_valves)
except ValidationError as e:
log.error(f"Validation error loading user valves for function {pipe_id}: {e}")
params["__user__"]["valves"] = UserValves()
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = UserValves()

return params

Expand Down Expand Up @@ -299,7 +421,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):

async def stream_content():
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, pipe_id)

# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
Expand Down Expand Up @@ -338,7 +460,7 @@ async def stream_content():
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, pipe_id)

except Exception as e:
log.error(f"Error: {e}")
Expand All @@ -350,4 +472,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)