diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..d43a8782349 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,8 +3,10 @@ import inspect import json import asyncio +import hashlib +import hmac -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing import AsyncGenerator, Generator, Iterator from fastapi import ( Depends, @@ -57,7 +59,139 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_function_code(pipe_id: str, function_code: str) -> bool: + """ + Validate function code against stored hash to prevent code injection. + Returns True if validation passes, False otherwise. + """ + try: + # Get the function from database + function = Functions.get_function_by_id(pipe_id) + if not function: + log.error(f"Function {pipe_id} not found in database") + return False + + # Verify code hash if available + if hasattr(function, 'code_hash'): + computed_hash = hashlib.sha256(function_code.encode()).hexdigest() + if not hmac.compare_digest(computed_hash, function.code_hash): + log.error(f"Code hash mismatch for function {pipe_id}") + return False + + # Additional validation: check for dangerous imports/operations + dangerous_patterns = [ + 'import os', + 'import subprocess', + 'import sys', + '__import__', + 'eval(', + 'exec(', + 'compile(', + 'open(', + ] + + for pattern in dangerous_patterns: + if pattern in function_code: + log.warning(f"Potentially dangerous pattern '{pattern}' found in function {pipe_id}") + # This is a warning, not blocking, as legitimate functions may need these + + return True + + except Exception as e: + log.error(f"Error validating function code for {pipe_id}: {e}") + return False + + +def sanitize_valve_value(value): + """ + Sanitize valve values to prevent code injection. + Only allows primitive types and basic collections. + """ + if value is None: + return None + + # Allow only safe primitive types + if isinstance(value, (str, int, float, bool)): + return value + + # Allow lists of primitives + if isinstance(value, list): + return [sanitize_valve_value(item) for item in value] + + # Allow dicts of primitives + if isinstance(value, dict): + return {k: sanitize_valve_value(v) for k, v in value.items() if isinstance(k, str)} + + # Reject any other types (objects, functions, etc.) + log.warning(f"Rejecting non-primitive valve value of type {type(value)}") + return None + + +def validate_valves(Valves, valves_data: dict, pipe_id: str): + """ + Validate and sanitize valve data before instantiation. + Uses Pydantic validation and additional sanitization. + """ + try: + # First, sanitize all values + sanitized_valves = {} + for k, v in valves_data.items(): + if v is not None: + sanitized_value = sanitize_valve_value(v) + if sanitized_value is not None: + sanitized_valves[k] = sanitized_value + + # Get the Valves class fields to validate against schema + if hasattr(Valves, 'model_fields'): + # Pydantic v2 + valid_fields = set(Valves.model_fields.keys()) + elif hasattr(Valves, '__fields__'): + # Pydantic v1 + valid_fields = set(Valves.__fields__.keys()) + else: + # Fallback: allow all fields + valid_fields = set(sanitized_valves.keys()) + + # Only keep fields that are defined in the Valves schema + filtered_valves = {k: v for k, v in sanitized_valves.items() if k in valid_fields} + + # Validate using Pydantic + validated_valves = Valves(**filtered_valves) + + return validated_valves + + except ValidationError as e: + log.error(f"Valve validation error for function {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid valve configuration: {str(e)}" + ) + except Exception as e: + log.error(f"Unexpected error validating valves for function {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error validating valve configuration" + ) + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate the function before loading + function = Functions.get_function_by_id(pipe_id) + if not function: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Function {pipe_id} not found" + ) + + # Verify function code integrity + if hasattr(function, 'content'): + if not validate_function_code(pipe_id, function.content): + log.error(f"Function code validation failed for {pipe_id}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Function code validation failed - potential security issue detected" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -66,12 +200,16 @@ def get_function_module_by_id(request: Request, pipe_id: str): if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + # Use validated and sanitized valves instantiation + function_module.valves = validate_valves(Valves, valves, pipe_id) + except HTTPException: + raise except Exception as e: log.exception(f"Error loading valves for function {pipe_id}: {e}") - raise e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error loading valve configuration: {str(e)}" + ) else: function_module.valves = Valves() @@ -212,7 +350,14 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + # Validate user valves as well + params["__user__"]["valves"] = validate_valves( + function_module.UserValves, user_valves, pipe_id + ) + except HTTPException: + # If validation fails, use default valves + log.warning(f"User valve validation failed for {pipe_id}, using defaults") + params["__user__"]["valves"] = function_module.UserValves() except Exception as e: log.exception(e) params["__user__"]["valves"] = function_module.UserValves() @@ -350,4 +495,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file