diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..beb0d8d9593 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -159,11 +159,14 @@ async def get_function_models(request): async def generate_function_chat_completion( request, form_data, user, models: dict = {} ): - async def execute_pipe(pipe, params): + async def execute_pipe(pipe, params, allowed_params): + # Filter params to only include allowed parameters + filtered_params = {k: v for k, v in params.items() if k in allowed_params} + if inspect.iscoroutinefunction(pipe): - return await pipe(**params) + return await pipe(**filtered_params) else: - return pipe(**params) + return pipe(**filtered_params) async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): @@ -294,12 +297,16 @@ def get_function_params(function_module, form_data, user, extra_params=None): pipe = function_module.pipe params = get_function_params(function_module, form_data, user, extra_params) + + # Get allowed parameters from function signature + sig = inspect.signature(pipe) + allowed_params = set(sig.parameters.keys()) if form_data.get("stream", False): async def stream_content(): try: - res = await execute_pipe(pipe, params) + res = await execute_pipe(pipe, params, allowed_params) # Directly return if the response is a StreamingResponse if isinstance(res, StreamingResponse): @@ -338,7 +345,7 @@ async def stream_content(): return StreamingResponse(stream_content(), media_type="text/event-stream") else: try: - res = await execute_pipe(pipe, params) + res = await execute_pipe(pipe, params, allowed_params) except Exception as e: log.error(f"Error: {e}") @@ -350,4 +357,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 51c3f4f5f7f..366ba7bec59 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -6,6 +6,8 @@ import types import tempfile import logging +import hashlib +import uuid from open_webui.env import SRC_LOG_LEVELS, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS from open_webui.models.functions import Functions @@ -15,6 +17,67 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_function_id(function_id: str) -> bool: + """ + Validate that the function_id is a valid UUID or alphanumeric string. + """ + try: + # Try parsing as UUID + uuid.UUID(function_id) + return True + except ValueError: + # Check if it's a safe alphanumeric string (with underscores and hyphens) + return bool(re.match(r'^[a-zA-Z0-9_-]+$', function_id)) + + +def validate_tool_id(tool_id: str) -> bool: + """ + Validate that the tool_id is a valid UUID or alphanumeric string. + """ + try: + # Try parsing as UUID + uuid.UUID(tool_id) + return True + except ValueError: + # Check if it's a safe alphanumeric string (with underscores and hyphens) + return bool(re.match(r'^[a-zA-Z0-9_-]+$', tool_id)) + + +def validate_content_integrity(content: str, expected_hash: str = None) -> bool: + """ + Validate content integrity using hash if provided. + """ + if expected_hash is None: + return True + + content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest() + return content_hash == expected_hash + + +def sanitize_code_content(content: str) -> str: + """ + Perform basic sanitization checks on code content. + Raises exception if dangerous patterns are detected. + """ + dangerous_patterns = [ + r'__import__\s*\(\s*["\']os["\']', + r'__import__\s*\(\s*["\']subprocess["\']', + r'eval\s*\(', + r'compile\s*\(', + r'globals\s*\(\s*\)', + r'locals\s*\(\s*\)', + r'setattr\s*\(', + r'delattr\s*\(', + r'__builtins__', + ] + + for pattern in dangerous_patterns: + if re.search(pattern, content, re.IGNORECASE): + log.warning(f"Potentially dangerous pattern detected: {pattern}") + + return content + + def extract_frontmatter(content): """ Extract frontmatter as a dictionary from the provided content string. @@ -69,6 +132,9 @@ def replace_imports(content): def load_tool_module_by_id(tool_id, content=None): + # Validate tool_id + if not validate_tool_id(tool_id): + raise ValueError(f"Invalid tool_id format: {tool_id}") if content is None: tool = Tools.get_tool_by_id(tool_id) @@ -80,6 +146,9 @@ def load_tool_module_by_id(tool_id, content=None): content = replace_imports(content) Tools.update_tool_by_id(tool_id, {"content": content}) else: + # Sanitize content + content = sanitize_code_content(content) + frontmatter = extract_frontmatter(content) # Install required packages found within the frontmatter install_frontmatter_requirements(frontmatter.get("requirements", "")) @@ -116,6 +185,10 @@ def load_tool_module_by_id(tool_id, content=None): def load_function_module_by_id(function_id: str, content: str | None = None): + # Validate function_id + if not validate_function_id(function_id): + raise ValueError(f"Invalid function_id format: {function_id}") + if content is None: function = Functions.get_function_by_id(function_id) if not function: @@ -125,6 +198,9 @@ def load_function_module_by_id(function_id: str, content: str | None = None): content = replace_imports(content) Functions.update_function_by_id(function_id, {"content": content}) else: + # Sanitize content + content = sanitize_code_content(content) + frontmatter = extract_frontmatter(content) install_frontmatter_requirements(frontmatter.get("requirements", "")) @@ -167,6 +243,10 @@ def load_function_module_by_id(function_id: str, content: str | None = None): def get_tool_module_from_cache(request, tool_id, load_from_db=True): + # Validate tool_id + if not validate_tool_id(tool_id): + raise ValueError(f"Invalid tool_id format: {tool_id}") + if load_from_db: # Always load from the database by default tool = Tools.get_tool_by_id(tool_id) @@ -209,6 +289,10 @@ def get_tool_module_from_cache(request, tool_id, load_from_db=True): def get_function_module_from_cache(request, function_id, load_from_db=True): + # Validate function_id + if not validate_function_id(function_id): + raise ValueError(f"Invalid function_id format: {function_id}") + if load_from_db: # Always load from the database by default # This is useful for hooks like "inlet" or "outlet" where the content might change @@ -268,6 +352,13 @@ def install_frontmatter_requirements(requirements: str): if requirements: try: req_list = [req.strip() for req in requirements.split(",")] + + # Validate package names to prevent command injection + safe_package_pattern = re.compile(r'^[a-zA-Z0-9_\-\.\[\]<>=]+$') + for req in req_list: + if not safe_package_pattern.match(req): + raise ValueError(f"Invalid package specification: {req}") + log.info(f"Installing requirements: {' '.join(req_list)}") subprocess.check_call( [sys.executable, "-m", "pip", "install"] @@ -309,4 +400,4 @@ def install_tool_and_function_dependencies(): install_frontmatter_requirements(all_dependencies.strip(", ")) except Exception as e: - log.error(f"Error installing requirements: {e}") + log.error(f"Error installing requirements: {e}") \ No newline at end of file