Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 152 additions & 7 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import inspect
import json
import asyncio
import hashlib
import hmac

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
Expand Down Expand Up @@ -57,7 +59,139 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_function_code(pipe_id: str, function_code: str) -> bool:
"""
Validate function code against stored hash to prevent code injection.
Returns True if validation passes, False otherwise.
"""
try:
# Get the function from database
function = Functions.get_function_by_id(pipe_id)
if not function:
log.error(f"Function {pipe_id} not found in database")
return False

# Verify code hash if available
if hasattr(function, 'code_hash'):
computed_hash = hashlib.sha256(function_code.encode()).hexdigest()
if not hmac.compare_digest(computed_hash, function.code_hash):
log.error(f"Code hash mismatch for function {pipe_id}")
return False

# Additional validation: check for dangerous imports/operations
dangerous_patterns = [
'import os',
'import subprocess',
'import sys',
'__import__',
'eval(',
'exec(',
'compile(',
'open(',
]

for pattern in dangerous_patterns:
if pattern in function_code:
log.warning(f"Potentially dangerous pattern '{pattern}' found in function {pipe_id}")
# This is a warning, not blocking, as legitimate functions may need these

return True

except Exception as e:
log.error(f"Error validating function code for {pipe_id}: {e}")
return False


def sanitize_valve_value(value):
"""
Sanitize valve values to prevent code injection.
Only allows primitive types and basic collections.
"""
if value is None:
return None

# Allow only safe primitive types
if isinstance(value, (str, int, float, bool)):
return value

# Allow lists of primitives
if isinstance(value, list):
return [sanitize_valve_value(item) for item in value]

# Allow dicts of primitives
if isinstance(value, dict):
return {k: sanitize_valve_value(v) for k, v in value.items() if isinstance(k, str)}

# Reject any other types (objects, functions, etc.)
log.warning(f"Rejecting non-primitive valve value of type {type(value)}")
return None


def validate_valves(Valves, valves_data: dict, pipe_id: str):
"""
Validate and sanitize valve data before instantiation.
Uses Pydantic validation and additional sanitization.
"""
try:
# First, sanitize all values
sanitized_valves = {}
for k, v in valves_data.items():
if v is not None:
sanitized_value = sanitize_valve_value(v)
if sanitized_value is not None:
sanitized_valves[k] = sanitized_value

# Get the Valves class fields to validate against schema
if hasattr(Valves, 'model_fields'):
# Pydantic v2
valid_fields = set(Valves.model_fields.keys())
elif hasattr(Valves, '__fields__'):
# Pydantic v1
valid_fields = set(Valves.__fields__.keys())
else:
# Fallback: allow all fields
valid_fields = set(sanitized_valves.keys())

# Only keep fields that are defined in the Valves schema
filtered_valves = {k: v for k, v in sanitized_valves.items() if k in valid_fields}

# Validate using Pydantic
validated_valves = Valves(**filtered_valves)

return validated_valves

except ValidationError as e:
log.error(f"Valve validation error for function {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid valve configuration: {str(e)}"
)
except Exception as e:
log.error(f"Unexpected error validating valves for function {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error validating valve configuration"
)


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate the function before loading
function = Functions.get_function_by_id(pipe_id)
if not function:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Function {pipe_id} not found"
)

# Verify function code integrity
if hasattr(function, 'content'):
if not validate_function_code(pipe_id, function.content):
log.error(f"Function code validation failed for {pipe_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Function code validation failed - potential security issue detected"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand All @@ -66,12 +200,16 @@ def get_function_module_by_id(request: Request, pipe_id: str):

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Use validated and sanitized valves instantiation
function_module.valves = validate_valves(Valves, valves, pipe_id)
except HTTPException:
raise
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error loading valve configuration: {str(e)}"
)
else:
function_module.valves = Valves()

Expand Down Expand Up @@ -212,7 +350,14 @@ def get_function_params(function_module, form_data, user, extra_params=None):
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
# Validate user valves as well
params["__user__"]["valves"] = validate_valves(
function_module.UserValves, user_valves, pipe_id
)
except HTTPException:
# If validation fails, use default valves
log.warning(f"User valve validation failed for {pipe_id}, using defaults")
params["__user__"]["valves"] = function_module.UserValves()
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
Expand Down Expand Up @@ -350,4 +495,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)