Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 160 additions & 14 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import inspect
import json
import asyncio
import re

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
Expand Down Expand Up @@ -57,7 +58,120 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_pipe_id(pipe_id: str) -> str:
"""
Validate pipe_id to prevent code injection attacks.
Only allow alphanumeric characters, hyphens, underscores, and dots.
"""
if not pipe_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid pipe_id: empty value"
)

# Allow alphanumeric, hyphens, underscores, and dots only
if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid pipe_id: contains invalid characters"
)

# Prevent directory traversal
if '..' in pipe_id or pipe_id.startswith('.') or pipe_id.startswith('/'):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid pipe_id: invalid format"
)

return pipe_id


def validate_and_sanitize_valves(valves_data: dict, Valves: type) -> dict:
"""
Validate and sanitize valves data before passing to Valves constructor.
Only allows primitive types and ensures data matches the Valves schema.
"""
if not valves_data:
return {}

# Get the field types from the Valves class
if not hasattr(Valves, '__annotations__'):
return {}

sanitized_valves = {}
allowed_types = (str, int, float, bool, type(None))

for key, value in valves_data.items():
if value is None:
continue

# Only process fields that are defined in the Valves class
if key not in Valves.__annotations__:
log.warning(f"Ignoring undefined valve field: {key}")
continue

# Check if the value is a primitive type
if not isinstance(value, allowed_types):
# Allow lists and dicts only if they contain primitive types
if isinstance(value, list):
if all(isinstance(item, allowed_types) for item in value):
sanitized_valves[key] = value
else:
log.warning(f"Ignoring valve field with complex list: {key}")
elif isinstance(value, dict):
if all(isinstance(k, str) and isinstance(v, allowed_types) for k, v in value.items()):
sanitized_valves[key] = value
else:
log.warning(f"Ignoring valve field with complex dict: {key}")
else:
log.warning(f"Ignoring valve field with non-primitive type: {key}")
else:
sanitized_valves[key] = value

return sanitized_valves


def validate_function_module(function_module, pipe_id: str):
"""
Validate that the function module has a proper pipe function.
Ensures the pipe function is actually callable and not arbitrary code.
"""
if not hasattr(function_module, "pipe"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Function module {pipe_id} does not have a 'pipe' function"
)

if not callable(function_module.pipe):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Function module {pipe_id} 'pipe' attribute is not callable"
)

# Verify the function has a proper signature
try:
sig = inspect.signature(function_module.pipe)
except (ValueError, TypeError) as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Function module {pipe_id} has invalid pipe signature: {str(e)}"
)

return True


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate pipe_id before processing
pipe_id = validate_pipe_id(pipe_id)

# Verify the function exists in the database
function = Functions.get_function_by_id(pipe_id.split('.')[0])
if not function:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Function not found: {pipe_id}"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand All @@ -66,12 +180,19 @@ def get_function_module_by_id(request: Request, pipe_id: str):

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Validate and sanitize valves data before passing to constructor
sanitized_valves = validate_and_sanitize_valves(valves, Valves)
function_module.valves = Valves(**sanitized_valves)
except ValidationError as e:
log.error(f"Validation error loading valves for function {pipe_id}: {e}")
# Fall back to default valves on validation error
function_module.valves = Valves()
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error loading function valves: {str(e)}"
)
else:
function_module.valves = Valves()

Expand Down Expand Up @@ -159,11 +280,28 @@ async def get_function_models(request):
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def execute_pipe(pipe, params, pipe_id: str):
"""
Securely execute a validated pipe function with sanitized parameters.
"""
# Additional validation before execution
if not callable(pipe):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe function"
)

try:
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
except Exception as e:
log.error(f"Error executing pipe {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error executing function: {str(e)}"
)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -212,7 +350,12 @@ def get_function_params(function_module, form_data, user, extra_params=None):
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
# Validate and sanitize user valves data
sanitized_user_valves = validate_and_sanitize_valves(user_valves, function_module.UserValves)
params["__user__"]["valves"] = function_module.UserValves(**sanitized_user_valves)
except ValidationError as e:
log.error(f"Validation error loading user valves: {e}")
params["__user__"]["valves"] = function_module.UserValves()
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
Expand Down Expand Up @@ -291,6 +434,9 @@ def get_function_params(function_module, form_data, user, extra_params=None):

pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

# Validate the function module before execution
validate_function_module(function_module, pipe_id)

pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
Expand All @@ -299,7 +445,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):

async def stream_content():
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, pipe_id)

# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
Expand Down Expand Up @@ -338,7 +484,7 @@ async def stream_content():
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, pipe_id)

except Exception as e:
log.error(f"Error: {e}")
Expand All @@ -350,4 +496,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)