Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 80 additions & 106 deletions backend/open_webui/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,29 +227,95 @@ def parse_section(section):
# DATA/FRONTEND BUILD DIR
####################################

DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
def validate_data_dir(data_dir_path: Path, allowed_base_dirs: list[Path]) -> Path:
"""
Validate that the data directory is within allowed base directories.

Args:
data_dir_path: The resolved data directory path
allowed_base_dirs: List of allowed base directories

Returns:
The validated data directory path

Raises:
ValueError: If the data directory is outside allowed directories
"""
data_dir_resolved = data_dir_path.resolve()

# Check if the path is within any of the allowed base directories
for allowed_base in allowed_base_dirs:
allowed_base_resolved = allowed_base.resolve()
try:
data_dir_resolved.relative_to(allowed_base_resolved)
return data_dir_resolved
except ValueError:
continue

# If we get here, the path is not within any allowed directory
raise ValueError(
f"DATA_DIR '{data_dir_path}' is outside allowed directories. "
f"Path must be within: {', '.join(str(d) for d in allowed_base_dirs)}"
)

# Define allowed base directories for DATA_DIR
ALLOWED_DATA_DIR_BASES = [BASE_DIR, BACKEND_DIR, OPEN_WEBUI_DIR]

# Get and validate DATA_DIR
data_dir_env = os.getenv("DATA_DIR", str(BACKEND_DIR / "data"))
try:
DATA_DIR = validate_data_dir(Path(data_dir_env), ALLOWED_DATA_DIR_BASES)
except ValueError as e:
log.error(f"Invalid DATA_DIR configuration: {e}")
log.warning(f"Falling back to default: {BACKEND_DIR / 'data'}")
DATA_DIR = (BACKEND_DIR / "data").resolve()

if FROM_INIT_PY:
NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
new_data_dir_env = os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data"))
try:
NEW_DATA_DIR = validate_data_dir(Path(new_data_dir_env), ALLOWED_DATA_DIR_BASES)
except ValueError as e:
log.error(f"Invalid NEW_DATA_DIR configuration: {e}")
log.warning(f"Falling back to default: {OPEN_WEBUI_DIR / 'data'}")
NEW_DATA_DIR = (OPEN_WEBUI_DIR / "data").resolve()

NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Check if the data directory exists in the package directory
if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
for item in DATA_DIR.iterdir():
dest = NEW_DATA_DIR / item.name
if item.is_dir():
shutil.copytree(item, dest, dirs_exist_ok=True)
try:
for item in DATA_DIR.iterdir():
# Validate that the destination is still within NEW_DATA_DIR
dest = NEW_DATA_DIR / item.name
if not dest.resolve().is_relative_to(NEW_DATA_DIR.resolve()):
log.error(f"Skipping {item.name}: destination would be outside NEW_DATA_DIR")
continue

if item.is_dir():
shutil.copytree(item, dest, dirs_exist_ok=True)
else:
shutil.copy2(item, dest)

# Zip the data directory
archive_dest = DATA_DIR.parent / "open_webui_data"
if archive_dest.resolve().is_relative_to(DATA_DIR.parent.resolve()):
shutil.make_archive(str(archive_dest), "zip", DATA_DIR)
else:
shutil.copy2(item, dest)

# Zip the data directory
shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
log.error("Archive destination would be outside allowed directory")

# Remove the old data directory
shutil.rmtree(DATA_DIR)
# Remove the old data directory
shutil.rmtree(DATA_DIR)
except Exception as e:
log.error(f"Error during data directory migration: {e}")

DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
data_dir_final_env = os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data"))
try:
DATA_DIR = validate_data_dir(Path(data_dir_final_env), ALLOWED_DATA_DIR_BASES)
except ValueError as e:
log.error(f"Invalid final DATA_DIR configuration: {e}")
log.warning(f"Falling back to default: {OPEN_WEBUI_DIR / 'data'}")
DATA_DIR = (OPEN_WEBUI_DIR / "data").resolve()

STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))

Expand Down Expand Up @@ -720,96 +786,4 @@ def parse_section(section):
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048

# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]


####################################
# OPENTELEMETRY
####################################

ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true"
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"

OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_EXPORTER_OTLP_INSECURE = (
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
)
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
) # e.g. key1=val1,key2=val2
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")

OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)

OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
).lower() # grpc or http

OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http

OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http

####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################

PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()


####################################
# PROGRESSIVE WEB APP OPTIONS
####################################

EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE")
112 changes: 107 additions & 5 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
import json
import asyncio
import hashlib
from typing import Set

from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
Expand Down Expand Up @@ -56,22 +58,122 @@
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])

# Whitelist of allowed attributes for function modules
ALLOWED_FUNCTION_ATTRIBUTES: Set[str] = {
"pipe",
"pipes",
"valves",
"Valves",
"UserValves",
"name",
"__name__",
"__doc__",
}


def validate_function_id(pipe_id: str) -> str:
"""
Validate and sanitize function ID to prevent injection attacks.
Only allow alphanumeric characters, hyphens, underscores, and dots.
"""
if not pipe_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function ID: ID cannot be empty"
)

# Remove any path traversal attempts
pipe_id = pipe_id.replace("..", "").replace("/", "").replace("\\", "")

# Validate characters
allowed_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.")
if not all(c in allowed_chars for c in pipe_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid function ID: contains invalid characters"
)

return pipe_id


def verify_function_integrity(pipe_id: str, function_module) -> bool:
"""
Verify the integrity of the loaded function module.
Check that it only contains expected attributes.
"""
# Get all attributes of the module
module_attrs = set(dir(function_module))

# Get attributes that are not built-in Python attributes
custom_attrs = {attr for attr in module_attrs if not attr.startswith('_') or attr in ALLOWED_FUNCTION_ATTRIBUTES}

# Check for suspicious attributes
suspicious_attrs = custom_attrs - ALLOWED_FUNCTION_ATTRIBUTES

# Filter out harmless built-in attributes
harmless_builtins = {'__builtins__', '__file__', '__package__', '__loader__', '__spec__', '__cached__'}
suspicious_attrs = suspicious_attrs - harmless_builtins

if suspicious_attrs:
log.warning(f"Function {pipe_id} contains unexpected attributes: {suspicious_attrs}")

# Verify that the pipe function exists and is callable
if not hasattr(function_module, "pipe"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid function module: missing 'pipe' function"
)

if not callable(getattr(function_module, "pipe")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid function module: 'pipe' is not callable"
)

return True


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate and sanitize the pipe_id
pipe_id = validate_function_id(pipe_id)

# Verify the function exists in the database and user has access
function_record = Functions.get_function_by_id(pipe_id)
if not function_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Function not found: {pipe_id}"
)

# Load the function module from cache
function_module, _, _ = get_function_module_from_cache(request, pipe_id)

# Verify the integrity of the loaded module
verify_function_integrity(pipe_id, function_module)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id)

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
# Sanitize valve values to prevent injection
sanitized_valves = {}
for k, v in valves.items():
if v is not None:
# Convert to string and limit length to prevent DoS
if isinstance(v, str) and len(v) > 10000:
log.warning(f"Valve value for {k} exceeds maximum length, truncating")
v = v[:10000]
sanitized_valves[k] = v

function_module.valves = Valves(**sanitized_valves)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error loading function valves: {str(e)}"
)
else:
function_module.valves = Valves()

Expand Down Expand Up @@ -350,4 +452,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)