diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py new file mode 100644 index 0000000..8861320 --- /dev/null +++ b/backend/app/exceptions.py @@ -0,0 +1,130 @@ +from enum import Enum +from fastapi import HTTPException, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +import logging + +logger = logging.getLogger("ai_assistant.exceptions") + +class ErrorCode(str, Enum): + INVALID_TOKEN = "invalid_token" + AUTHENTICATION_REQUIRED = "authentication_required" + INVALID_CREDENTIALS = "invalid_credentials" + EMAIL_ALREADY_EXISTS = "email_already_exists" + HISTORY_NOT_FOUND = "history_not_found" + FAVORITE_NOT_FOUND = "favorite_not_found" + SHARED_RESULT_NOT_FOUND = "shared_result_not_found" + UNSUPPORTED_FILE_TYPE = "unsupported_file_type" + PAYLOAD_TOO_LARGE = "payload_too_large" + RATE_LIMITED = "rate_limited" + VALIDATION_ERROR = "validation_error" + INTERNAL_SERVER_ERROR = "internal_server_error" + BAD_REQUEST = "bad_request" + FORBIDDEN = "forbidden" + ALREADY_SUBSCRIBED = "already_subscribed" + SUBSCRIPTION_NOT_FOUND = "subscription_not_found" + +class APIException(HTTPException): + def __init__(self, status_code: int, error_code: ErrorCode, detail: str, headers: dict | None = None): + super().__init__(status_code=status_code, detail=detail, headers=headers) + self.error_code = error_code + +def map_http_exception_to_code(status_code: int, detail: str) -> str: + detail_lower = detail.lower() + + # 401 Unauthorized + if status_code == 401: + if "authentication required" in detail_lower: + return ErrorCode.AUTHENTICATION_REQUIRED + if "invalid token" in detail_lower or "user not found" in detail_lower: + return ErrorCode.INVALID_TOKEN + if "invalid credentials" in detail_lower: + return ErrorCode.INVALID_CREDENTIALS + return "unauthorized" + + # 403 Forbidden + if status_code == 403: + if "invalid unsubscribe token" in detail_lower: + return ErrorCode.INVALID_TOKEN + return ErrorCode.FORBIDDEN + + # 404 Not Found + if status_code == 404: + if "history" in detail_lower: + return ErrorCode.HISTORY_NOT_FOUND + if "favorite" in detail_lower: + return ErrorCode.FAVORITE_NOT_FOUND + if "shared result" in detail_lower or "share" in detail_lower: + return ErrorCode.SHARED_RESULT_NOT_FOUND + if "subscription" in detail_lower: + return ErrorCode.SUBSCRIPTION_NOT_FOUND + return "not_found" + + # 409 Conflict + if status_code == 409: + if "already subscribed" in detail_lower: + return ErrorCode.ALREADY_SUBSCRIBED + if "email already exists" in detail_lower: + return ErrorCode.EMAIL_ALREADY_EXISTS + return "conflict" + + # 413 Content Too Large / Payload Too Large + if status_code == 413: + return ErrorCode.PAYLOAD_TOO_LARGE + + # 415 Unsupported Media Type + if status_code == 415: + return ErrorCode.UNSUPPORTED_FILE_TYPE + + # 429 Too Many Requests + if status_code == 429: + return ErrorCode.RATE_LIMITED + + # 400 Bad Request + if status_code == 400: + if "only .zip" in detail_lower: + return ErrorCode.UNSUPPORTED_FILE_TYPE + return ErrorCode.BAD_REQUEST + + # 500 Internal Server Error + if status_code == 500: + return ErrorCode.INTERNAL_SERVER_ERROR + + return ErrorCode.BAD_REQUEST if status_code < 500 else ErrorCode.INTERNAL_SERVER_ERROR + +async def api_exception_handler(request: Request, exc: APIException): + return JSONResponse( + status_code=exc.status_code, + content={ + "error": exc.error_code, + "detail": exc.detail, + }, + headers=exc.headers, + ) + +async def http_exception_handler(request: Request, exc: HTTPException): + error_code = map_http_exception_to_code(exc.status_code, exc.detail) + return JSONResponse( + status_code=exc.status_code, + content={ + "error": error_code, + "detail": exc.detail, + }, + headers=exc.headers, + ) + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + errors = exc.errors() + details = [] + for err in errors: + loc = " -> ".join(str(l) for l in err["loc"]) + details.append(f"{loc}: {err['msg']}") + detail_str = "; ".join(details) + + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "error": ErrorCode.VALIDATION_ERROR, + "detail": detail_str, + }, + ) diff --git a/backend/app/main.py b/backend/app/main.py index e868905..4a6a0b3 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -3,7 +3,8 @@ FastAPI application with advanced middleware, rate limiting, and full analysis engine. """ -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, HTTPException +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import JSONResponse @@ -14,6 +15,13 @@ import logging from contextlib import asynccontextmanager +from .exceptions import ( + APIException, + api_exception_handler, + http_exception_handler, + validation_exception_handler, +) + from .routers import ( analyze, auth, @@ -83,6 +91,10 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) +app.add_exception_handler(APIException, api_exception_handler) +app.add_exception_handler(HTTPException, http_exception_handler) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + # ── Middleware ──────────────────────────────────────────────────────────────── app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware( @@ -112,6 +124,7 @@ async def add_process_time_header(request: Request, call_next): return JSONResponse( status_code=429, content={ + "error": "rate_limited", "detail": f"Rate limit exceeded. Max {RATE_LIMIT} requests/minute." }, headers=headers, @@ -223,5 +236,8 @@ async def global_exception_handler(request: Request, exc: Exception): logging.exception("Unhandled error") return JSONResponse( status_code=500, - content={"detail": "Internal server error. Please try again."}, + content={ + "error": "internal_server_error", + "detail": "Internal server error. Please try again.", + }, ) diff --git a/backend/app/schemas.py b/backend/app/schemas.py index cfde6d9..30e80e3 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -371,3 +371,9 @@ class ChatMessageResponse(BaseModel): model: str mode: str reply: str + + +class ErrorResponse(BaseModel): + error: str + detail: str + diff --git a/backend/tests/test_auth_endpoints.py b/backend/tests/test_auth_endpoints.py index eeece4a..278c766 100644 --- a/backend/tests/test_auth_endpoints.py +++ b/backend/tests/test_auth_endpoints.py @@ -1,113 +1,117 @@ -"""Integration tests for auth routes""" - -import os -import sys - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool - -sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) - -from app.database import Base, get_db -from app.main import app as fastapi_app - - -TEST_ENGINE = create_engine( - "sqlite:///:memory:", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) -TEST_SESSION_LOCAL = sessionmaker(bind=TEST_ENGINE) - - -def _override_db(): - db = TEST_SESSION_LOCAL() - try: - yield db - finally: - db.close() - - -@pytest.fixture -def client(): - previous_override = fastapi_app.dependency_overrides.get(get_db) - fastapi_app.dependency_overrides[get_db] = _override_db - with TestClient(fastapi_app) as test_client: - yield test_client - if previous_override is None: - fastapi_app.dependency_overrides.pop(get_db, None) - else: - fastapi_app.dependency_overrides[get_db] = previous_override - - -@pytest.fixture(autouse=True) -def _recreate_tables(): - Base.metadata.create_all(bind=TEST_ENGINE) - yield - Base.metadata.drop_all(bind=TEST_ENGINE) - - -def test_auth_routes_are_exposed_in_openapi(client): - response = client.get("/openapi.json") - assert response.status_code == 200 - - paths = response.json()["paths"] - assert "/auth/signup" in paths - assert "/auth/login" in paths - assert "/auth/me" in paths - - -def test_signup_login_and_me_happy_path(client): - signup_response = client.post( - "/auth/signup", - json={"email": "new.user@example.com", "password": "StrongPass123!"}, - ) - assert signup_response.status_code == 200 - - signup_data = signup_response.json() - assert signup_data["email"] == "new.user@example.com" - assert signup_data["user_id"] > 0 - assert signup_data["access_token"] - - login_response = client.post( - "/auth/login", - json={"email": "new.user@example.com", "password": "StrongPass123!"}, - ) - assert login_response.status_code == 200 - - token = login_response.json()["access_token"] - me_response = client.get( - "/auth/me", - headers={"Authorization": f"Bearer {token}"}, - ) - assert me_response.status_code == 200 - assert me_response.json() == { - "user_id": signup_data["user_id"], - "email": "new.user@example.com", - } - - -def test_signup_duplicate_email_returns_409(client): - payload = {"email": "dup@example.com", "password": "StrongPass123!"} - first_response = client.post("/auth/signup", json=payload) - assert first_response.status_code == 200 - - duplicate_response = client.post("/auth/signup", json=payload) - assert duplicate_response.status_code == 409 - assert "already exists" in duplicate_response.json()["detail"].lower() - - -def test_me_rejects_missing_and_invalid_token(client): - missing_token_response = client.get("/auth/me") - assert missing_token_response.status_code == 401 - assert "authentication required" in missing_token_response.json()["detail"].lower() - - invalid_token_response = client.get( - "/auth/me", - headers={"Authorization": "Bearer not-a-real-token"}, - ) - assert invalid_token_response.status_code == 401 - assert "invalid token" in invalid_token_response.json()["detail"].lower() \ No newline at end of file +"""Integration tests for auth routes""" + +import os +import sys + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from app.database import Base, get_db +from app.main import app as fastapi_app + + +TEST_ENGINE = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TEST_SESSION_LOCAL = sessionmaker(bind=TEST_ENGINE) + + +def _override_db(): + db = TEST_SESSION_LOCAL() + try: + yield db + finally: + db.close() + + +@pytest.fixture +def client(): + previous_override = fastapi_app.dependency_overrides.get(get_db) + fastapi_app.dependency_overrides[get_db] = _override_db + with TestClient(fastapi_app) as test_client: + yield test_client + if previous_override is None: + fastapi_app.dependency_overrides.pop(get_db, None) + else: + fastapi_app.dependency_overrides[get_db] = previous_override + + +@pytest.fixture(autouse=True) +def _recreate_tables(): + Base.metadata.create_all(bind=TEST_ENGINE) + yield + Base.metadata.drop_all(bind=TEST_ENGINE) + + +def test_auth_routes_are_exposed_in_openapi(client): + response = client.get("/openapi.json") + assert response.status_code == 200 + + paths = response.json()["paths"] + assert "/auth/signup" in paths + assert "/auth/login" in paths + assert "/auth/me" in paths + + +def test_signup_login_and_me_happy_path(client): + signup_response = client.post( + "/auth/signup", + json={"email": "new.user@example.com", "password": "StrongPass123!"}, + ) + assert signup_response.status_code == 200 + + signup_data = signup_response.json() + assert signup_data["email"] == "new.user@example.com" + assert signup_data["user_id"] > 0 + assert signup_data["access_token"] + + login_response = client.post( + "/auth/login", + json={"email": "new.user@example.com", "password": "StrongPass123!"}, + ) + assert login_response.status_code == 200 + + token = login_response.json()["access_token"] + me_response = client.get( + "/auth/me", + headers={"Authorization": f"Bearer {token}"}, + ) + assert me_response.status_code == 200 + assert me_response.json() == { + "user_id": signup_data["user_id"], + "email": "new.user@example.com", + } + + +def test_signup_duplicate_email_returns_409(client): + payload = {"email": "dup@example.com", "password": "StrongPass123!"} + first_response = client.post("/auth/signup", json=payload) + assert first_response.status_code == 200 + + duplicate_response = client.post("/auth/signup", json=payload) + assert duplicate_response.status_code == 409 + assert "already exists" in duplicate_response.json()["detail"].lower() + assert duplicate_response.json()["error"] == "email_already_exists" + + + +def test_me_rejects_missing_and_invalid_token(client): + missing_token_response = client.get("/auth/me") + assert missing_token_response.status_code == 401 + assert "authentication required" in missing_token_response.json()["detail"].lower() + assert missing_token_response.json()["error"] == "authentication_required" + + invalid_token_response = client.get( + "/auth/me", + headers={"Authorization": "Bearer not-a-real-token"}, + ) + assert invalid_token_response.status_code == 401 + assert "invalid token" in invalid_token_response.json()["detail"].lower() + assert invalid_token_response.json()["error"] == "invalid_token" \ No newline at end of file diff --git a/backend/tests/test_endpoints.py b/backend/tests/test_endpoints.py index c4d776b..28875d5 100644 --- a/backend/tests/test_endpoints.py +++ b/backend/tests/test_endpoints.py @@ -186,6 +186,8 @@ def test_rate_limit_returns_429_with_retry_after_header(): assert r.headers["Retry-After"] == str(app_main.RATE_LIMIT_WINDOW_SECONDS) assert r.headers["X-RateLimit-Limit"] == str(app_main.RATE_LIMIT) assert r.headers["X-RateLimit-Remaining"] == "0" + assert r.json()["error"] == "rate_limited" + assert "Rate limit exceeded" in r.json()["detail"] # ── Explanation ─────────────────────────────────────────────────────────────── @@ -231,11 +233,15 @@ def test_explanation_accepts_rust_hint_alias(): def test_explanation_empty_code(): r = client.post("/explanation/", json={"code": " "}) assert r.status_code == 422 + assert r.json()["error"] == "validation_error" + assert "code" in r.json()["detail"] def test_explanation_too_long(): r = client.post("/explanation/", json={"code": "x" * 60000}) assert r.status_code == 422 + assert r.json()["error"] == "validation_error" + assert "code" in r.json()["detail"] def test_explanation_typescript(): @@ -681,6 +687,7 @@ def test_full_analyze_all_languages(): def test_missing_code_field(): r = client.post("/analyze/", json={}) assert r.status_code == 422 + assert r.json()["error"] == "validation_error" def test_unicode_code(): @@ -751,3 +758,5 @@ def test_get_stream_with_language_hint(): def test_get_stream_empty_code_rejected(): r = client.get("/analyze/stream", params={"code": " "}) assert r.status_code in (400, 422) + assert r.json()["error"] in ("validation_error", "bad_request") + assert "code" in r.json()["detail"] diff --git a/backend/tests/test_file_upload.py b/backend/tests/test_file_upload.py index edba244..9019b95 100644 --- a/backend/tests/test_file_upload.py +++ b/backend/tests/test_file_upload.py @@ -94,10 +94,10 @@ def test_upload_blocked_files( ) assert response.status_code == 415 - data = response.json() - assert "Executable files are not allowed" in data["detail"] + assert data["error"] == "unsupported_file_type" + # ========================================================= @@ -119,10 +119,10 @@ def test_invalid_mime_type(): print(response.json()) assert response.status_code == 415 - data = response.json() - assert "Invalid MIME type" in data["detail"] + assert data["error"] == "unsupported_file_type" + # ========================================================= @@ -143,10 +143,10 @@ def test_double_extension(): ) assert response.status_code == 415 - data = response.json() - assert "Executable files are not allowed" in data["detail"] + assert data["error"] == "unsupported_file_type" + # ========================================================= @@ -158,6 +158,8 @@ def test_no_file_uploaded(): response = client.post("/upload/validate") assert response.status_code in [400, 422] + assert response.json()["error"] in ("bad_request", "validation_error") + # ========================================================= @@ -179,4 +181,5 @@ def test_large_file(): } ) - assert response.status_code == 413 \ No newline at end of file + assert response.status_code == 413 + assert response.json()["error"] == "payload_too_large" \ No newline at end of file diff --git a/backend/tests/test_history.py b/backend/tests/test_history.py index 0b5f371..21851f6 100644 --- a/backend/tests/test_history.py +++ b/backend/tests/test_history.py @@ -67,6 +67,9 @@ def test_delete_history(): def test_delete_nonexistent(): r = client.delete("/history/999999") assert r.status_code == 404 + assert r.json()["error"] == "history_not_found" + assert "History entry not found" in r.json()["detail"] + def test_history_entry_fields(): diff --git a/backend/tests/test_share.py b/backend/tests/test_share.py index c7edef5..a5e06ba 100644 --- a/backend/tests/test_share.py +++ b/backend/tests/test_share.py @@ -1,95 +1,96 @@ -from __future__ import annotations - -from datetime import UTC, datetime, timedelta - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from app import database -from app.database import Base -from app.main import app -from app.models import SharedSnippet - - -def _configure_test_db(monkeypatch, tmp_path): - db_path = tmp_path / "share-tests.db" - - engine = create_engine( - f"sqlite:///{db_path}", - connect_args={"check_same_thread": False}, - ) - - session_local = sessionmaker( - autocommit=False, - autoflush=False, - bind=engine, - ) - - monkeypatch.setattr(database, "engine", engine) - monkeypatch.setattr(database, "SessionLocal", session_local) - - Base.metadata.drop_all(bind=engine) - Base.metadata.create_all(bind=engine) - - return session_local - - -def test_create_and_fetch_share(monkeypatch, tmp_path): - _configure_test_db(monkeypatch, tmp_path) - - from fastapi.testclient import TestClient - - client = TestClient(app) - - payload = { - "code": "print('hello')", - "result": { - "provider": "rule-based", - "explanation": {"summary": "ok"}, - }, - } - - create_resp = client.post("/share/", json=payload) - - assert create_resp.status_code == 200 - - share_id = create_resp.json()["id"] - - assert share_id - - fetch_resp = client.get(f"/share/{share_id}") - - assert fetch_resp.status_code == 200 - - data = fetch_resp.json() - - assert data["id"] == share_id - assert data["code"] == payload["code"] - assert data["result"] == payload["result"] - assert "created_at" in data - - -def test_expired_share_returns_404(monkeypatch, tmp_path): - session_local = _configure_test_db(monkeypatch, tmp_path) - - from fastapi.testclient import TestClient - - client = TestClient(app) - - db = session_local() - - record = SharedSnippet( - token="expired123", - code="print('old')", - result_json='{"ok": true}', - created_at=datetime.now(UTC) - timedelta(days=8), - ) - - db.add(record) - db.commit() - db.close() - - resp = client.get("/share/expired123") - - assert resp.status_code == 404 - assert "expired" in resp.json()["detail"].lower() \ No newline at end of file +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from app import database +from app.database import Base +from app.main import app +from app.models import SharedSnippet + + +def _configure_test_db(monkeypatch, tmp_path): + db_path = tmp_path / "share-tests.db" + + engine = create_engine( + f"sqlite:///{db_path}", + connect_args={"check_same_thread": False}, + ) + + session_local = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + ) + + monkeypatch.setattr(database, "engine", engine) + monkeypatch.setattr(database, "SessionLocal", session_local) + + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + + return session_local + + +def test_create_and_fetch_share(monkeypatch, tmp_path): + _configure_test_db(monkeypatch, tmp_path) + + from fastapi.testclient import TestClient + + client = TestClient(app) + + payload = { + "code": "print('hello')", + "result": { + "provider": "rule-based", + "explanation": {"summary": "ok"}, + }, + } + + create_resp = client.post("/share/", json=payload) + + assert create_resp.status_code == 200 + + share_id = create_resp.json()["id"] + + assert share_id + + fetch_resp = client.get(f"/share/{share_id}") + + assert fetch_resp.status_code == 200 + + data = fetch_resp.json() + + assert data["id"] == share_id + assert data["code"] == payload["code"] + assert data["result"] == payload["result"] + assert "created_at" in data + + +def test_expired_share_returns_404(monkeypatch, tmp_path): + session_local = _configure_test_db(monkeypatch, tmp_path) + + from fastapi.testclient import TestClient + + client = TestClient(app) + + db = session_local() + + record = SharedSnippet( + token="expired123", + code="print('old')", + result_json='{"ok": true}', + created_at=datetime.now(UTC) - timedelta(days=8), + ) + + db.add(record) + db.commit() + db.close() + + resp = client.get("/share/expired123") + + assert resp.status_code == 404 + assert "expired" in resp.json()["detail"].lower() + assert resp.json()["error"] == "shared_result_not_found" \ No newline at end of file diff --git a/backend/tests/test_zip_dos.py b/backend/tests/test_zip_dos.py index ec0fa08..1063505 100644 --- a/backend/tests/test_zip_dos.py +++ b/backend/tests/test_zip_dos.py @@ -19,6 +19,8 @@ def test_analyze_zip_too_large_via_header(): response = client.post("/analyze/zip/", files=files, headers={"Content-Length": str(15 * 1024 * 1024)}) assert response.status_code == 413 assert "ZIP file too large" in response.json()["detail"] + assert response.json()["error"] == "payload_too_large" + def test_analyze_zip_too_large_via_stream(): # Simulate a stream that exceeds the limit @@ -30,6 +32,8 @@ def test_analyze_zip_too_large_via_stream(): response = client.post("/analyze/zip/", files=files, headers={"Content-Length": "100"}) assert response.status_code == 413 assert "ZIP file exceeds size limit during upload" in response.json()["detail"] + assert response.json()["error"] == "payload_too_large" + def test_analyze_zip_valid(): # Create a real small ZIP