diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..34c6201871d 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,6 +3,9 @@ import inspect import json import asyncio +import hashlib +import re +from functools import wraps from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -57,7 +60,186 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_pipe_id(pipe_id: str) -> bool: + """Validate that pipe_id contains only safe characters""" + # Allow alphanumeric, hyphens, underscores, and dots only + if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id): + return False + # Prevent path traversal + if '..' in pipe_id or pipe_id.startswith('/') or pipe_id.startswith('\\'): + return False + return True + + +def verify_function_integrity(pipe_id: str, function_code: str) -> bool: + """Verify that the function hasn't been tampered with""" + function_record = Functions.get_function_by_id(pipe_id) + if not function_record: + return False + + # Check if function is marked as active/approved + if not getattr(function_record, 'is_active', True): + return False + + return True + + +def sanitize_params(params: dict) -> dict: + """Sanitize parameters to prevent code injection""" + sanitized = {} + allowed_types = (str, int, float, bool, list, dict, type(None), BaseModel) + + for key, value in params.items(): + # Check if key is a valid identifier + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key): + log.warning(f"Skipping invalid parameter key: {key}") + continue + + # Check value type + if isinstance(value, allowed_types): + if isinstance(value, dict): + sanitized[key] = sanitize_params(value) + elif isinstance(value, list): + sanitized[key] = [ + sanitize_params(item) if isinstance(item, dict) else item + for item in value + if isinstance(item, allowed_types) + ] + else: + sanitized[key] = value + elif isinstance(value, BaseModel): + sanitized[key] = value + else: + log.warning(f"Skipping parameter {key} with disallowed type: {type(value)}") + + return sanitized + + +def validate_pipe_signature(pipe_func) -> bool: + """Validate that the pipe function has a safe signature""" + try: + sig = inspect.signature(pipe_func) + + # Check for suspicious parameter names or types + for param_name, param in sig.parameters.items(): + # Block parameters that could be used for code injection + if param_name.startswith('__') and param_name not in [ + '__user__', '__event_emitter__', '__event_call__', + '__chat_id__', '__session_id__', '__message_id__', + '__task__', '__task_body__', '__files__', '__metadata__', + '__oauth_token__', '__request__', '__tools__', '__model__', + '__messages__' + ]: + log.warning(f"Suspicious parameter name detected: {param_name}") + return False + + return True + except Exception as e: + log.error(f"Error validating pipe signature: {e}") + return False + + +def safe_execute_pipes_callable(function_module, pipe_id: str): + """ + Safely execute the pipes callable with validation and sandboxing. + Returns a list of pipes or empty list on error. + """ + if not hasattr(function_module, 'pipes'): + return [] + + pipes_attr = function_module.pipes + + # If it's already a list, validate and return it + if isinstance(pipes_attr, list): + return pipes_attr + + # If it's callable, execute with safety checks + if callable(pipes_attr): + try: + # Validate the callable signature to ensure it doesn't require unexpected parameters + sig = inspect.signature(pipes_attr) + if len(sig.parameters) > 0: + log.warning(f"Function {pipe_id} pipes() callable has unexpected parameters, skipping") + return [] + + # Execute in a controlled manner + if asyncio.iscoroutinefunction(pipes_attr): + # For async functions, we need to run them in the current event loop + result = asyncio.create_task(pipes_attr()) + # Wait for the result with a timeout + loop = asyncio.get_event_loop() + result = loop.run_until_complete(asyncio.wait_for(result, timeout=5.0)) + else: + result = pipes_attr() + + # Validate the result is a list + if not isinstance(result, list): + log.warning(f"Function {pipe_id} pipes() returned non-list type, skipping") + return [] + + # Validate each pipe in the result + validated_pipes = [] + for pipe in result: + if isinstance(pipe, dict) and 'id' in pipe and 'name' in pipe: + # Additional validation for pipe structure + if isinstance(pipe['id'], str) and isinstance(pipe['name'], str): + validated_pipes.append(pipe) + else: + log.warning(f"Invalid pipe structure in {pipe_id}, skipping entry") + else: + log.warning(f"Invalid pipe format in {pipe_id}, skipping entry") + + return validated_pipes + + except asyncio.TimeoutError: + log.error(f"Timeout executing pipes() for function {pipe_id}") + return [] + except Exception as e: + log.error(f"Error executing pipes() for function {pipe_id}: {e}") + return [] + + return [] + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate pipe_id format + if not validate_pipe_id(pipe_id): + log.error(f"Invalid pipe_id format: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid function identifier" + ) + + # Verify the function exists and is authorized + function_record = Functions.get_function_by_id(pipe_id) + if not function_record: + log.error(f"Function not found: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function not found" + ) + + # Verify function integrity + function_code = getattr(function_record, 'content', '') + if not verify_function_integrity(pipe_id, function_code): + log.error(f"Function integrity check failed: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Function failed security validation" + ) + + # Get user from request and verify access + user = None + if hasattr(request.state, 'user'): + user = request.state.user + + if user and not has_access(user.id, type="read", access_control=getattr(function_record, 'access_control', {})): + log.error(f"User {user.id} does not have access to function {pipe_id}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied to this function" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -84,6 +266,11 @@ async def get_function_models(request): for pipe in pipes: try: + # Validate pipe ID before processing + if not validate_pipe_id(pipe.id): + log.warning(f"Skipping invalid pipe_id: {pipe.id}") + continue + function_module = get_function_module_by_id(request, pipe.id) has_user_valves = False @@ -92,20 +279,8 @@ async def get_function_models(request): # Check if function is a manifold if hasattr(function_module, "pipes"): - sub_pipes = [] - - # Handle pipes being a list, sync function, or async function - try: - if callable(function_module.pipes): - if asyncio.iscoroutinefunction(function_module.pipes): - sub_pipes = await function_module.pipes() - else: - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) - sub_pipes = [] + # Use safe execution method + sub_pipes = safe_execute_pipes_callable(function_module, pipe.id) log.debug( f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" @@ -113,6 +288,12 @@ async def get_function_models(request): for p in sub_pipes: sub_pipe_id = f'{pipe.id}.{p["id"]}' + + # Validate sub-pipe ID + if not validate_pipe_id(sub_pipe_id): + log.warning(f"Skipping invalid sub_pipe_id: {sub_pipe_id}") + continue + sub_pipe_name = p["name"] if hasattr(function_module, "name"): @@ -160,10 +341,33 @@ async def generate_function_chat_completion( request, form_data, user, models: dict = {} ): async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) + # Validate pipe function signature + if not validate_pipe_signature(pipe): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid pipe function signature detected" + ) + + # Sanitize parameters before execution + sanitized_params = sanitize_params(params) + + # Execute with timeout to prevent hanging + try: + if inspect.iscoroutinefunction(pipe): + return await asyncio.wait_for(pipe(**sanitized_params), timeout=300.0) + else: + # Run sync function in executor to prevent blocking + loop = asyncio.get_event_loop() + return await asyncio.wait_for( + loop.run_in_executor(None, lambda: pipe(**sanitized_params)), + timeout=300.0 + ) + except asyncio.TimeoutError: + log.error("Pipe execution timeout") + raise HTTPException( + status_code=status.HTTP_408_REQUEST_TIMEOUT, + detail="Function execution timeout" + ) async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): @@ -290,6 +494,14 @@ def get_function_params(function_module, form_data, user, extra_params=None): form_data = apply_system_prompt_to_body(system, form_data, metadata, user) pipe_id = get_pipe_id(form_data) + + # Validate pipe_id before getting function module + if not validate_pipe_id(pipe_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid function identifier" + ) + function_module = get_function_module_by_id(request, pipe_id) pipe = function_module.pipe @@ -350,4 +562,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file