From 2f34f066f451f9a71347acc92002f36f8a75e817 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 23 Mar 2026 11:18:15 +0100 Subject: [PATCH 1/3] chore: set-cookie to pass session id --- k8s/welearn-api/values.dev.yaml | 1 + k8s/welearn-api/values.prod.yaml | 1 + k8s/welearn-api/values.staging.yaml | 1 + src/app/api/api_v1/api.py | 5 +- src/app/api/api_v1/endpoints/chat.py | 5 +- src/app/api/api_v1/endpoints/metric.py | 3 +- src/app/core/config.py | 1 + src/app/middleware/monitor_requests.py | 5 +- src/app/search/api/router.py | 3 +- src/app/services/data_collection.py | 12 +-- src/app/services/sql_db/queries_user.py | 21 ++++- src/app/shared/domain/constants.py | 4 + src/app/shared/utils/requests.py | 23 +++++ src/app/tests/api/api_v1/test_search.py | 6 +- src/app/tutor/api/router.py | 7 +- src/app/user/__init__.py | 0 src/app/user/api/__init__.py | 0 .../endpoints/user.py => user/api/router.py} | 90 ++++++++++++++++--- src/app/user/utils/__init__.py | 0 src/app/user/utils/utils.py | 38 ++++++++ 20 files changed, 191 insertions(+), 35 deletions(-) create mode 100644 src/app/shared/utils/requests.py create mode 100644 src/app/user/__init__.py create mode 100644 src/app/user/api/__init__.py rename src/app/{api/api_v1/endpoints/user.py => user/api/router.py} (54%) create mode 100644 src/app/user/utils/__init__.py create mode 100644 src/app/user/utils/utils.py diff --git a/k8s/welearn-api/values.dev.yaml b/k8s/welearn-api/values.dev.yaml index cd36fee..336ad7a 100644 --- a/k8s/welearn-api/values.dev.yaml +++ b/k8s/welearn-api/values.dev.yaml @@ -6,6 +6,7 @@ config: PG_HOST: dev-lab-projects-backend.postgres.database.azure.com TIKA_URL_BASE: https://tika.k8s.lp-i.dev/ DATA_COLLECTION_ORIGIN_PREFIX: welearn + ENV: development allowedHostsRegexes: mainUrl: |- https:\/\/welearn\.k8s\.lp-i\.dev diff --git a/k8s/welearn-api/values.prod.yaml b/k8s/welearn-api/values.prod.yaml index 5965741..cefda86 100644 --- a/k8s/welearn-api/values.prod.yaml +++ b/k8s/welearn-api/values.prod.yaml @@ -6,6 +6,7 @@ config: PG_HOST: prod-prod-projects-backend.postgres.database.azure.com TIKA_URL_BASE: https://tika.k8s.lp-i.org/ DATA_COLLECTION_ORIGIN_PREFIX: workshop + ENV: production allowedHostsRegexes: alphaUrls: |- https://[a-zA-Z0-9-]*\.alpha-welearn\.lp-i\.org diff --git a/k8s/welearn-api/values.staging.yaml b/k8s/welearn-api/values.staging.yaml index 50a6a14..8da3d01 100644 --- a/k8s/welearn-api/values.staging.yaml +++ b/k8s/welearn-api/values.staging.yaml @@ -8,6 +8,7 @@ config: PG_DATABASE: welearn_datastack_staging TIKA_URL_BASE: https://tika.k8s.lp-i.dev/ DATA_COLLECTION_ORIGIN_PREFIX: welearn + ENV: staging allowedHostsRegexes: mainUrl: |- https:\/\/welearn\.k8s\.lp-i\.xyz diff --git a/src/app/api/api_v1/api.py b/src/app/api/api_v1/api.py index a769be6..e6a6958 100644 --- a/src/app/api/api_v1/api.py +++ b/src/app/api/api_v1/api.py @@ -2,9 +2,10 @@ from fastapi import APIRouter -from src.app.api.api_v1.endpoints import chat, metric, micro_learning, user +from src.app.api.api_v1.endpoints import chat, metric, micro_learning from src.app.search.api import router as search_router from src.app.tutor.api import router as tutor_router +from src.app.user.api import router as user_router api_router = APIRouter() api_router.include_router(chat.router, prefix="/qna", tags=["qna"]) @@ -14,7 +15,7 @@ api_router.include_router( micro_learning.router, prefix="/micro_learning", tags=["micro_learning"] ) -api_router.include_router(user.router, prefix="/user", tags=["user"]) +api_router.include_router(user_router.router, prefix="/user", tags=["user"]) api_tags_metadata = [ diff --git a/src/app/api/api_v1/endpoints/chat.py b/src/app/api/api_v1/endpoints/chat.py index f708167..fafe8e7 100644 --- a/src/app/api/api_v1/endpoints/chat.py +++ b/src/app/api/api_v1/endpoints/chat.py @@ -23,6 +23,7 @@ ) from src.app.shared.infra.abst_chat import get_chat_service from src.app.shared.utils.dependencies import get_settings +from src.app.shared.utils.requests import extract_session_cookie from src.app.utils.logger import logger as utils_logger logger = utils_logger(__name__) @@ -256,7 +257,7 @@ async def q_and_a_ans( str: openai chat completion content """ - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) try: content = await chatfactory.chat_message( @@ -354,7 +355,7 @@ async def agent_response( data_collection=Depends(get_data_collection_service), ) -> Optional[Dict]: try: - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) docs = [] if body.query is None: diff --git a/src/app/api/api_v1/endpoints/metric.py b/src/app/api/api_v1/endpoints/metric.py index 2d53ed5..9a9b726 100644 --- a/src/app/api/api_v1/endpoints/metric.py +++ b/src/app/api/api_v1/endpoints/metric.py @@ -6,6 +6,7 @@ from src.app.services.data_collection import get_data_collection_service from src.app.services.sql_db.queries import get_document_qty_table_info_sync from src.app.shared.utils.dependencies import get_settings +from src.app.shared.utils.requests import extract_session_cookie from src.app.utils.logger import logger as utils_logger logger = utils_logger(__name__) @@ -55,6 +56,6 @@ async def update_clicked_doc_from_chat_message( async def register_syllabus_download( request: Request, data_collection=Depends(get_data_collection_service) ) -> str: - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) await data_collection.register_syllabus_download(session_id) return "registered" diff --git a/src/app/core/config.py b/src/app/core/config.py index b1194b8..b454e7c 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -42,6 +42,7 @@ def get_api_version(self) -> dict: AZURE_API_VERSION: str LLM_MODEL_NAME: str + ENV: str # PG PG_USER: Optional[str] = None diff --git a/src/app/middleware/monitor_requests.py b/src/app/middleware/monitor_requests.py index d4f5a0c..22f20fc 100644 --- a/src/app/middleware/monitor_requests.py +++ b/src/app/middleware/monitor_requests.py @@ -3,6 +3,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from src.app.services.sql_db.queries import register_endpoint +from src.app.shared.utils.requests import extract_session_cookie from src.app.utils.logger import logger as logger_utils logger = logger_utils(__name__) @@ -11,7 +12,7 @@ class MonitorRequestsMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.url.path.startswith("/api/v1/"): - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) if session_id: try: await run_in_threadpool( @@ -24,7 +25,7 @@ async def dispatch(self, request: Request, call_next): logger.error(f"Failed to register endpoint {request.url.path}: {e}") else: logger.warning( - f"No X-Session-ID header provided for {request.url.path}" + f"No X-Session-ID cookie provided for {request.url.path}" ) response = await call_next(request) diff --git a/src/app/search/api/router.py b/src/app/search/api/router.py index ac20c5c..d2c93f7 100644 --- a/src/app/search/api/router.py +++ b/src/app/search/api/router.py @@ -32,6 +32,7 @@ ModelNotFoundError, bad_request, ) +from src.app.shared.utils.requests import extract_session_cookie from src.app.utils.logger import logger as logger_utils router = APIRouter() @@ -206,7 +207,7 @@ async def search_all( data_collection=Depends(get_data_collection_service), ): try: - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) res = await sp.search_handler( qp=qp, method=SearchMethods.BY_DOCUMENT, background_tasks=background_tasks diff --git a/src/app/services/data_collection.py b/src/app/services/data_collection.py index 062cc52..2f7fcfa 100644 --- a/src/app/services/data_collection.py +++ b/src/app/services/data_collection.py @@ -65,7 +65,7 @@ def get_campaign_state( async def register_search_data( self, - session_id: str | None, + session_id: uuid.UUID | None, query: str, search_results: list[Document | ScoredPoint], sdg_filter: list[int] | None = None, @@ -88,9 +88,7 @@ async def register_search_data( }, ) - user_id = await run_in_threadpool( - get_user_from_session_id, uuid.UUID(session_id) - ) + user_id = await run_in_threadpool(get_user_from_session_id, session_id) if not user_id: raise HTTPException( @@ -241,7 +239,7 @@ async def register_syllabus_update( async def register_chat_data( self, - session_id: str | None, + session_id: uuid.UUID | None, user_query: str, conversation_id: uuid.UUID | None, answer_content: str, @@ -264,9 +262,7 @@ async def register_chat_data( }, ) - user_id = await run_in_threadpool( - get_user_from_session_id, uuid.UUID(session_id) - ) + user_id = await run_in_threadpool(get_user_from_session_id, session_id) if not user_id: raise HTTPException( diff --git a/src/app/services/sql_db/queries_user.py b/src/app/services/sql_db/queries_user.py index 7e49bdb..b21fcf4 100644 --- a/src/app/services/sql_db/queries_user.py +++ b/src/app/services/sql_db/queries_user.py @@ -3,6 +3,7 @@ import uuid from datetime import datetime, timedelta +from fastapi import HTTPException from sqlalchemy.sql import select from welearn_database.data.models import Bookmark, InferredUser, Session @@ -27,6 +28,12 @@ def get_or_create_user_sync( ).first() if user: return user.id + else: + logger.error(f"User with id {user_id} not found,") + raise HTTPException( + status_code=404, detail=f"User with id {user_id} not found" + ) + user = InferredUser(origin_referrer=referer) s.add(user) s.commit() @@ -70,11 +77,23 @@ def get_or_create_session_sync( return new_session.id -def get_user_from_session_id(session_id: uuid.UUID) -> uuid.UUID | None: +def get_user_from_session_id(session_id: uuid.UUID | None) -> uuid.UUID | None: + if not session_id: + return None + with session_maker() as s: session = s.execute(select(Session).where(Session.id == session_id)).first() if session: + logger.info( + "Valid session. user_id=%s session_id=%s", + session[0].inferred_user_id, + session_id, + ) return session[0].inferred_user_id + else: + HTTPException( + status_code=404, detail=f"Session with id {session_id} not found" + ) return None diff --git a/src/app/shared/domain/constants.py b/src/app/shared/domain/constants.py index de70d46..4695eb3 100644 --- a/src/app/shared/domain/constants.py +++ b/src/app/shared/domain/constants.py @@ -25,3 +25,7 @@ } APP_NAME = "welearn-api" + + +SESSION_COOKIE_NAME = "x-session-id" +SESSION_TTL_SECONDS = 60 * 60 * 24 * 400 # 400 days diff --git a/src/app/shared/utils/requests.py b/src/app/shared/utils/requests.py new file mode 100644 index 0000000..dd986ed --- /dev/null +++ b/src/app/shared/utils/requests.py @@ -0,0 +1,23 @@ +from typing import Optional +from uuid import UUID +from fastapi import Request +from src.app.shared.domain.constants import SESSION_COOKIE_NAME +from src.app.utils.logger import logger as utils_logger + +logger = utils_logger(__name__) + + +def extract_session_cookie(request: Request) -> Optional[UUID]: + cookie_value = request.cookies.get(SESSION_COOKIE_NAME) + if not cookie_value: + return None + + try: + return UUID(cookie_value) + except ValueError: + logger.warning("Invalid session cookie format: %s", cookie_value) + return None + + +def extract_origin_from_request(request: Request) -> str: + return request.headers.get("origin", "unknown") diff --git a/src/app/tests/api/api_v1/test_search.py b/src/app/tests/api/api_v1/test_search.py index 3ebc105..0d53d3e 100644 --- a/src/app/tests/api/api_v1/test_search.py +++ b/src/app/tests/api/api_v1/test_search.py @@ -333,7 +333,7 @@ async def test_search_all_no_collections(self, *mocks): headers={ "X-API-Key": "test", "origin": "test.com", - "X-Session-ID": str(uuid.uuid4()), + "Cookie": f"x-session-id={str(uuid.uuid4())}", }, ) self.assertEqual(response.status_code, 404) @@ -349,7 +349,7 @@ async def test_search_all_no_result(self, *mocks): headers={ "X-API-Key": "test", "origin": "test.com", - "X-Session-ID": str(uuid.uuid4()), + "Cookie": f"x-session-id={str(uuid.uuid4())}", }, # noqa: E501 ) @@ -363,7 +363,7 @@ async def test_search_all_no_query(self, *mocks): headers={ "X-API-Key": "test", "origin": "test.com", - "X-Session-Id": str(uuid.uuid4()), + "Cookie": f"x-session-id={str(uuid.uuid4())}", }, ) self.assertEqual(response.status_code, 400) diff --git a/src/app/tutor/api/router.py b/src/app/tutor/api/router.py index 53f3379..faae374 100644 --- a/src/app/tutor/api/router.py +++ b/src/app/tutor/api/router.py @@ -19,6 +19,7 @@ from src.app.shared.domain.exceptions import NoResultsError from src.app.shared.infra.abst_chat import get_chat_service from src.app.shared.utils.dependencies import get_settings +from src.app.shared.utils.requests import extract_session_cookie from src.app.shared.utils.utils import get_files_content from src.app.tutor.service.agents import TEMPLATES from src.app.tutor.service.models import ( @@ -161,7 +162,7 @@ async def create_syllabus( data_collection=Depends(get_data_collection_service), settings: Settings = Depends(get_settings), ) -> SyllabusResponse: - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) results = await tutor_manager(body, lang, settings) # TODO: handle errors @@ -228,7 +229,7 @@ async def handle_syllabus_feedback( chatfactory=Depends(get_chat_service), data_collection=Depends(get_data_collection_service), ): - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) messages = [ { @@ -282,7 +283,7 @@ async def register_syllabus_user_update( body: SyllabusUserUpdate, data_collection=Depends(get_data_collection_service), ): - session_id = request.headers.get("X-Session-ID") + session_id = extract_session_cookie(request) await data_collection.register_syllabus_update( session_id=session_id, diff --git a/src/app/user/__init__.py b/src/app/user/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/user/api/__init__.py b/src/app/user/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/api/api_v1/endpoints/user.py b/src/app/user/api/router.py similarity index 54% rename from src/app/api/api_v1/endpoints/user.py rename to src/app/user/api/router.py index b082d95..f9382b7 100644 --- a/src/app/api/api_v1/endpoints/user.py +++ b/src/app/user/api/router.py @@ -1,6 +1,6 @@ import uuid -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException, Request, Response from fastapi.concurrency import run_in_threadpool from src.app.services.sql_db.queries_user import ( @@ -10,13 +10,48 @@ get_or_create_session_sync, get_or_create_user_sync, get_user_bookmarks_sync, + get_user_from_session_id, ) +from src.app.shared.domain.constants import SESSION_COOKIE_NAME, SESSION_TTL_SECONDS +from src.app.shared.utils.requests import ( + extract_origin_from_request, + extract_session_cookie, +) +from src.app.user.utils.utils import resolve_user_and_session from src.app.utils.logger import logger as logger_utils + router = APIRouter() logger = logger_utils(__name__) +@router.post( + "/user_and_session", summary="Create new user and session", response_model=dict +) +async def handle_user_and_session( + request: Request, response: Response, referer: str | None = None +): + host = extract_origin_from_request(request) + session_uuid = extract_session_cookie(request) + + _, session_uuid = await resolve_user_and_session( + session_uuid=session_uuid, + host=host, + referer=referer, + ) + + response.set_cookie( + key=SESSION_COOKIE_NAME, + value=str(session_uuid), + max_age=SESSION_TTL_SECONDS, + httponly=True, + samesite="lax", + secure=False, # True in production (HTTPS) + ) + + return {"message": "session created"} + + @router.post("/user", summary="Create new user", response_model=dict) async def handle_user(user_id: uuid.UUID | None = None, referer: str | None = None): try: @@ -35,7 +70,7 @@ async def handle_session( referer: str | None = None, ): try: - host = request.headers.get("origin", "unknown") + host = extract_origin_from_request(request) session_id = await run_in_threadpool( get_or_create_session_sync, user_id, session_id, host, referer ) @@ -47,10 +82,19 @@ async def handle_session( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/:user_id/bookmarks", summary="Get user bookmarks") -async def get_user_bookmarks(user_id: uuid.UUID): +@router.get("/bookmarks", summary="Get user bookmarks") +async def get_user_bookmarks(request: Request): + session_uuid = extract_session_cookie(request) + host = extract_origin_from_request(request) + + user_id, _ = await resolve_user_and_session( + session_uuid=session_uuid, + host=host, + referer=None, + ) try: bookmarks = await run_in_threadpool(get_user_bookmarks_sync, user_id) + print(f"Bookmarks for user {user_id}: {bookmarks}") return {"bookmarks": bookmarks} except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -59,10 +103,16 @@ async def get_user_bookmarks(user_id: uuid.UUID): raise HTTPException(status_code=500, detail=str(e)) -@router.delete( - "/:user_id/bookmarks", summary="Delete all user bookmarks", response_model=dict -) -async def delete_user_bookmarks(user_id: uuid.UUID): +@router.delete("/bookmarks", summary="Delete all user bookmarks", response_model=dict) +async def delete_user_bookmarks(request: Request): + session_uuid = extract_session_cookie(request) + host = extract_origin_from_request(request) + + user_id, _ = await resolve_user_and_session( + session_uuid=session_uuid, + host=host, + referer=None, + ) try: deleted_count = await run_in_threadpool(delete_user_bookmarks_sync, user_id) return {"deleted": deleted_count} @@ -74,11 +124,19 @@ async def delete_user_bookmarks(user_id: uuid.UUID): @router.delete( - "/:user_id/bookmarks/:document_id", + "/bookmarks/:document_id", summary="Delete a user bookmark", response_model=dict, ) -async def delete_user_bookmark(user_id: uuid.UUID, document_id: uuid.UUID): +async def delete_user_bookmark(request: Request, document_id: uuid.UUID): + session_uuid = extract_session_cookie(request) + host = extract_origin_from_request(request) + + user_id, _ = await resolve_user_and_session( + session_uuid=session_uuid, + host=host, + referer=None, + ) try: deleted_id = await run_in_threadpool( delete_user_bookmark_sync, user_id, document_id @@ -92,9 +150,17 @@ async def delete_user_bookmark(user_id: uuid.UUID, document_id: uuid.UUID): @router.post( - "/:user_id/bookmarks/:document_id", summary="Add user bookmark", response_model=dict + "/bookmarks/:document_id", summary="Add user bookmark", response_model=dict ) -async def add_user_bookmark(user_id: uuid.UUID, document_id: uuid.UUID): +async def add_user_bookmark(request: Request, document_id: uuid.UUID): + session_uuid = extract_session_cookie(request) + host = extract_origin_from_request(request) + + user_id, _ = await resolve_user_and_session( + session_uuid=session_uuid, + host=host, + referer=None, + ) try: added_id = await run_in_threadpool(add_user_bookmark_sync, user_id, document_id) return {"added": added_id} diff --git a/src/app/user/utils/__init__.py b/src/app/user/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/user/utils/utils.py b/src/app/user/utils/utils.py new file mode 100644 index 0000000..fba368c --- /dev/null +++ b/src/app/user/utils/utils.py @@ -0,0 +1,38 @@ +from uuid import UUID +from typing import Optional +from fastapi.concurrency import run_in_threadpool + +from src.app.services.sql_db.queries_user import ( + get_user_from_session_id, + get_or_create_session_sync, + get_or_create_user_sync, +) +from src.app.utils.logger import logger as logger_utils + +logger = logger_utils(__name__) + + +async def resolve_user_and_session( + session_uuid: Optional[UUID], + host: str, + referer: Optional[str], +) -> tuple[UUID, UUID]: + user_id = await run_in_threadpool(get_user_from_session_id, session_uuid) + + if not user_id: + logger.info("No user found. Creating new user and session.") + user_id = await run_in_threadpool(get_or_create_user_sync, None, referer) + session_uuid = await run_in_threadpool( + get_or_create_session_sync, user_id, None, host, referer + ) + else: + logger.info( + "Existing user found. user_id=%s session_id=%s", + user_id, + session_uuid, + ) + session_uuid = await run_in_threadpool( + get_or_create_session_sync, user_id, session_uuid, host, referer + ) + + return user_id, session_uuid From 46bbeadb5c8a870855b7c3af27dc72aed6663093 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 23 Mar 2026 14:44:52 +0100 Subject: [PATCH 2/3] raise error if no session id --- src/app/user/api/router.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/app/user/api/router.py b/src/app/user/api/router.py index f9382b7..3192264 100644 --- a/src/app/user/api/router.py +++ b/src/app/user/api/router.py @@ -10,7 +10,6 @@ get_or_create_session_sync, get_or_create_user_sync, get_user_bookmarks_sync, - get_user_from_session_id, ) from src.app.shared.domain.constants import SESSION_COOKIE_NAME, SESSION_TTL_SECONDS from src.app.shared.utils.requests import ( @@ -87,6 +86,9 @@ async def get_user_bookmarks(request: Request): session_uuid = extract_session_cookie(request) host = extract_origin_from_request(request) + if not session_uuid: + raise HTTPException(status_code=401, detail="Session cookie is missing") + user_id, _ = await resolve_user_and_session( session_uuid=session_uuid, host=host, @@ -94,7 +96,6 @@ async def get_user_bookmarks(request: Request): ) try: bookmarks = await run_in_threadpool(get_user_bookmarks_sync, user_id) - print(f"Bookmarks for user {user_id}: {bookmarks}") return {"bookmarks": bookmarks} except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -106,6 +107,8 @@ async def get_user_bookmarks(request: Request): @router.delete("/bookmarks", summary="Delete all user bookmarks", response_model=dict) async def delete_user_bookmarks(request: Request): session_uuid = extract_session_cookie(request) + if not session_uuid: + raise HTTPException(status_code=401, detail="Session cookie is missing") host = extract_origin_from_request(request) user_id, _ = await resolve_user_and_session( @@ -130,6 +133,8 @@ async def delete_user_bookmarks(request: Request): ) async def delete_user_bookmark(request: Request, document_id: uuid.UUID): session_uuid = extract_session_cookie(request) + if not session_uuid: + raise HTTPException(status_code=401, detail="Session cookie is missing") host = extract_origin_from_request(request) user_id, _ = await resolve_user_and_session( @@ -154,6 +159,8 @@ async def delete_user_bookmark(request: Request, document_id: uuid.UUID): ) async def add_user_bookmark(request: Request, document_id: uuid.UUID): session_uuid = extract_session_cookie(request) + if not session_uuid: + raise HTTPException(status_code=401, detail="Session cookie is missing") host = extract_origin_from_request(request) user_id, _ = await resolve_user_and_session( From d1e4a65b9931d2201315c5f5d757e00cb705ff0a Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 23 Mar 2026 16:41:01 +0100 Subject: [PATCH 3/3] adds tests --- .env.example | 2 +- pytest.ini | 1 + src/app/shared/utils/requests.py | 2 + src/app/tests/api/api_v1/test_user.py | 35 +++++++++------- src/app/tests/test_utils.py | 57 +++++++++++++++++++++++++++ src/app/user/api/router.py | 5 ++- src/app/user/utils/utils.py | 5 ++- 7 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 src/app/tests/test_utils.py diff --git a/.env.example b/.env.example index 946201d..e5bf486 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,6 @@ CLIENT_ORIGINS=https://fake-origin.example.com CLIENT_ORIGINS_REGEX="^http://fake-localhost:.*" - +ENV=development ##### MISTRAL AZURE ##### AZURE_MISTRAL_SWEDEN_API_BASE=https://fake-mistral-sweden.example.com/models AZURE_MISTRAL_SWEDEN_API_KEY=FAKE_MISTRAL_SWEDEN_API_KEY diff --git a/pytest.ini b/pytest.ini index 7b32851..e8bd55e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -21,6 +21,7 @@ env = TIKA_URL_BASE=https://tika.example.com USE_CACHED_SETTINGS=True DATA_COLLECTION_ORIGIN_PREFIX=workshop + ENV=test filterwarnings = ignore:.*U.*mode is deprecated:DeprecationWarning diff --git a/src/app/shared/utils/requests.py b/src/app/shared/utils/requests.py index dd986ed..998d26d 100644 --- a/src/app/shared/utils/requests.py +++ b/src/app/shared/utils/requests.py @@ -1,6 +1,8 @@ from typing import Optional from uuid import UUID + from fastapi import Request + from src.app.shared.domain.constants import SESSION_COOKIE_NAME from src.app.utils.logger import logger as utils_logger diff --git a/src/app/tests/api/api_v1/test_user.py b/src/app/tests/api/api_v1/test_user.py index 70cc0ac..5e60317 100644 --- a/src/app/tests/api/api_v1/test_user.py +++ b/src/app/tests/api/api_v1/test_user.py @@ -1,4 +1,5 @@ import unittest +import uuid from unittest import mock from unittest.mock import MagicMock @@ -231,30 +232,36 @@ async def test_get_user_bookmarks_success_empty(self, session_maker_mock, *mocks session.execute.return_value.all.return_value = [] session_maker_mock.return_value.__enter__.return_value = session - user_id = "22222222-2222-2222-2222-222222222222" response = client.get( - f"{settings.API_V1_STR}/user/:user_id/bookmarks", - params={"user_id": user_id}, + f"{settings.API_V1_STR}/user/bookmarks", headers={"X-API-Key": "test"}, + cookies={"x-session-id": "bdb62bb2-1fe5-4d14-92fd-60a041355aea"}, ) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"bookmarks": []}) - @mock.patch("src.app.services.sql_db.queries_user.session_maker") - async def test_add_user_bookmark_success(self, session_maker_mock, *mocks): - """Ajout d'un bookmark""" - session = MagicMock() - session.execute.return_value.first.side_effect = [MagicMock(id="user-1"), None] - session_maker_mock.return_value.__enter__.return_value = session + @mock.patch("src.app.user.api.router.run_in_threadpool") + @mock.patch("src.app.user.api.router.resolve_user_and_session") + async def test_add_user_bookmark_success( + self, resolve_user_and_session_mock, run_in_threadpool_mock, *mocks + ): + """Ajout d'un bookmark - mocks only what is needed""" + # Mock resolve_user_and_session to return user_id and session_id + user_id = uuid.UUID("cfc8072c-a055-442a-9878-b5a73d9141b2") + session_id = uuid.UUID("bdb62bb2-1fe5-4d14-92fd-60a041355aea") + resolve_user_and_session_mock.return_value = (user_id, session_id) - user_id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + # Mock run_in_threadpool to simulate DB add_user_bookmark_sync document_id = "ffffffff-ffff-ffff-ffff-ffffffffffff" + run_in_threadpool_mock.return_value = document_id + response = client.post( - f"{settings.API_V1_STR}/user/:user_id/bookmarks/:document_id", - params={"user_id": user_id, "document_id": document_id}, + f"{settings.API_V1_STR}/user/bookmarks/:document_id", + params={"document_id": document_id}, headers={"X-API-Key": "test"}, + cookies={"x-session-id": str(session_id)}, ) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"added": document_id}) - session.add.assert_called_once() - session.commit.assert_called_once() + resolve_user_and_session_mock.assert_called_once() + run_in_threadpool_mock.assert_called_once() diff --git a/src/app/tests/test_utils.py b/src/app/tests/test_utils.py new file mode 100644 index 0000000..d9cfa4f --- /dev/null +++ b/src/app/tests/test_utils.py @@ -0,0 +1,57 @@ +import unittest +import uuid +from unittest.mock import patch + +from src.app.user.utils import utils + + +class TestResolveUserAndSession(unittest.IsolatedAsyncioTestCase): + @patch("src.app.user.utils.utils.run_in_threadpool") + async def test_existing_user_and_session(self, run_in_threadpool_mock): + """Should return user_id and session_uuid for existing user and session""" + user_id = uuid.uuid4() + session_uuid = uuid.uuid4() + host = "localhost" + referer = "test" + # First call: get_user_from_session_id returns user_id + # Second call: get_or_create_session_sync returns session_uuid + run_in_threadpool_mock.side_effect = [user_id, session_uuid] + + result_user_id, result_session_uuid = await utils.resolve_user_and_session( + session_uuid, host, referer + ) + self.assertEqual(result_user_id, user_id) + self.assertEqual(result_session_uuid, session_uuid) + self.assertEqual(run_in_threadpool_mock.call_count, 2) + + @patch("src.app.user.utils.utils.run_in_threadpool") + async def test_new_user_and_session(self, run_in_threadpool_mock): + """Should create new user and session if user_id not found""" + session_uuid = uuid.uuid4() + host = "localhost" + referer = "test" + new_user_id = uuid.uuid4() + new_session_uuid = uuid.uuid4() + # First call: get_user_from_session_id returns None + # Second call: get_or_create_user_sync returns new_user_id + # Third call: get_or_create_session_sync returns new_session_uuid + run_in_threadpool_mock.side_effect = [None, new_user_id, new_session_uuid] + + result_user_id, result_session_uuid = await utils.resolve_user_and_session( + session_uuid, host, referer + ) + self.assertEqual(result_user_id, new_user_id) + self.assertEqual(result_session_uuid, new_session_uuid) + self.assertEqual(run_in_threadpool_mock.call_count, 3) + + @patch("src.app.user.utils.utils.run_in_threadpool") + async def test_logger_called(self, run_in_threadpool_mock): + """Should log info for both new and existing user cases""" + user_id = uuid.uuid4() + session_uuid = uuid.uuid4() + host = "localhost" + referer = "test" + run_in_threadpool_mock.side_effect = [user_id, session_uuid] + with patch.object(utils.logger, "info") as logger_info: + await utils.resolve_user_and_session(session_uuid, host, referer) + logger_info.assert_called() diff --git a/src/app/user/api/router.py b/src/app/user/api/router.py index 3192264..3766754 100644 --- a/src/app/user/api/router.py +++ b/src/app/user/api/router.py @@ -12,6 +12,7 @@ get_user_bookmarks_sync, ) from src.app.shared.domain.constants import SESSION_COOKIE_NAME, SESSION_TTL_SECONDS +from src.app.shared.utils.dependencies import get_settings from src.app.shared.utils.requests import ( extract_origin_from_request, extract_session_cookie, @@ -19,9 +20,9 @@ from src.app.user.utils.utils import resolve_user_and_session from src.app.utils.logger import logger as logger_utils - router = APIRouter() logger = logger_utils(__name__) +settings = get_settings() @router.post( @@ -45,7 +46,7 @@ async def handle_user_and_session( max_age=SESSION_TTL_SECONDS, httponly=True, samesite="lax", - secure=False, # True in production (HTTPS) + secure=settings.ENV == "production", ) return {"message": "session created"} diff --git a/src/app/user/utils/utils.py b/src/app/user/utils/utils.py index fba368c..6526bc1 100644 --- a/src/app/user/utils/utils.py +++ b/src/app/user/utils/utils.py @@ -1,11 +1,12 @@ -from uuid import UUID from typing import Optional +from uuid import UUID + from fastapi.concurrency import run_in_threadpool from src.app.services.sql_db.queries_user import ( - get_user_from_session_id, get_or_create_session_sync, get_or_create_user_sync, + get_user_from_session_id, ) from src.app.utils.logger import logger as logger_utils