Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 155 additions & 9 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import json
import asyncio

import hashlib
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Expand Down Expand Up @@ -57,7 +57,116 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_function_access(request: Request, pipe_id: str, user: UserModel = None):
"""
Validate that the user has access to the function and that the function exists.
"""
function = Functions.get_function_by_id(pipe_id)

if not function:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND
)

if not function.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)

# Verify function integrity
if hasattr(function, 'content') and hasattr(function, 'meta'):
expected_hash = function.meta.get('content_hash')
if expected_hash:
actual_hash = hashlib.sha256(function.content.encode()).hexdigest()
if actual_hash != expected_hash:
log.error(f"Function {pipe_id} failed integrity check")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Function integrity check failed"
)

return function


def sanitize_value(value, max_depth=5, current_depth=0):
"""
Sanitize values to prevent code injection.
Only allows primitive types and basic structures.
"""
if current_depth > max_depth:
log.warning(f"Maximum nesting depth exceeded during sanitization")
return None

if value is None:
return None

if isinstance(value, (str, int, float, bool)):
return value

if isinstance(value, list):
return [sanitize_value(v, max_depth, current_depth + 1) for v in value]

if isinstance(value, dict):
return {
str(k): sanitize_value(v, max_depth, current_depth + 1)
for k, v in value.items()
if isinstance(k, (str, int, float))
}

log.warning(f"Skipping unsupported type during sanitization: {type(value)}")
return None


def sanitize_params(params):
"""
Sanitize all parameters to prevent code injection attacks.
"""
sanitized = {}

for key, value in params.items():
# Skip special internal parameters
if key.startswith('__') and key.endswith('__'):
sanitized[key] = value
continue

# Handle body parameter specially as it's a dict
if key == 'body' and isinstance(value, dict):
sanitized_body = {}
for k, v in value.items():
sanitized_body[k] = sanitize_value(v)
sanitized[key] = sanitized_body
else:
sanitized[key] = sanitize_value(value)

return sanitized


def validate_valves_schema(valves_class, valves_data):
"""
Validate that user valves data only contains fields defined in the Valves schema.
This prevents arbitrary object injection through extra fields.
"""
if not hasattr(valves_class, '__annotations__'):
return {}

allowed_fields = set(valves_class.__annotations__.keys())
validated_data = {}

for key, value in valves_data.items():
if key in allowed_fields:
validated_data[key] = value
else:
log.warning(f"Ignoring unexpected valve field: {key}")

return validated_data


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate function access and integrity
validate_function_access(request, pipe_id)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand All @@ -66,12 +175,23 @@ def get_function_module_by_id(request: Request, pipe_id: str):

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Sanitize valve values to prevent injection
sanitized_valves = {}
for k, v in valves.items():
sanitized_v = sanitize_value(v)
if sanitized_v is not None or v is None:
sanitized_valves[k] = sanitized_v

# Validate schema to prevent object injection
validated_valves = validate_valves_schema(Valves, sanitized_valves)

function_module.valves = Valves(**validated_valves)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error loading function configuration: {str(e)}"
)
else:
function_module.valves = Valves()

Expand Down Expand Up @@ -160,10 +280,23 @@ async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
# Sanitize params before execution to prevent code injection
sanitized_params = sanitize_params(params)

# Validate pipe signature to ensure we're only passing expected parameters
sig = inspect.signature(pipe)
validated_params = {}

for key, value in sanitized_params.items():
if key in sig.parameters:
validated_params[key] = value
else:
log.warning(f"Skipping unexpected parameter: {key}")

if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
return await pipe(**validated_params)
else:
return pipe(**params)
return pipe(**validated_params)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -212,7 +345,20 @@ def get_function_params(function_module, form_data, user, extra_params=None):
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
# Sanitize user valve values
sanitized_user_valves = {}
for k, v in user_valves.items():
sanitized_v = sanitize_value(v)
if sanitized_v is not None or v is None:
sanitized_user_valves[k] = sanitized_v

# Validate schema to prevent object injection
validated_user_valves = validate_valves_schema(
function_module.UserValves,
sanitized_user_valves
)

params["__user__"]["valves"] = function_module.UserValves(**validated_user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
Expand Down Expand Up @@ -350,4 +496,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)