Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@


def get_function_module_by_id(request: Request, pipe_id: str):
# Validate that the function exists in the database and is authorized
function = Functions.get_function_by_id(pipe_id.split('.')[0])
if not function:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Function not found"
)

# Verify the function is active
if not function.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Function is not active"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Expand Down Expand Up @@ -160,10 +175,32 @@ async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
# Validate that pipe is callable and from a trusted source
if not callable(pipe):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid pipe function"
)

# Verify the pipe function has the expected signature
sig = inspect.signature(pipe)
param_names = set(sig.parameters.keys())
provided_params = set(params.keys())

# Only pass parameters that the function expects
filtered_params = {k: v for k, v in params.items() if k in param_names}

try:
if inspect.iscoroutinefunction(pipe):
return await pipe(**filtered_params)
else:
return pipe(**filtered_params)
except Exception as e:
log.error(f"Error executing pipe: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Pipe execution failed: {str(e)}"
)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -292,6 +329,13 @@ def get_function_params(function_module, form_data, user, extra_params=None):
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

# Verify the function module has a pipe attribute and it's callable
if not hasattr(function_module, "pipe"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Function module does not have a pipe method"
)

pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)

Expand Down Expand Up @@ -350,4 +394,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)