diff --git a/api/app_factory.py b/api/app_factory.py index 70bc3611..17bb7a3e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -11,6 +11,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.base import BaseHTTPMiddleware +from api.auth.oauth_handlers import setup_oauth_handlers from api.routes.auth import auth_router, init_auth from api.routes.graphs import graphs_router from api.routes.database import database_router @@ -83,6 +84,8 @@ def create_app(): app.include_router(graphs_router, prefix="/graphs") app.include_router(database_router) + setup_oauth_handlers(app, app.state.oauth) + @app.exception_handler(Exception) async def handle_oauth_error(request: Request, exc: Exception): """Handle OAuth-related errors gracefully""" diff --git a/api/routes/auth.py b/api/routes/auth.py index 40eee554..b2446d26 100644 --- a/api/routes/auth.py +++ b/api/routes/auth.py @@ -161,6 +161,12 @@ async def google_authorized(request: Request) -> RedirectResponse: request.session["google_token"] = token request.session["token_validated_at"] = time.time() + # Call the registered Google callback handler if it exists to store user data. + handler = getattr(request.app.state, "google_callback_handler", None) + if handler: + # call the registered handler (await if async) + await handler(request, token, user_info) + return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) except AuthlibBaseError as e: @@ -235,6 +241,12 @@ async def github_authorized(request: Request) -> RedirectResponse: request.session["github_token"] = token request.session["token_validated_at"] = time.time() + # Call the registered GitHub callback handler if it exists to store user data. + handler = getattr(request.app.state, "github_callback_handler", None) + if handler: + # call the registered handler (await if async) + await handler(request, token, user_info) + return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) except AuthlibBaseError as e: