Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 231 additions & 19 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import inspect
import json
import asyncio
import hashlib
import re
from functools import wraps

from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
Expand Down Expand Up @@ -57,7 +60,186 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_pipe_id(pipe_id: str) -> bool:
"""Validate that pipe_id contains only safe characters"""
# Allow alphanumeric, hyphens, underscores, and dots only
if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id):
return False
# Prevent path traversal
if '..' in pipe_id or pipe_id.startswith('/') or pipe_id.startswith('\\'):
return False
return True


def verify_function_integrity(pipe_id: str, function_code: str) -> bool:
"""Verify that the function hasn't been tampered with"""
function_record = Functions.get_function_by_id(pipe_id)
if not function_record:
return False

# Check if function is marked as active/approved
if not getattr(function_record, 'is_active', True):
return False

return True


def sanitize_params(params: dict) -> dict:
"""Sanitize parameters to prevent code injection"""
sanitized = {}
allowed_types = (str, int, float, bool, list, dict, type(None), BaseModel)

for key, value in params.items():
# Check if key is a valid identifier
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key):
log.warning(f"Skipping invalid parameter key: {key}")
continue

# Check value type
if isinstance(value, allowed_types):
if isinstance(value, dict):
sanitized[key] = sanitize_params(value)
elif isinstance(value, list):
sanitized[key] = [
sanitize_params(item) if isinstance(item, dict) else item
for item in value
if isinstance(item, allowed_types)
]
else:
sanitized[key] = value
elif isinstance(value, BaseModel):
sanitized[key] = value
else:
log.warning(f"Skipping parameter {key} with disallowed type: {type(value)}")

return sanitized


def validate_pipe_signature(pipe_func) -> bool:
"""Validate that the pipe function has a safe signature"""
try:
sig = inspect.signature(pipe_func)

# Check for suspicious parameter names or types
for param_name, param in sig.parameters.items():
# Block parameters that could be used for code injection
if param_name.startswith('__') and param_name not in [
'__user__', '__event_emitter__', '__event_call__',
'__chat_id__', '__session_id__', '__message_id__',
'__task__', '__task_body__', '__files__', '__metadata__',
'__oauth_token__', '__request__', '__tools__', '__model__',
'__messages__'
]:
log.warning(f"Suspicious parameter name detected: {param_name}")
return False

return True
except Exception as e:
log.error(f"Error validating pipe signature: {e}")
return False


def safe_execute_pipes_callable(function_module, pipe_id: str):
"""
Safely execute the pipes callable with validation and sandboxing.
Returns a list of pipes or empty list on error.
"""
if not hasattr(function_module, 'pipes'):
return []

pipes_attr = function_module.pipes

# If it's already a list, validate and return it
if isinstance(pipes_attr, list):
return pipes_attr

# If it's callable, execute with safety checks
if callable(pipes_attr):
try:
# Validate the callable signature to ensure it doesn't require unexpected parameters
sig = inspect.signature(pipes_attr)
if len(sig.parameters) > 0:
log.warning(f"Function {pipe_id} pipes() callable has unexpected parameters, skipping")
return []

# Execute in a controlled manner
if asyncio.iscoroutinefunction(pipes_attr):
# For async functions, we need to run them in the current event loop
result = asyncio.create_task(pipes_attr())
# Wait for the result with a timeout
loop = asyncio.get_event_loop()
result = loop.run_until_complete(asyncio.wait_for(result, timeout=5.0))
else:
result = pipes_attr()

# Validate the result is a list
if not isinstance(result, list):
log.warning(f"Function {pipe_id} pipes() returned non-list type, skipping")
return []

# Validate each pipe in the result
validated_pipes = []
for pipe in result:
if isinstance(pipe, dict) and 'id' in pipe and 'name' in pipe:
# Additional validation for pipe structure
if isinstance(pipe['id'], str) and isinstance(pipe['name'], str):
validated_pipes.append(pipe)
else:
log.warning(f"Invalid pipe structure in {pipe_id}, skipping entry")
else:
log.warning(f"Invalid pipe format in {pipe_id}, skipping entry")

return validated_pipes

except asyncio.TimeoutError:
log.error(f"Timeout executing pipes() for function {pipe_id}")
return []
except Exception as e:
log.error(f"Error executing pipes() for function {pipe_id}: {e}")
return []

return []


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate pipe_id format
if not validate_pipe_id(pipe_id):
log.error(f"Invalid pipe_id format: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function identifier"
)

# Verify the function exists and is authorized
function_record = Functions.get_function_by_id(pipe_id)
if not function_record:
log.error(f"Function not found: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Function not found"
)

# Verify function integrity
function_code = getattr(function_record, 'content', '')
if not verify_function_integrity(pipe_id, function_code):
log.error(f"Function integrity check failed: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Function failed security validation"
)

# Get user from request and verify access
user = None
if hasattr(request.state, 'user'):
user = request.state.user

if user and not has_access(user.id, type="read", access_control=getattr(function_record, 'access_control', {})):
log.error(f"User {user.id} does not have access to function {pipe_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this function"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand All @@ -84,6 +266,11 @@ async def get_function_models(request):

for pipe in pipes:
try:
# Validate pipe ID before processing
if not validate_pipe_id(pipe.id):
log.warning(f"Skipping invalid pipe_id: {pipe.id}")
continue

function_module = get_function_module_by_id(request, pipe.id)

has_user_valves = False
Expand All @@ -92,27 +279,21 @@ async def get_function_models(request):

# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []

# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
# Use safe execution method
sub_pipes = safe_execute_pipes_callable(function_module, pipe.id)

log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)

for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'

# Validate sub-pipe ID
if not validate_pipe_id(sub_pipe_id):
log.warning(f"Skipping invalid sub_pipe_id: {sub_pipe_id}")
continue

sub_pipe_name = p["name"]

if hasattr(function_module, "name"):
Expand Down Expand Up @@ -160,10 +341,33 @@ async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
# Validate pipe function signature
if not validate_pipe_signature(pipe):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid pipe function signature detected"
)

# Sanitize parameters before execution
sanitized_params = sanitize_params(params)

# Execute with timeout to prevent hanging
try:
if inspect.iscoroutinefunction(pipe):
return await asyncio.wait_for(pipe(**sanitized_params), timeout=300.0)
else:
# Run sync function in executor to prevent blocking
loop = asyncio.get_event_loop()
return await asyncio.wait_for(
loop.run_in_executor(None, lambda: pipe(**sanitized_params)),
timeout=300.0
)
except asyncio.TimeoutError:
log.error("Pipe execution timeout")
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail="Function execution timeout"
)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -290,6 +494,14 @@ def get_function_params(function_module, form_data, user, extra_params=None):
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)

pipe_id = get_pipe_id(form_data)

# Validate pipe_id before getting function module
if not validate_pipe_id(pipe_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function identifier"
)

function_module = get_function_module_by_id(request, pipe_id)

pipe = function_module.pipe
Expand Down Expand Up @@ -350,4 +562,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)