Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 197 additions & 18 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import inspect
import json
import asyncio
import hashlib

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
Expand Down Expand Up @@ -57,33 +58,177 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_pipe_id(pipe_id: str) -> bool:
"""
Validate pipe_id to prevent code injection attacks.
Only allows alphanumeric characters, underscores, hyphens, and dots.
"""
if not pipe_id:
return False

# Allow alphanumeric, underscore, hyphen, and dot (for manifold pipes)
import re
if not re.match(r'^[a-zA-Z0-9_\-\.]+$', pipe_id):
return False

# Additional length check to prevent abuse
if len(pipe_id) > 256:
return False

return True


def safe_instantiate_valves(Valves, valves_data: dict, pipe_id: str):
"""
Safely instantiate Valves objects with validation.
Only accepts primitive types and validates against the Pydantic model schema.
"""
if not valves_data:
return Valves()

# Filter out None values
filtered_data = {k: v for k, v in valves_data.items() if v is not None}

# Validate that all values are primitive types (no complex objects that could be malicious)
allowed_types = (str, int, float, bool, type(None), list, dict)

def is_safe_value(value):
"""Recursively check if a value contains only safe primitive types"""
if isinstance(value, allowed_types):
if isinstance(value, dict):
return all(isinstance(k, str) and is_safe_value(v) for k, v in value.items())
elif isinstance(value, list):
return all(is_safe_value(item) for item in value)
return True
return False

# Validate all values in the data
for key, value in filtered_data.items():
if not is_safe_value(value):
log.error(f"Unsafe value type detected in valves for {pipe_id}: {key}={type(value)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid valve data: unsafe type for key '{key}'"
)

# Use Pydantic's validation to safely instantiate the Valves object
try:
return Valves(**filtered_data)
except ValidationError as e:
log.error(f"Validation error instantiating valves for {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid valve data: {str(e)}"
)
except Exception as e:
log.error(f"Error instantiating valves for {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error loading function configuration"
)


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate pipe_id to prevent code injection
if not validate_pipe_id(pipe_id):
log.error(f"Invalid pipe_id format: {pipe_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function ID format"
)

# Verify the function exists in the database before attempting to load
base_pipe_id = pipe_id.split('.')[0] if '.' in pipe_id else pipe_id
function_record = Functions.get_function_by_id(base_pipe_id)

if not function_record:
log.error(f"Function not found: {base_pipe_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Function not found"
)

# Verify the function is active
if not function_record.is_active:
log.error(f"Function is not active: {base_pipe_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Function is not active"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id)
valves = Functions.get_function_valves_by_id(base_pipe_id)

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
function_module.valves = safe_instantiate_valves(Valves, valves, base_pipe_id)
else:
function_module.valves = Valves()

return function_module


def safe_get_sub_pipes(function_module, pipe_id: str):
"""
Safely retrieve sub-pipes from a function module.
Validates that the pipes attribute is callable or a list before execution.
"""
if not hasattr(function_module, "pipes"):
return None

pipes_attr = function_module.pipes

# If it's a list, validate and return it directly
if isinstance(pipes_attr, list):
# Validate list structure
for item in pipes_attr:
if not isinstance(item, dict):
log.error(f"Invalid pipe item in {pipe_id}: expected dict, got {type(item)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe configuration"
)
if "id" not in item or "name" not in item:
log.error(f"Invalid pipe item in {pipe_id}: missing required keys")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe configuration"
)
return pipes_attr

# If it's callable, validate it's a legitimate function before calling
if callable(pipes_attr):
# Verify it's a function or method, not an arbitrary callable
if not (inspect.isfunction(pipes_attr) or
inspect.ismethod(pipes_attr) or
inspect.iscoroutinefunction(pipes_attr)):
log.error(f"Invalid pipes attribute in {pipe_id}: not a function or method")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe configuration"
)
return pipes_attr

log.error(f"Invalid pipes attribute in {pipe_id}: expected callable or list, got {type(pipes_attr)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe configuration"
)


async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []

for pipe in pipes:
try:
# Validate pipe ID before processing
if not validate_pipe_id(pipe.id):
log.warning(f"Skipping pipe with invalid ID: {pipe.id}")
continue

function_module = get_function_module_by_id(request, pipe.id)

has_user_valves = False
Expand All @@ -94,25 +239,48 @@ async def get_function_models(request):
if hasattr(function_module, "pipes"):
sub_pipes = []

# Handle pipes being a list, sync function, or async function
# Safely retrieve and call pipes
try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
pipes_attr = safe_get_sub_pipes(function_module, pipe.id)

if pipes_attr is None:
continue

# Handle callable pipes
if callable(pipes_attr):
if asyncio.iscoroutinefunction(pipes_attr):
sub_pipes = await pipes_attr()
else:
sub_pipes = function_module.pipes()
sub_pipes = pipes_attr()
else:
sub_pipes = function_module.pipes
# Already validated as a list
sub_pipes = pipes_attr

# Validate the result
if not isinstance(sub_pipes, list):
log.error(f"pipes() for {pipe.id} returned non-list: {type(sub_pipes)}")
continue

except Exception as e:
log.exception(e)
log.exception(f"Error calling pipes() for {pipe.id}: {e}")
sub_pipes = []

log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)

for p in sub_pipes:
if not isinstance(p, dict) or "id" not in p or "name" not in p:
log.warning(f"Invalid sub-pipe structure in {pipe.id}: {p}")
continue

sub_pipe_id = f'{pipe.id}.{p["id"]}'

# Validate sub-pipe ID
if not validate_pipe_id(sub_pipe_id):
log.warning(f"Skipping sub-pipe with invalid ID: {sub_pipe_id}")
continue

sub_pipe_name = p["name"]

if hasattr(function_module, "name"):
Expand Down Expand Up @@ -212,14 +380,25 @@ def get_function_params(function_module, form_data, user, extra_params=None):
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
params["__user__"]["valves"] = safe_instantiate_valves(
function_module.UserValves, user_valves, pipe_id
)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()

return params

model_id = form_data.get("model")

# Validate model_id to prevent code injection
if not validate_pipe_id(model_id):
log.error(f"Invalid model_id format: {model_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid model ID format"
)

model_info = Models.get_model_by_id(model_id)

metadata = form_data.pop("metadata", {})
Expand Down Expand Up @@ -350,4 +529,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)