Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import json
import asyncio
import re

from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
Expand Down Expand Up @@ -57,7 +58,79 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_pipe_id(pipe_id: str) -> bool:
"""
Validate that the pipe_id is safe and belongs to an authorized function.
Only allow alphanumeric characters, hyphens, underscores, and dots.
"""
if not pipe_id or not isinstance(pipe_id, str):
return False

# Only allow alphanumeric, hyphens, underscores, and dots
if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id):
return False

# Prevent path traversal attempts
if '..' in pipe_id or pipe_id.startswith('.') or pipe_id.startswith('/'):
return False

return True


def validate_pipe_params(params: dict) -> dict:
"""
Validate and sanitize parameters passed to pipe functions.
Remove any potentially dangerous keys and ensure data types are safe.
"""
if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary")

# Create a copy to avoid modifying the original
safe_params = {}

# Whitelist of allowed parameter keys
allowed_keys = {
'body', '__event_emitter__', '__event_call__', '__chat_id__',
'__session_id__', '__message_id__', '__task__', '__task_body__',
'__files__', '__user__', '__metadata__', '__oauth_token__',
'__request__', '__tools__', '__model__', '__messages__'
}

for key, value in params.items():
if key in allowed_keys:
safe_params[key] = value
else:
log.warning(f"Filtered out unexpected parameter key: {key}")

return safe_params


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate pipe_id to prevent code injection
if not validate_pipe_id(pipe_id):
log.error(f"Invalid pipe_id format: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function ID format"
)

# Verify that the function exists in the database
function = Functions.get_function_by_id(pipe_id.split('.')[0])
if not function:
log.error(f"Function not found: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Function not found"
)

# Verify the function is active
if not function.is_active:
log.error(f"Function is not active: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Function is not active"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand Down Expand Up @@ -160,10 +233,48 @@ async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
"""
Execute pipe function with validated parameters and security controls.
"""
# Validate that pipe is actually a callable function
if not callable(pipe):
raise ValueError("Invalid pipe: must be a callable function")

# Validate and sanitize parameters
try:
safe_params = validate_pipe_params(params)
except Exception as e:
log.error(f"Parameter validation failed: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid parameters provided to pipe function"
)

# Execute with timeout to prevent infinite loops
try:
if inspect.iscoroutinefunction(pipe):
# Set a reasonable timeout (e.g., 300 seconds)
result = await asyncio.wait_for(
pipe(**safe_params),
timeout=300.0
)
else:
# For sync functions, run in executor to avoid blocking
loop = asyncio.get_event_loop()
result = await asyncio.wait_for(
loop.run_in_executor(None, lambda: pipe(**safe_params)),
timeout=300.0
)
return result
except asyncio.TimeoutError:
log.error(f"Pipe execution timed out")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Pipe execution timed out"
)
except Exception as e:
log.error(f"Pipe execution failed: {e}")
raise

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -350,4 +461,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)