Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ FASTAPI_SECRET_KEY=your_super_secret_key_here
# Default: development
# APP_ENV=development

# General prefix for graph names used for Demo Graphs
# GENERAL_PREFIX=your_general_prefix_here

# Optional: allow OAuth over HTTP in development (disable in production)
# OAUTHLIB_INSECURE_TRANSPORT=1

Expand Down
65 changes: 61 additions & 4 deletions api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,69 @@ def create_app():
"""Create and configure the FastAPI application."""
app = FastAPI(
title="QueryWeaver",
description=(
"Text2SQL with "
"Graph-Powered Schema Understanding"
),
description="Text2SQL with Graph-Powered Schema Understanding",
openapi_tags=[
{"name": "Authentication",
"description": "User authentication and OAuth operations"},
{"name": "Graphs & Databases",
"description": "Database schema management and querying"},
{"name": "Database Connection",
"description": "Connect to external databases"},
{"name": "API Tokens",
"description": "Manage API tokens for authentication"}
]
)

# Add security schemes to OpenAPI after app creation
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema

# pylint: disable=import-outside-toplevel
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)

# Add security schemes
openapi_schema["components"]["securitySchemes"] = {
"ApiTokenAuth": {
"type": "apiKey",
"in": "cookie",
"name": "api_token",
"description": "API token for programmatic access. "
"Generate via POST /tokens/generate after OAuth login."
},
"SessionAuth": {
"type": "apiKey",
"in": "cookie",
"name": "session",
"description": "Session cookie for web browsers. "
"Login via Google/GitHub at /login/google or /login/github."
}
}

# Add security requirements to protected endpoints
for _, path_item in openapi_schema["paths"].items():
for method, operation in path_item.items():
if method in ["get", "post", "put", "delete", "patch"]:
# Check if endpoint has token_required (look for 401 response)
if "401" in operation.get("responses", {}):
# Use OR logic - user needs EITHER ApiTokenAuth OR
# SessionAuth (not both)
operation["security"] = [
{"ApiTokenAuth": []}, # Option 1: API Token
{"SessionAuth": []} # Option 2: OAuth Session
]

app.openapi_schema = openapi_schema
return app.openapi_schema

app.openapi = custom_openapi

app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")

app.add_middleware(
Expand Down
2 changes: 1 addition & 1 deletion api/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


# Router
auth_router = APIRouter()
auth_router = APIRouter(tags=["Authentication"])
TEMPLATES_DIR = str((Path(__file__).resolve().parents[1] / "../app/templates").resolve())

TEMPLATES_CACHE_DIR = "/tmp/jinja_cache"
Expand Down
6 changes: 4 additions & 2 deletions api/routes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from api.loaders.postgres_loader import PostgresLoader
from api.loaders.mysql_loader import MySQLLoader

database_router = APIRouter()
database_router = APIRouter(tags=["Database Connection"])

# Use the same delimiter as in the JavaScript frontend for streaming chunks
MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||"
Expand All @@ -26,7 +26,9 @@ class DatabaseConnectionRequest(BaseModel):

url: str

@database_router.post("/database", operation_id="connect_database")
@database_router.post("/database", operation_id="connect_database", responses={
401: {"description": "Unauthorized - Please log in or provide a valid API token"}
})
@token_required
async def connect_database(request: Request, db_request: DatabaseConnectionRequest):
"""
Expand Down
Loading