Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 82 additions & 118 deletions backend/open_webui/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,33 +223,97 @@ def parse_section(section):

WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")

####################################
# PATH VALIDATION HELPER
####################################

def validate_data_dir(path: Path, allowed_base: Path) -> Path:
"""
Validate that a data directory path is within the allowed base directory.

Args:
path: The path to validate
allowed_base: The base directory that the path must be within

Returns:
The validated, resolved path

Raises:
ValueError: If the path is outside the allowed base directory
"""
try:
resolved_path = path.resolve()
resolved_base = allowed_base.resolve()

# Ensure the resolved path is within the allowed base
resolved_path.relative_to(resolved_base)

return resolved_path
except (ValueError, RuntimeError) as e:
log.error(f"Invalid DATA_DIR path: {path}. Must be within {allowed_base}")
raise ValueError(f"Invalid DATA_DIR path: {path}. Path traversal detected.") from e

####################################
# DATA/FRONTEND BUILD DIR
####################################

DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
# Get the raw DATA_DIR from environment or use default
raw_data_dir = os.getenv("DATA_DIR", str(BACKEND_DIR / "data"))
DATA_DIR = validate_data_dir(Path(raw_data_dir), BACKEND_DIR)

if FROM_INIT_PY:
NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
raw_new_data_dir = os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data"))
NEW_DATA_DIR = validate_data_dir(Path(raw_new_data_dir), OPEN_WEBUI_DIR)
NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Check if the data directory exists in the package directory
if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
for item in DATA_DIR.iterdir():
dest = NEW_DATA_DIR / item.name
if item.is_dir():
shutil.copytree(item, dest, dirs_exist_ok=True)
else:
shutil.copy2(item, dest)

# Zip the data directory
shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)

# Remove the old data directory
shutil.rmtree(DATA_DIR)

DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
# Validate both paths are safe before migration
try:
# Ensure DATA_DIR is within BACKEND_DIR
DATA_DIR.relative_to(BACKEND_DIR.resolve())
# Ensure NEW_DATA_DIR is within OPEN_WEBUI_DIR
NEW_DATA_DIR.relative_to(OPEN_WEBUI_DIR.resolve())

log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
for item in DATA_DIR.iterdir():
# Validate each item to prevent symlink attacks
item_resolved = item.resolve()
# Ensure the resolved item is still within DATA_DIR
try:
item_resolved.relative_to(DATA_DIR.resolve())
except ValueError:
log.warning(f"Skipping {item}: resolves outside DATA_DIR")
continue

dest = NEW_DATA_DIR / item.name
# Validate destination
dest_resolved = dest.resolve()
try:
dest_resolved.relative_to(NEW_DATA_DIR.resolve())
except ValueError:
log.warning(f"Skipping {item}: destination outside NEW_DATA_DIR")
continue

if item.is_dir():
shutil.copytree(item_resolved, dest_resolved, dirs_exist_ok=True, symlinks=False)
else:
shutil.copy2(item_resolved, dest_resolved, follow_symlinks=False)

# Zip the data directory with safe archive path
archive_base = DATA_DIR.parent / "open_webui_data"
# Ensure archive path is safe
archive_base_resolved = archive_base.resolve()
archive_base_resolved.relative_to(BACKEND_DIR.resolve())
shutil.make_archive(str(archive_base_resolved), "zip", DATA_DIR)

# Remove the old data directory
shutil.rmtree(DATA_DIR)
except ValueError as e:
log.error(f"Data migration failed due to path validation: {e}")
raise

DATA_DIR = validate_data_dir(Path(os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data"))), OPEN_WEBUI_DIR)

STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))

Expand Down Expand Up @@ -732,104 +796,4 @@ def parse_section(section):

# Comma separated list of logger names to use for audit logging
# Default is "uvicorn.access" which is the access log for Uvicorn
# You can add more logger names to this list if you want to capture more logs
AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
"AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
).split(",")

# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048

# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]


####################################
# OPENTELEMETRY
####################################

ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true"
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"

OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_EXPORTER_OTLP_INSECURE = (
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
)
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
) # e.g. key1=val1,key2=val2
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")

OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)

OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
).lower() # grpc or http

OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http

OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http

####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################

PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()


####################################
# PROGRESSIVE WEB APP OPTIONS
####################################

EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
# You can add more logger names
124 changes: 119 additions & 5 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
import json
import asyncio
import hashlib
import hmac

from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
Expand Down Expand Up @@ -57,6 +59,93 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_valves(valves_dict: dict, valves_class) -> bool:
"""
Validates that valve values are safe and match expected types.
Returns True if valid, False otherwise.
"""
if not valves_dict:
return True

try:
# Get the expected fields from the Valves class
if hasattr(valves_class, '__fields__'):
expected_fields = valves_class.__fields__

# Check that all provided valves are expected
for key in valves_dict.keys():
if key not in expected_fields:
log.warning(f"Unexpected valve key: {key}")
return False

# Validate types match expectations
for key, value in valves_dict.items():
if value is not None and key in expected_fields:
expected_type = expected_fields[key].annotation
# Basic type checking
if hasattr(expected_type, '__origin__'):
# Handle generic types
continue
if not isinstance(value, expected_type):
# Try to convert basic types
if expected_type in (int, float, str, bool):
continue
log.warning(f"Type mismatch for valve {key}: expected {expected_type}, got {type(value)}")
return False

return True
except Exception as e:
log.error(f"Error validating valves: {e}")
return False


def sanitize_valve_value(value):
"""
Sanitizes a single valve value to prevent code injection.
"""
if value is None:
return None

# If it's a string, check for dangerous patterns
if isinstance(value, str):
# Reject strings that look like code or system commands
dangerous_patterns = [
'__import__',
'exec(',
'eval(',
'compile(',
'os.system',
'subprocess',
'__builtins__',
'__globals__',
'__code__',
'open(',
'file(',
]

value_lower = value.lower()
for pattern in dangerous_patterns:
if pattern.lower() in value_lower:
log.warning(f"Potentially dangerous pattern detected in valve value: {pattern}")
raise ValueError(f"Invalid valve value: contains prohibited pattern")

# For other basic types, return as-is
if isinstance(value, (int, float, bool, list, dict)):
# Recursively sanitize lists and dicts
if isinstance(value, list):
return [sanitize_valve_value(v) for v in value]
elif isinstance(value, dict):
return {k: sanitize_valve_value(v) for k, v in value.items()}
return value

# For complex objects, only allow if they're from safe types
if hasattr(value, '__dict__'):
log.warning(f"Complex object passed as valve value: {type(value)}")
raise ValueError(f"Invalid valve value type: {type(value)}")

return value


def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id)

Expand All @@ -66,9 +155,23 @@ def get_function_module_by_id(request: Request, pipe_id: str):

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Validate valves before using them
if not validate_valves(valves, Valves):
log.error(f"Invalid valves configuration for function {pipe_id}")
raise ValueError("Invalid valves configuration")

# Sanitize valve values
sanitized_valves = {}
for k, v in valves.items():
if v is not None:
try:
sanitized_valves[k] = sanitize_valve_value(v)
except ValueError as e:
log.error(f"Error sanitizing valve {k} for function {pipe_id}: {e}")
raise ValueError(f"Invalid valve value for {k}")

# Create valves instance with sanitized values
function_module.valves = Valves(**sanitized_valves)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
Expand Down Expand Up @@ -212,7 +315,18 @@ def get_function_params(function_module, form_data, user, extra_params=None):
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
# Validate and sanitize user valves
if validate_valves(user_valves, function_module.UserValves):
sanitized_user_valves = {}
for k, v in user_valves.items():
try:
sanitized_user_valves[k] = sanitize_valve_value(v)
except ValueError as e:
log.error(f"Error sanitizing user valve {k}: {e}")
continue
params["__user__"]["valves"] = function_module.UserValves(**sanitized_user_valves)
else:
params["__user__"]["valves"] = function_module.UserValves()
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
Expand Down Expand Up @@ -350,4 +464,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)