diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..bbd35474136 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,8 +3,9 @@ import inspect import json import asyncio +import re -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing import AsyncGenerator, Generator, Iterator from fastapi import ( Depends, @@ -57,7 +58,120 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_pipe_id(pipe_id: str) -> str: + """ + Validate pipe_id to prevent code injection attacks. + Only allow alphanumeric characters, hyphens, underscores, and dots. + """ + if not pipe_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid pipe_id: empty value" + ) + + # Allow alphanumeric, hyphens, underscores, and dots only + if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid pipe_id: contains invalid characters" + ) + + # Prevent directory traversal + if '..' in pipe_id or pipe_id.startswith('.') or pipe_id.startswith('/'): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid pipe_id: invalid format" + ) + + return pipe_id + + +def validate_and_sanitize_valves(valves_data: dict, Valves: type) -> dict: + """ + Validate and sanitize valves data before passing to Valves constructor. + Only allows primitive types and ensures data matches the Valves schema. + """ + if not valves_data: + return {} + + # Get the field types from the Valves class + if not hasattr(Valves, '__annotations__'): + return {} + + sanitized_valves = {} + allowed_types = (str, int, float, bool, type(None)) + + for key, value in valves_data.items(): + if value is None: + continue + + # Only process fields that are defined in the Valves class + if key not in Valves.__annotations__: + log.warning(f"Ignoring undefined valve field: {key}") + continue + + # Check if the value is a primitive type + if not isinstance(value, allowed_types): + # Allow lists and dicts only if they contain primitive types + if isinstance(value, list): + if all(isinstance(item, allowed_types) for item in value): + sanitized_valves[key] = value + else: + log.warning(f"Ignoring valve field with complex list: {key}") + elif isinstance(value, dict): + if all(isinstance(k, str) and isinstance(v, allowed_types) for k, v in value.items()): + sanitized_valves[key] = value + else: + log.warning(f"Ignoring valve field with complex dict: {key}") + else: + log.warning(f"Ignoring valve field with non-primitive type: {key}") + else: + sanitized_valves[key] = value + + return sanitized_valves + + +def validate_function_module(function_module, pipe_id: str): + """ + Validate that the function module has a proper pipe function. + Ensures the pipe function is actually callable and not arbitrary code. + """ + if not hasattr(function_module, "pipe"): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Function module {pipe_id} does not have a 'pipe' function" + ) + + if not callable(function_module.pipe): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Function module {pipe_id} 'pipe' attribute is not callable" + ) + + # Verify the function has a proper signature + try: + sig = inspect.signature(function_module.pipe) + except (ValueError, TypeError) as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Function module {pipe_id} has invalid pipe signature: {str(e)}" + ) + + return True + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate pipe_id before processing + pipe_id = validate_pipe_id(pipe_id) + + # Verify the function exists in the database + function = Functions.get_function_by_id(pipe_id.split('.')[0]) + if not function: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Function not found: {pipe_id}" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -66,12 +180,19 @@ def get_function_module_by_id(request: Request, pipe_id: str): if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + # Validate and sanitize valves data before passing to constructor + sanitized_valves = validate_and_sanitize_valves(valves, Valves) + function_module.valves = Valves(**sanitized_valves) + except ValidationError as e: + log.error(f"Validation error loading valves for function {pipe_id}: {e}") + # Fall back to default valves on validation error + function_module.valves = Valves() except Exception as e: log.exception(f"Error loading valves for function {pipe_id}: {e}") - raise e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error loading function valves: {str(e)}" + ) else: function_module.valves = Valves() @@ -159,11 +280,28 @@ async def get_function_models(request): async def generate_function_chat_completion( request, form_data, user, models: dict = {} ): - async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) + async def execute_pipe(pipe, params, pipe_id: str): + """ + Securely execute a validated pipe function with sanitized parameters. + """ + # Additional validation before execution + if not callable(pipe): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid pipe function" + ) + + try: + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + except Exception as e: + log.error(f"Error executing pipe {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error executing function: {str(e)}" + ) async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): @@ -212,7 +350,12 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + # Validate and sanitize user valves data + sanitized_user_valves = validate_and_sanitize_valves(user_valves, function_module.UserValves) + params["__user__"]["valves"] = function_module.UserValves(**sanitized_user_valves) + except ValidationError as e: + log.error(f"Validation error loading user valves: {e}") + params["__user__"]["valves"] = function_module.UserValves() except Exception as e: log.exception(e) params["__user__"]["valves"] = function_module.UserValves() @@ -291,6 +434,9 @@ def get_function_params(function_module, form_data, user, extra_params=None): pipe_id = get_pipe_id(form_data) function_module = get_function_module_by_id(request, pipe_id) + + # Validate the function module before execution + validate_function_module(function_module, pipe_id) pipe = function_module.pipe params = get_function_params(function_module, form_data, user, extra_params) @@ -299,7 +445,7 @@ def get_function_params(function_module, form_data, user, extra_params=None): async def stream_content(): try: - res = await execute_pipe(pipe, params) + res = await execute_pipe(pipe, params, pipe_id) # Directly return if the response is a StreamingResponse if isinstance(res, StreamingResponse): @@ -338,7 +484,7 @@ async def stream_content(): return StreamingResponse(stream_content(), media_type="text/event-stream") else: try: - res = await execute_pipe(pipe, params) + res = await execute_pipe(pipe, params, pipe_id) except Exception as e: log.error(f"Error: {e}") @@ -350,4 +496,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file