diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..8489e1bc5c9 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,8 +3,9 @@ import inspect import json import asyncio +import hashlib -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing import AsyncGenerator, Generator, Iterator from fastapi import ( Depends, @@ -57,33 +58,177 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_pipe_id(pipe_id: str) -> bool: + """ + Validate pipe_id to prevent code injection attacks. + Only allows alphanumeric characters, underscores, hyphens, and dots. + """ + if not pipe_id: + return False + + # Allow alphanumeric, underscore, hyphen, and dot (for manifold pipes) + import re + if not re.match(r'^[a-zA-Z0-9_\-\.]+$', pipe_id): + return False + + # Additional length check to prevent abuse + if len(pipe_id) > 256: + return False + + return True + + +def safe_instantiate_valves(Valves, valves_data: dict, pipe_id: str): + """ + Safely instantiate Valves objects with validation. + Only accepts primitive types and validates against the Pydantic model schema. + """ + if not valves_data: + return Valves() + + # Filter out None values + filtered_data = {k: v for k, v in valves_data.items() if v is not None} + + # Validate that all values are primitive types (no complex objects that could be malicious) + allowed_types = (str, int, float, bool, type(None), list, dict) + + def is_safe_value(value): + """Recursively check if a value contains only safe primitive types""" + if isinstance(value, allowed_types): + if isinstance(value, dict): + return all(isinstance(k, str) and is_safe_value(v) for k, v in value.items()) + elif isinstance(value, list): + return all(is_safe_value(item) for item in value) + return True + return False + + # Validate all values in the data + for key, value in filtered_data.items(): + if not is_safe_value(value): + log.error(f"Unsafe value type detected in valves for {pipe_id}: {key}={type(value)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid valve data: unsafe type for key '{key}'" + ) + + # Use Pydantic's validation to safely instantiate the Valves object + try: + return Valves(**filtered_data) + except ValidationError as e: + log.error(f"Validation error instantiating valves for {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid valve data: {str(e)}" + ) + except Exception as e: + log.error(f"Error instantiating valves for {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error loading function configuration" + ) + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate pipe_id to prevent code injection + if not validate_pipe_id(pipe_id): + log.error(f"Invalid pipe_id format: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid function ID format" + ) + + # Verify the function exists in the database before attempting to load + base_pipe_id = pipe_id.split('.')[0] if '.' in pipe_id else pipe_id + function_record = Functions.get_function_by_id(base_pipe_id) + + if not function_record: + log.error(f"Function not found: {base_pipe_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function not found" + ) + + # Verify the function is active + if not function_record.is_active: + log.error(f"Function is not active: {base_pipe_id}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Function is not active" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): Valves = function_module.Valves - valves = Functions.get_function_valves_by_id(pipe_id) + valves = Functions.get_function_valves_by_id(base_pipe_id) if valves: - try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) - except Exception as e: - log.exception(f"Error loading valves for function {pipe_id}: {e}") - raise e + function_module.valves = safe_instantiate_valves(Valves, valves, base_pipe_id) else: function_module.valves = Valves() return function_module +def safe_get_sub_pipes(function_module, pipe_id: str): + """ + Safely retrieve sub-pipes from a function module. + Validates that the pipes attribute is callable or a list before execution. + """ + if not hasattr(function_module, "pipes"): + return None + + pipes_attr = function_module.pipes + + # If it's a list, validate and return it directly + if isinstance(pipes_attr, list): + # Validate list structure + for item in pipes_attr: + if not isinstance(item, dict): + log.error(f"Invalid pipe item in {pipe_id}: expected dict, got {type(item)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid pipe configuration" + ) + if "id" not in item or "name" not in item: + log.error(f"Invalid pipe item in {pipe_id}: missing required keys") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid pipe configuration" + ) + return pipes_attr + + # If it's callable, validate it's a legitimate function before calling + if callable(pipes_attr): + # Verify it's a function or method, not an arbitrary callable + if not (inspect.isfunction(pipes_attr) or + inspect.ismethod(pipes_attr) or + inspect.iscoroutinefunction(pipes_attr)): + log.error(f"Invalid pipes attribute in {pipe_id}: not a function or method") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid pipe configuration" + ) + return pipes_attr + + log.error(f"Invalid pipes attribute in {pipe_id}: expected callable or list, got {type(pipes_attr)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid pipe configuration" + ) + + async def get_function_models(request): pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: try: + # Validate pipe ID before processing + if not validate_pipe_id(pipe.id): + log.warning(f"Skipping pipe with invalid ID: {pipe.id}") + continue + function_module = get_function_module_by_id(request, pipe.id) has_user_valves = False @@ -94,17 +239,30 @@ async def get_function_models(request): if hasattr(function_module, "pipes"): sub_pipes = [] - # Handle pipes being a list, sync function, or async function + # Safely retrieve and call pipes try: - if callable(function_module.pipes): - if asyncio.iscoroutinefunction(function_module.pipes): - sub_pipes = await function_module.pipes() + pipes_attr = safe_get_sub_pipes(function_module, pipe.id) + + if pipes_attr is None: + continue + + # Handle callable pipes + if callable(pipes_attr): + if asyncio.iscoroutinefunction(pipes_attr): + sub_pipes = await pipes_attr() else: - sub_pipes = function_module.pipes() + sub_pipes = pipes_attr() else: - sub_pipes = function_module.pipes + # Already validated as a list + sub_pipes = pipes_attr + + # Validate the result + if not isinstance(sub_pipes, list): + log.error(f"pipes() for {pipe.id} returned non-list: {type(sub_pipes)}") + continue + except Exception as e: - log.exception(e) + log.exception(f"Error calling pipes() for {pipe.id}: {e}") sub_pipes = [] log.debug( @@ -112,7 +270,17 @@ async def get_function_models(request): ) for p in sub_pipes: + if not isinstance(p, dict) or "id" not in p or "name" not in p: + log.warning(f"Invalid sub-pipe structure in {pipe.id}: {p}") + continue + sub_pipe_id = f'{pipe.id}.{p["id"]}' + + # Validate sub-pipe ID + if not validate_pipe_id(sub_pipe_id): + log.warning(f"Skipping sub-pipe with invalid ID: {sub_pipe_id}") + continue + sub_pipe_name = p["name"] if hasattr(function_module, "name"): @@ -212,7 +380,9 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + params["__user__"]["valves"] = safe_instantiate_valves( + function_module.UserValves, user_valves, pipe_id + ) except Exception as e: log.exception(e) params["__user__"]["valves"] = function_module.UserValves() @@ -220,6 +390,15 @@ def get_function_params(function_module, form_data, user, extra_params=None): return params model_id = form_data.get("model") + + # Validate model_id to prevent code injection + if not validate_pipe_id(model_id): + log.error(f"Invalid model_id format: {model_id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid model ID format" + ) + model_info = Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", {}) @@ -350,4 +529,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file