diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..e8e854d3190 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,6 +3,8 @@ import inspect import json import asyncio +import hashlib +import hmac from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -57,7 +59,104 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_valves_schema(Valves, valves_data: dict) -> bool: + """ + Validate that valves_data only contains fields defined in the Valves class + and that the values match expected types. + """ + if not inspect.isclass(Valves): + return False + + # Get Valves class annotations/fields + try: + # For Pydantic models + if issubclass(Valves, BaseModel): + valid_fields = set(Valves.model_fields.keys()) + # Use Pydantic's validation + try: + Valves(**{k: v for k, v in valves_data.items() if v is not None}) + return True + except Exception as e: + log.error(f"Pydantic validation failed for valves: {e}") + return False + else: + # For regular classes, check __annotations__ + valid_fields = set(getattr(Valves, '__annotations__', {}).keys()) + + except Exception as e: + log.error(f"Error getting Valves fields: {e}") + return False + + # Check that all keys in valves_data are valid fields + valves_keys = set(valves_data.keys()) + if not valves_keys.issubset(valid_fields): + invalid_keys = valves_keys - valid_fields + log.error(f"Invalid valve keys detected: {invalid_keys}") + return False + + return True + + +def validate_pipes_output(pipes_output) -> bool: + """ + Validate that pipes() output is a safe list of dictionaries with expected structure. + This prevents malicious code injection through the pipes attribute/method. + """ + if not isinstance(pipes_output, list): + log.error("pipes output is not a list") + return False + + for pipe in pipes_output: + if not isinstance(pipe, dict): + log.error("pipe item is not a dictionary") + return False + + # Check for required fields + if 'id' not in pipe or 'name' not in pipe: + log.error("pipe item missing required 'id' or 'name' field") + return False + + # Validate types + if not isinstance(pipe['id'], str) or not isinstance(pipe['name'], str): + log.error("pipe 'id' or 'name' is not a string") + return False + + # Check for safe characters to prevent injection + if not all(c.isalnum() or c in '-_.' for c in pipe['id']): + log.error(f"pipe id contains unsafe characters: {pipe['id']}") + return False + + # Limit size to prevent DoS + if len(pipe['id']) > 256 or len(pipe['name']) > 512: + log.error("pipe id or name exceeds maximum length") + return False + + # Limit number of pipes to prevent DoS + if len(pipes_output) > 1000: + log.error("pipes output exceeds maximum number of items") + return False + + return True + + def get_function_module_by_id(request: Request, pipe_id: str): + # Verify user has access to this function + user = request.state.user if hasattr(request.state, 'user') else None + if user: + function = Functions.get_function_by_id(pipe_id) + if not function: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function not found" + ) + + # Check if user has permission to access this function + if not has_access(user.id, type="read", access_control=function.access_control): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions to access this function" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -65,15 +164,33 @@ def get_function_module_by_id(request: Request, pipe_id: str): valves = Functions.get_function_valves_by_id(pipe_id) if valves: - try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} + # Validate valves data before instantiation + filtered_valves = {k: v for k, v in valves.items() if v is not None} + + if not validate_valves_schema(Valves, filtered_valves): + log.error(f"Invalid valves schema for function {pipe_id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid valves configuration" ) + + try: + function_module.valves = Valves(**filtered_valves) except Exception as e: log.exception(f"Error loading valves for function {pipe_id}: {e}") - raise e + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Error instantiating valves: {str(e)}" + ) else: - function_module.valves = Valves() + try: + function_module.valves = Valves() + except Exception as e: + log.exception(f"Error creating default valves for function {pipe_id}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Error creating default valves: {str(e)}" + ) return function_module @@ -103,8 +220,14 @@ async def get_function_models(request): sub_pipes = function_module.pipes() else: sub_pipes = function_module.pipes + + # Validate pipes output to prevent code injection + if not validate_pipes_output(sub_pipes): + log.error(f"Invalid pipes output for function {pipe.id}, skipping") + sub_pipes = [] + except Exception as e: - log.exception(e) + log.exception(f"Error executing pipes for function {pipe.id}: {e}") sub_pipes = [] log.debug( @@ -211,11 +334,17 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) - except Exception as e: - log.exception(e) + + # Validate user valves before instantiation + if user_valves and not validate_valves_schema(function_module.UserValves, user_valves): + log.error(f"Invalid user valves schema for function {pipe_id} and user {user.id}") params["__user__"]["valves"] = function_module.UserValves() + else: + try: + params["__user__"]["valves"] = function_module.UserValves(**user_valves) + except Exception as e: + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() return params @@ -350,4 +479,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file