diff --git a/.gitignore b/.gitignore index d932895..1ee0356 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ cluster.yaml external-url-dump.txt app/__pycache__/ *.pyc +doc/ diff --git a/app/config.py b/app/config.py index 9106ab2..39ac7bc 100644 --- a/app/config.py +++ b/app/config.py @@ -3,6 +3,24 @@ """ import os from typing import Optional +from urllib.parse import urlparse + + +def _normalize_path_prefix(value: str) -> str: + prefix = (value or "").strip() + if not prefix or prefix == "/": + return "" + return "/" + prefix.strip("/") + + +def _path_prefix_from_public_base_url(value: str) -> str: + raw = (value or "").strip() + if not raw: + return "" + if raw.startswith("/"): + return urlparse(raw).path + parsed = urlparse(raw if "://" in raw else f"https://{raw}") + return parsed.path INSTANCE_TYPES = { @@ -47,6 +65,31 @@ def AVAILABLE_IMAGES(self) -> list: DATABASE_PATH: str = os.getenv("DATABASE_PATH", "/data/amd-oneclick.db") SESSION_SECRET: str = os.getenv("SESSION_SECRET", "change-me-for-production") + SSO_ENABLED: bool = os.getenv("SSO_ENABLED", "false").lower() in {"1", "true", "yes", "on"} + SSO_PUBLIC_KEY_PEM: Optional[str] = os.getenv("SSO_PUBLIC_KEY_PEM") + SSO_ISSUER: str = os.getenv("SSO_ISSUER", "") + SSO_AUDIENCE: str = os.getenv("SSO_AUDIENCE", "") + SSO_ALGORITHM: str = os.getenv("SSO_ALGORITHM", "RS256") + SSO_ACCESS_COOKIE_NAME: str = os.getenv("SSO_ACCESS_COOKIE_NAME", "sso_access_token") + SSO_REFRESH_COOKIE_NAME: str = os.getenv("SSO_REFRESH_COOKIE_NAME", "sso_refresh_token") + SSO_REFRESH_THRESHOLD_SECONDS: int = int(os.getenv("SSO_REFRESH_THRESHOLD_SECONDS", "300")) + SSO_REFRESH_URL: str = os.getenv("SSO_REFRESH_URL", "/apitest/api/auth/refresh") + SSO_LOGOUT_URL: str = os.getenv("SSO_LOGOUT_URL", "/apitest/api/auth/logout") + SSO_BIND_ENTRY_URL: str = os.getenv("SSO_BIND_ENTRY_URL", "https://aideveloperportal.anruicloud.com/login?returnUrl=") + SSO_BIND_RETURN_QUERY_KEY: str = os.getenv("SSO_BIND_RETURN_QUERY_KEY", "bind") + SSO_AUTO_CREATE_USER: bool = os.getenv("SSO_AUTO_CREATE_USER", "true").lower() in {"1", "true", "yes", "on"} + SSO_DEFAULT_USER_DOMAIN: str = os.getenv("SSO_DEFAULT_USER_DOMAIN", "developer.local") + SSO_CLAIM_NAME_URI: str = os.getenv( + "SSO_CLAIM_NAME_URI", + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name", + ) + SSO_CLAIM_EMAIL_URI: str = os.getenv( + "SSO_CLAIM_EMAIL_URI", + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + ) + + REDIS_URL: str = os.getenv("REDIS_URL", "") + REDIS_TOKEN_VERSION_KEY_PREFIX: str = os.getenv("REDIS_TOKEN_VERSION_KEY_PREFIX", "auth:user") GITHUB_CLIENT_ID: Optional[str] = os.getenv("GITHUB_CLIENT_ID") GITHUB_CLIENT_SECRET: Optional[str] = os.getenv("GITHUB_CLIENT_SECRET") @@ -109,6 +152,10 @@ def AVAILABLE_IMAGES(self) -> list: SERVICE_HOST: str = os.getenv("SERVICE_HOST", "localhost") PUBLIC_BASE_URL: str = os.getenv("PUBLIC_BASE_URL", "") + PUBLIC_PATH_PREFIX: str = _normalize_path_prefix( + os.getenv("PUBLIC_PATH_PREFIX", "") + or _path_prefix_from_public_base_url(os.getenv("PUBLIC_BASE_URL", "")) + ) NODE_PORT_BASE: int = int(os.getenv("NODE_PORT_BASE", "30000")) PYPI_MIRROR: str = "https://pypi.tuna.tsinghua.edu.cn/simple" @@ -141,3 +188,17 @@ def AVAILABLE_IMAGES(self) -> list: settings = Settings() + + +def validate_settings() -> None: + if not settings.SSO_ENABLED: + return + missing = [] + if not settings.SSO_PUBLIC_KEY_PEM: + missing.append("SSO_PUBLIC_KEY_PEM") + if not settings.SSO_ISSUER: + missing.append("SSO_ISSUER") + if not settings.SSO_AUDIENCE: + missing.append("SSO_AUDIENCE") + if missing: + raise RuntimeError(f"SSO is enabled but missing required env vars: {', '.join(missing)}") diff --git a/app/k8s_client.py b/app/k8s_client.py index e72e0e4..314fe32 100644 --- a/app/k8s_client.py +++ b/app/k8s_client.py @@ -99,6 +99,8 @@ def _get_labels(self, email: str, instance_id: str) -> dict: } def _jupyter_base_url(self, instance_id: str) -> str: + if settings.PUBLIC_PATH_PREFIX: + return f"{settings.PUBLIC_PATH_PREFIX}/instances/{instance_id}/" return f"/instances/{instance_id}/" def _workspace_host_path(self, instance_id: str) -> str: diff --git a/app/main.py b/app/main.py index 21b777e..f4449f4 100644 --- a/app/main.py +++ b/app/main.py @@ -24,7 +24,7 @@ import requests import websockets -from .config import settings, INSTANCE_TYPES +from .config import settings, INSTANCE_TYPES, validate_settings from .models import ( NotebookRequest, NotebookStatus, @@ -43,6 +43,7 @@ from .scheduler import start_scheduler, stop_scheduler from .template_sync import sync_template_preview from .store import ( + bind_user_developer_user_id, clear_template_preview_cache, delete_image, delete_notebook_template, @@ -71,6 +72,7 @@ update_image_sync_status, upsert_image, ) +from .sso_auth import developer_logout, resolve_current_user, resolve_websocket_user_from_cookie_header # Configure logging logging.basicConfig( @@ -85,6 +87,7 @@ async def lifespan(app: FastAPI): """Application lifespan manager""" # Startup logger.info("Starting AMD OneClick Notebook Manager") + validate_settings() init_db() if settings.RUN_SCHEDULER: start_scheduler() @@ -101,9 +104,11 @@ async def lifespan(app: FastAPI): title="AMD OneClick Notebook Manager", description="Kubernetes-based Jupyter Notebook instance management", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, + root_path=settings.PUBLIC_PATH_PREFIX, ) -app.add_middleware(SessionMiddleware, secret_key=settings.SESSION_SECRET) +if not settings.SSO_ENABLED: + app.add_middleware(SessionMiddleware, secret_key=settings.SESSION_SECRET) @app.middleware("http") @@ -125,6 +130,39 @@ async def log_slow_requests(request: Request, call_next): # Templates templates = Jinja2Templates(directory="templates") + +def _public_path(path: str = "/") -> str: + value = str(path or "/") + if value.startswith(("http://", "https://", "mailto:", "tel:", "data:", "blob:", "#")): + return value + if value.startswith("//"): + return value + if not value.startswith("/"): + value = f"/{value}" + + prefix = settings.PUBLIC_PATH_PREFIX + if not prefix: + return value + if value == prefix or value.startswith(f"{prefix}/") or value.startswith(f"{prefix}?"): + return value + if value == "/": + return f"{prefix}/" + return f"{prefix}{value}" + + +def _instance_route_path(instance_id: str, path: str = "") -> str: + suffix = str(path or "").lstrip("/") + base = f"/instances/{instance_id}/" + return f"{base}{suffix}" if suffix else base + + +def _instance_jupyter_path(instance_id: str, path: str = "") -> str: + return _public_path(_instance_route_path(instance_id, path)) + + +templates.env.globals["app_base_path"] = settings.PUBLIC_PATH_PREFIX +templates.env.globals["public_path"] = _public_path + # HTTP Basic Auth for admin security = HTTPBasic() @@ -144,7 +182,13 @@ def verify_admin(credentials: HTTPBasicCredentials = Depends(security)): return credentials.username -def current_user(request: Request) -> dict: +async def current_user(request: Request, response: Response) -> dict: + if settings.SSO_ENABLED: + user = await resolve_current_user(request, response, required=True) + if not user: + raise HTTPException(status_code=401, detail="Login required") + return user + user_id = request.session.get("user_id") if not user_id: raise HTTPException(status_code=401, detail="Login required") @@ -155,7 +199,10 @@ def current_user(request: Request) -> dict: return user -def session_user(request: Request) -> Optional[dict]: +async def session_user(request: Request, response: Response) -> Optional[dict]: + if settings.SSO_ENABLED: + return await resolve_current_user(request, response, required=False) + user_id = request.session.get("user_id") if not user_id: return None @@ -310,11 +357,17 @@ def _request_public_origin(request: Request) -> str: return f"{proto}://{host}".rstrip("/") +def _request_public_base_url(request: Request) -> str: + if settings.PUBLIC_BASE_URL: + return settings.PUBLIC_BASE_URL.rstrip("/") + return f"{_request_public_origin(request)}{settings.PUBLIC_PATH_PREFIX}".rstrip("/") + + def _instance_public_url(request: Request, instance_id: str, notebook_path: Optional[str] = None) -> str: path = f"/instances/{instance_id}/lab" if notebook_path: path += f"/tree/{quote(notebook_path.lstrip('/'), safe='/')}" - return f"{_request_public_origin(request)}{path}?token={settings.NOTEBOOK_TOKEN}" + return f"{_request_public_base_url(request)}{path}?token={settings.NOTEBOOK_TOKEN}" def _validate_resource_profile(profile: Optional[str]) -> str: @@ -416,11 +469,15 @@ def _proxy_headers(headers) -> dict: def _rewrite_location(location: str, instance_id: str, target_base: str) -> str: - public_prefix = f"/instances/{instance_id}/" + public_prefix = _public_path(f"/instances/{instance_id}/") + jupyter_prefix = _instance_jupyter_path(instance_id).rstrip("/") + if location.startswith(f"{target_base}{jupyter_prefix}"): + return location.replace(f"{target_base}{jupyter_prefix}", public_prefix.rstrip("/"), 1) if location.startswith(target_base): - return location.replace(target_base, public_prefix.rstrip("/"), 1) + rest = location[len(target_base):] or "/" + return _public_path(rest) if location.startswith("/"): - return location + return _public_path(location) return location @@ -486,9 +543,8 @@ def _decode_credit_coupon(encrypted_coupon_b64: str) -> dict: # ============================================================================= @app.get("/", response_class=HTMLResponse) -async def index(request: Request): +async def index(request: Request, user: Optional[dict] = Depends(session_user)): """Render the main request page""" - user = get_user(int(request.session["user_id"])) if request.session.get("user_id") else None images = list_images(enabled_only=True) notebook_templates = list_notebook_templates(enabled_only=True) active_instance = _active_instance_context(user, request) @@ -505,6 +561,9 @@ async def index(request: Request): "active_instance_json": json.dumps(active_instance or {}), "workshop_login_enabled": settings.WORKSHOP_LOGIN_ENABLED, "admin_login_enabled": settings.ADMIN_LOGIN_ENABLED, + "sso_enabled": settings.SSO_ENABLED, + "sso_bind_entry_url": settings.SSO_BIND_ENTRY_URL, + "sso_bind_return_query_key": settings.SSO_BIND_RETURN_QUERY_KEY, "resource_profiles_json": json.dumps(RESOURCE_PROFILES), "auto_resource_profile_by_gpu_json": json.dumps(AUTO_RESOURCE_PROFILE_BY_GPU), }, @@ -512,9 +571,8 @@ async def index(request: Request): @app.get("/profile", response_class=HTMLResponse) -async def profile_page(request: Request): +async def profile_page(request: Request, user: Optional[dict] = Depends(session_user)): """Render user profile and login page.""" - user = get_user(int(request.session["user_id"])) if request.session.get("user_id") else None active_instance = _active_instance_context(user, request) return templates.TemplateResponse( request, @@ -524,6 +582,9 @@ async def profile_page(request: Request): "active_instance": active_instance, "github_enabled": bool(settings.GITHUB_CLIENT_ID), "modelscope_enabled": bool(settings.MODELSCOPE_CLIENT_ID), + "sso_enabled": settings.SSO_ENABLED, + "sso_bind_entry_url": settings.SSO_BIND_ENTRY_URL, + "sso_bind_return_query_key": settings.SSO_BIND_RETURN_QUERY_KEY, "coupon_redeem_enabled": settings.COUPON_REDEEM_ENABLED, "coupon_redeem_disabled_message": settings.COUPON_REDEEM_DISABLED_MESSAGE, }, @@ -532,6 +593,8 @@ async def profile_page(request: Request): @app.get("/auth/github/login") async def github_login(request: Request): + if settings.SSO_ENABLED: + raise HTTPException(status_code=404, detail="Local login is disabled") if not settings.GITHUB_CLIENT_ID: raise HTTPException(status_code=500, detail="GitHub OAuth is not configured") params = { @@ -545,6 +608,8 @@ async def github_login(request: Request): @app.get("/auth/github/callback", name="github_callback") async def github_callback(request: Request, code: str = Query(...), state: str = Query("")): + if settings.SSO_ENABLED: + raise HTTPException(status_code=404, detail="Local login is disabled") _validate_oauth_state(request, "github", state) try: async with httpx.AsyncClient(timeout=_oauth_timeout()) as client: @@ -585,11 +650,13 @@ async def github_callback(request: Request, code: str = Query(...), state: str = await report_user_registered_event(user) request.session["user_id"] = user["id"] - return RedirectResponse("/") + return RedirectResponse(_public_path("/")) @app.get("/auth/modelscope/login") async def modelscope_login(request: Request): + if settings.SSO_ENABLED: + raise HTTPException(status_code=404, detail="login is disabled") if not settings.MODELSCOPE_CLIENT_ID: raise HTTPException(status_code=500, detail="ModelScope OAuth is not configured") params = { @@ -604,6 +671,8 @@ async def modelscope_login(request: Request): @app.get("/auth/modelscope/callback", name="modelscope_callback") async def modelscope_callback(request: Request, code: str = Query(...), state: str = Query("")): + if settings.SSO_ENABLED: + raise HTTPException(status_code=404, detail="login is disabled") expected_state = request.session.pop("modelscope_oauth_state", None) if expected_state and state and not secrets.compare_digest(expected_state, state): logger.warning("ModelScope OAuth state mismatch: expected=%s got=%s; continuing because ModelScope may not echo state", expected_state, state) @@ -657,17 +726,23 @@ async def modelscope_callback(request: Request, code: str = Query(...), state: s await report_user_registered_event(user) request.session["user_id"] = user["id"] - return RedirectResponse("/") + return RedirectResponse(_public_path("/")) @app.get("/auth/logout") async def logout(request: Request): + if settings.SSO_ENABLED: + response = RedirectResponse(_public_path("/")) + await developer_logout(request, response) + return response request.session.clear() - return RedirectResponse("/") + return RedirectResponse(_public_path("/")) @app.post("/auth/workshop/login") async def workshop_login(request: Request): + if settings.SSO_ENABLED: + raise HTTPException(status_code=404, detail="login is disabled") if not settings.WORKSHOP_LOGIN_ENABLED: raise HTTPException(status_code=404, detail="Workshop login is not enabled") payload = await request.json() @@ -703,6 +778,20 @@ async def api_me(user: dict = Depends(current_user)): return user +@app.post("/api/account/bind") +async def bind_account(user: dict = Depends(current_user)): + pending_developer_user_id = str(user.get("pending_developer_user_id") or "").strip() + if not pending_developer_user_id: + raise HTTPException(status_code=400, detail="No pending developer account to bind") + try: + updated = bind_user_developer_user_id(int(user["id"]), pending_developer_user_id) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not updated: + raise HTTPException(status_code=404, detail="User not found") + return {"user": updated, "bound": True} + + @app.post("/api/credits/redeem") async def redeem_credits(req: CouponRedeemRequest, user: dict = Depends(current_user)): if not settings.COUPON_REDEEM_ENABLED: @@ -781,14 +870,8 @@ async def request_notebook(request: Request, req: NotebookRequest, user: dict = @app.get("/api/notebook/status", response_model=NotebookStatus) -async def check_status(request: Request, email: Optional[str] = Query(None, description="User email")): +async def check_status(request: Request, email: Optional[str] = Query(None, description="User email"), user: dict = Depends(current_user)): """Check the status of a notebook instance""" - user_id = request.session.get("user_id") - if not user_id: - raise HTTPException(status_code=401, detail="Login required") - user = get_user(int(user_id)) - if not user: - raise HTTPException(status_code=401, detail="Login required") email = user["email"].lower() try: @@ -906,8 +989,9 @@ async def profile_delete_template(template_id: int, user: dict = Depends(current @app.get("/api/templates/{template_id}/preview-status") -async def template_preview_status(request: Request, template_id: int): - template = _template_accessible_to_user(template_id, session_user(request)) +async def template_preview_status(request: Request, response: Response, template_id: int): + user = await session_user(request, response) + template = _template_accessible_to_user(template_id, user) if not template: raise HTTPException(status_code=404, detail="Template not found") if not template.get("repo_url") or not template.get("notebook_path"): @@ -932,8 +1016,8 @@ async def profile_sync_template_preview(template_id: int, user: dict = Depends(c @app.get("/templates/{template_id}/preview", response_class=HTMLResponse) -async def preview_notebook_template(request: Request, template_id: int): - user = session_user(request) +async def preview_notebook_template(request: Request, response: Response, template_id: int): + user = await session_user(request, response) template = _template_accessible_to_user(template_id, user) if not template: raise HTTPException(status_code=404, detail="Template not found") @@ -979,15 +1063,16 @@ async def preview_notebook_template(request: Request, template_id: int): { "template_json": json.dumps(template), "notebook_json": json.dumps(notebook), - "asset_base_url": _template_asset_base_path(template_id, template["notebook_path"]), + "asset_base_url": _public_path(_template_asset_base_path(template_id, template["notebook_path"])), "preview_json": json.dumps(cache or {}), }, ) @app.get("/templates/{template_id}/assets/{asset_path:path}") -async def preview_notebook_template_asset(request: Request, template_id: int, asset_path: str): - template = _template_accessible_to_user(template_id, session_user(request)) +async def preview_notebook_template_asset(request: Request, response: Response, template_id: int, asset_path: str): + user = await session_user(request, response) + template = _template_accessible_to_user(template_id, user) if not template: raise HTTPException(status_code=404, detail="Template not found") if not template.get("repo_url") or not template.get("notebook_path"): @@ -1294,10 +1379,14 @@ async def check_github_status(instance_id: str = Query(...)): # ============================================================================= @app.api_route("/instances/{instance_id}/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"]) -async def proxy_instance_http(instance_id: str, path: str, request: Request): +async def proxy_instance_http(instance_id: str, path: str, request: Request, user: dict = Depends(current_user)): """Proxy HTTP traffic to a Jupyter instance using its path-based base_url.""" + active = get_active_instance_for_user(int(user["id"])) + if not active or active.get("instance_id") != instance_id: + raise HTTPException(status_code=403, detail="Forbidden") + target_base = _instance_service_base(instance_id) - target_url = f"{target_base}/instances/{instance_id}/{path}" + target_url = f"{target_base}{_instance_jupyter_path(instance_id, path)}" if request.url.query: target_url += f"?{request.url.query}" @@ -1326,10 +1415,21 @@ async def proxy_instance_http(instance_id: str, path: str, request: Request): @app.websocket("/instances/{instance_id}/{path:path}") async def proxy_instance_websocket(websocket: WebSocket, instance_id: str, path: str): """Proxy WebSocket traffic for Jupyter terminals/kernels under /instances//.""" + cookie_header = websocket.headers.get("cookie", "") + user = resolve_websocket_user_from_cookie_header(cookie_header) + if not user: + await websocket.close(code=4401) + return + + active = get_active_instance_for_user(int(user["id"])) + if not active or active.get("instance_id") != instance_id: + await websocket.close(code=4403) + return + await websocket.accept() try: target_base = _instance_service_base(instance_id).replace("http://", "ws://") - target_url = f"{target_base}/instances/{instance_id}/{path}" + target_url = f"{target_base}{_instance_jupyter_path(instance_id, path)}" if websocket.url.query: target_url += f"?{websocket.url.query}" diff --git a/app/sso_auth.py b/app/sso_auth.py new file mode 100644 index 0000000..9e9ed01 --- /dev/null +++ b/app/sso_auth.py @@ -0,0 +1,222 @@ +"""SSO authentication helpers for developer-issued JWT cookies.""" + +from __future__ import annotations + +import time +from typing import Optional +import httpx +import jwt +from fastapi import HTTPException, Request, Response +from redis import Redis + +from .config import settings +from .store import get_or_create_sso_user + +_redis_client: Optional[Redis] = None + + +def _redis() -> Optional[Redis]: + global _redis_client + if not settings.REDIS_URL: + return None + if _redis_client is None: + _redis_client = Redis.from_url(settings.REDIS_URL, decode_responses=True) + return _redis_client + + +def _absolute_url(request: Request, value: str) -> str: + if value.startswith("http://") or value.startswith("https://"): + return value + return str(request.base_url).rstrip("/") + "/" + value.lstrip("/") + + +def _decode_jwt(token: str) -> dict: + options = {"require": ["exp", "token_version"]} + claims = jwt.decode( + token, + settings.SSO_PUBLIC_KEY_PEM, + algorithms=[settings.SSO_ALGORITHM], + audience=settings.SSO_AUDIENCE, + issuer=settings.SSO_ISSUER, + options=options, + ) + return _normalize_claims(claims) + + +def _normalize_claims(claims: dict) -> dict: + normalized = dict(claims) + if not normalized.get("sub") and normalized.get("Id"): + normalized["sub"] = str(normalized.get("Id")) + if not normalized.get("username") and normalized.get(settings.SSO_CLAIM_NAME_URI): + normalized["username"] = str(normalized.get(settings.SSO_CLAIM_NAME_URI)) + if not normalized.get("email") and normalized.get(settings.SSO_CLAIM_EMAIL_URI): + normalized["email"] = str(normalized.get(settings.SSO_CLAIM_EMAIL_URI)) + if not normalized.get("sub"): + raise jwt.InvalidTokenError("JWT missing required subject claim") + return normalized + + +def _read_token_version_from_redis(user_id: str) -> Optional[int]: + client = _redis() + if not client: + return None + key = f"{settings.REDIS_TOKEN_VERSION_KEY_PREFIX}:{user_id}:token_version" + raw = client.get(key) + if raw is None: + return None + return int(raw) + + +async def _refresh_tokens(request: Request, response: Response) -> Optional[str]: + refresh_url = _absolute_url(request, settings.SSO_REFRESH_URL) + timeout = httpx.Timeout(settings.OAUTH_READ_TIMEOUT_SECONDS, connect=settings.OAUTH_CONNECT_TIMEOUT_SECONDS) + cookies = {} + access_cookie = request.cookies.get(settings.SSO_ACCESS_COOKIE_NAME) + refresh_cookie = request.cookies.get(settings.SSO_REFRESH_COOKIE_NAME) + if access_cookie: + cookies[settings.SSO_ACCESS_COOKIE_NAME] = access_cookie + if refresh_cookie: + cookies[settings.SSO_REFRESH_COOKIE_NAME] = refresh_cookie + + if not cookies: + return None + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + try: + resp = await client.post(refresh_url, cookies=cookies) + except httpx.HTTPError: + return None + + if resp.status_code >= 400: + return None + + for header, value in resp.headers.multi_items(): + if header.lower() == "set-cookie": + response.headers.append("set-cookie", value) + + return resp.cookies.get(settings.SSO_ACCESS_COOKIE_NAME) or request.cookies.get(settings.SSO_ACCESS_COOKIE_NAME) + + +def _email_from_claims(claims: dict) -> str: + email = str(claims.get("email") or "").strip().lower() + if email: + return email + username = str(claims.get("username") or "").strip().lower() + if username and "@" in username: + return username + if username: + return f"{username}@{settings.SSO_DEFAULT_USER_DOMAIN}" + return f"{claims['sub']}@{settings.SSO_DEFAULT_USER_DOMAIN}" + + +def _build_user_from_claims(claims: dict) -> dict: + user = get_or_create_sso_user( + developer_user_id=str(claims["sub"]), + email=_email_from_claims(claims), + username=str(claims.get("username") or "").strip(), + nickname=str(claims.get("nickname") or "").strip(), + ) + user["sso_sub"] = str(claims["sub"]) + user["sso_username"] = str(claims.get("username") or "") + user["sso_nickname"] = str(claims.get("nickname") or "") + return user + + +def resolve_websocket_user_from_cookie_header(cookie_header: str) -> Optional[dict]: + if not settings.SSO_ENABLED: + return None + if not cookie_header: + return None + + cookies: dict[str, str] = {} + for chunk in cookie_header.split(";"): + part = chunk.strip() + if not part or "=" not in part: + continue + name, value = part.split("=", 1) + cookies[name.strip()] = value.strip() + + access_token = cookies.get(settings.SSO_ACCESS_COOKIE_NAME) + if not access_token: + return None + + try: + claims = _decode_jwt(access_token) + except jwt.PyJWTError: + return None + + token_version_claim = int(claims.get("token_version") or 0) + redis_version = _read_token_version_from_redis(str(claims["sub"])) + if redis_version is not None and redis_version != token_version_claim: + return None + + return _build_user_from_claims(claims) + + +async def resolve_current_user(request: Request, response: Response, required: bool = True) -> Optional[dict]: + if not settings.SSO_ENABLED: + raise HTTPException(status_code=500, detail="SSO is not enabled") + + access_token = request.cookies.get(settings.SSO_ACCESS_COOKIE_NAME) + if not access_token: + if required: + raise HTTPException(status_code=401, detail="Login required") + return None + + claims: Optional[dict] = None + try: + claims = _decode_jwt(access_token) + except jwt.PyJWTError: + claims = None + + now = int(time.time()) + should_refresh = False + if claims is None: + should_refresh = True + else: + exp = int(claims.get("exp") or 0) + should_refresh = (exp - now) <= settings.SSO_REFRESH_THRESHOLD_SECONDS + + if should_refresh: + refreshed_token = await _refresh_tokens(request, response) + if refreshed_token: + try: + claims = _decode_jwt(refreshed_token) + except jwt.PyJWTError: + claims = None + + if claims is None: + if required: + raise HTTPException(status_code=401, detail="Invalid or expired SSO token") + return None + + token_version_claim = int(claims.get("token_version") or 0) + redis_version = _read_token_version_from_redis(str(claims["sub"])) + if redis_version is not None and redis_version != token_version_claim: + if required: + raise HTTPException(status_code=401, detail="Token has been revoked") + return None + + return _build_user_from_claims(claims) + + +async def developer_logout(request: Request, response: Response) -> None: + logout_url = _absolute_url(request, settings.SSO_LOGOUT_URL) + timeout = httpx.Timeout(settings.OAUTH_READ_TIMEOUT_SECONDS, connect=settings.OAUTH_CONNECT_TIMEOUT_SECONDS) + cookies = {} + access_cookie = request.cookies.get(settings.SSO_ACCESS_COOKIE_NAME) + refresh_cookie = request.cookies.get(settings.SSO_REFRESH_COOKIE_NAME) + if access_cookie: + cookies[settings.SSO_ACCESS_COOKIE_NAME] = access_cookie + if refresh_cookie: + cookies[settings.SSO_REFRESH_COOKIE_NAME] = refresh_cookie + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + try: + await client.post(logout_url, cookies=cookies) + except httpx.HTTPError: + # Best effort: local cookies are still cleared. + pass + + response.delete_cookie(settings.SSO_ACCESS_COOKIE_NAME, path="/") + response.delete_cookie(settings.SSO_REFRESH_COOKIE_NAME, path="/") diff --git a/app/store.py b/app/store.py index 75b5deb..d39c430 100644 --- a/app/store.py +++ b/app/store.py @@ -55,6 +55,8 @@ Column("provider", String(64), nullable=False), Column("provider_id", String(255), nullable=False), Column("email", String(255), nullable=False, unique=True), + Column("developer_user_id", String(255), unique=True), + Column("is_bound", Boolean, nullable=False, default=False), Column("name", String(255)), Column("avatar_url", Text), Column("credits", Integer, nullable=False, default=100), @@ -229,6 +231,15 @@ def ensure_schema_columns(conn): user_columns = {col["name"] for col in inspector.get_columns("users")} if "is_editor" not in user_columns: conn.execute(text("ALTER TABLE users ADD COLUMN is_editor BOOLEAN NOT NULL DEFAULT FALSE")) + if "developer_user_id" not in user_columns: + conn.execute(text("ALTER TABLE users ADD COLUMN developer_user_id VARCHAR(255)")) + if "is_bound" not in user_columns: + conn.execute(text("ALTER TABLE users ADD COLUMN is_bound BOOLEAN NOT NULL DEFAULT FALSE")) + + if conn.dialect.name == "postgresql": + conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS uq_users_developer_user_id ON users (developer_user_id) WHERE developer_user_id IS NOT NULL")) + elif conn.dialect.name == "sqlite": + conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS uq_users_developer_user_id ON users (developer_user_id)")) instance_columns = {col["name"] for col in inspector.get_columns("instance_records")} if "billing_session_id" not in instance_columns: @@ -367,6 +378,97 @@ def get_user(user_id: int) -> Optional[dict]: return row_to_dict(conn.execute(select(users).where(users.c.id == user_id)).mappings().first()) +def get_user_by_email(email: str) -> Optional[dict]: + normalized = (email or "").strip().lower() + if not normalized: + return None + with engine.begin() as conn: + return row_to_dict(conn.execute(select(users).where(users.c.email == normalized)).mappings().first()) + + +def get_user_by_developer_user_id(developer_user_id: str) -> Optional[dict]: + key = (developer_user_id or "").strip() + if not key: + return None + with engine.begin() as conn: + return row_to_dict(conn.execute(select(users).where(users.c.developer_user_id == key)).mappings().first()) + + +def bind_user_developer_user_id(user_id: int, developer_user_id: str) -> Optional[dict]: + key = (developer_user_id or "").strip() + if not key: + raise ValueError("developer_user_id is required") + now = utc_now() + with engine.begin() as conn: + current = conn.execute(select(users).where(users.c.id == user_id).with_for_update()).mappings().first() + if not current: + return None + existing = conn.execute(select(users).where(users.c.developer_user_id == key, users.c.id != user_id)).mappings().first() + if existing: + raise ValueError("developer_user_id is already bound to another user") + conn.execute( + update(users) + .where(users.c.id == user_id) + .values(developer_user_id=key, is_bound=True, updated_at=now) + ) + return row_to_dict(conn.execute(select(users).where(users.c.id == user_id)).mappings().first()) + + +def get_or_create_sso_user(developer_user_id: str, email: str, username: str = "", nickname: str = "") -> dict: + now = utc_now() + key = (developer_user_id or "").strip() + normalized_email = (email or "").strip().lower() + if not key: + raise ValueError("developer_user_id is required") + if not normalized_email: + raise ValueError("email is required") + + with engine.begin() as conn: + by_developer = conn.execute(select(users).where(users.c.developer_user_id == key)).mappings().first() + if by_developer: + conn.execute( + update(users) + .where(users.c.id == by_developer["id"]) + .values( + email=normalized_email, + name=(nickname or username or by_developer.get("name") or "").strip(), + provider="developer", + provider_id=key, + is_bound=True, + updated_at=now, + ) + ) + return row_to_dict(conn.execute(select(users).where(users.c.id == by_developer["id"])).mappings().first()) + + by_email = conn.execute(select(users).where(users.c.email == normalized_email)).mappings().first() + if by_email: + # Keep backward-compatible data; explicit binding is handled by UI flow. + user = row_to_dict(by_email) + user["needs_bind"] = not bool(by_email.get("developer_user_id")) + user["pending_developer_user_id"] = key + return user + + result = conn.execute( + users.insert().values( + provider="developer", + provider_id=key, + email=normalized_email, + developer_user_id=key, + is_bound=True, + name=(nickname or username or normalized_email.split("@")[0]).strip(), + avatar_url="", + credits=10, + created_at=now, + updated_at=now, + ) + ) + user_id = result.inserted_primary_key[0] + conn.execute( + credit_ledger.insert().values(user_id=user_id, delta=10, reason="signup_bonus", created_at=now) + ) + return row_to_dict(conn.execute(select(users).where(users.c.id == user_id)).mappings().first()) + + def ensure_user_min_credits(user_id: int, minimum_credits: int) -> Optional[dict]: with engine.begin() as conn: user = conn.execute(select(users).where(users.c.id == user_id)).mappings().first() diff --git a/k8s-deployment-v2.yaml b/k8s-deployment-v2.yaml index b796d2c..c23bb01 100644 --- a/k8s-deployment-v2.yaml +++ b/k8s-deployment-v2.yaml @@ -165,19 +165,35 @@ data: IDLE_TIMEOUT_MINUTES: "10" MAX_LIFETIME_HOURS: "6" SERVICE_HOST: "36.151.243.69" - PUBLIC_BASE_URL: "https://radeon.anruicloud.com" + PUBLIC_BASE_URL: "https://aideveloperportal.anruicloud.com/radeon" + PUBLIC_PATH_PREFIX: "/radeon" NODE_PORT_BASE: "30000" NOTEBOOK_TOKEN: "amd-oneclick" NOTEBOOK_LABEL_PREFIX: "amd-oneclick-v2" - GITHUB_REDIRECT_URI: "https://radeon.anruicloud.com/auth/github/callback" - MODELSCOPE_REDIRECT_URI: "https://radeon.anruicloud.com/auth/modelscope/callback" + GITHUB_REDIRECT_URI: "https://aideveloperportal.anruicloud.com/radeon/auth/github/callback" + MODELSCOPE_REDIRECT_URI: "https://aideveloperportal.anruicloud.com/radeon/auth/modelscope/callback" TELEMETRY_API_URL: "http://36.150.116.200:30090/v1/metrics" ONECLICK_TELEMETRY_ENABLED: "true" ONECLICK_TELEMETRY_SOURCE: "amd_oneclick" ONECLICK_TELEMETRY_PRODUCT: "radeon_cloud" COUPON_REDEEM_ENABLED: "true" COUPON_REDEEM_DISABLED_MESSAGE: "System maintenance is in progress. Credit redemption is temporarily unavailable. Please contact the administrator if you need credits." - + SSO_ENABLED: "true" + SSO_ISSUER: "AMD.aideveloperportal" + SSO_AUDIENCE: "WebApi" + SSO_ALGORITHM: "RS256" + SSO_ACCESS_COOKIE_NAME: "sso_access_token" + SSO_REFRESH_COOKIE_NAME: "sso_refresh_token" + SSO_REFRESH_THRESHOLD_SECONDS: "300" + SSO_REFRESH_URL: "https://aideveloperportal.anruicloud.com/apitest/api/auth/refresh" + SSO_LOGOUT_URL: "https://aideveloperportal.anruicloud.com/apitest/api/auth/logout" + SSO_BIND_ENTRY_URL: "https://aideveloperportal.anruicloud.com/login?returnUrl=" + SSO_BIND_RETURN_QUERY_KEY: "bind" + SSO_AUTO_CREATE_USER: "true" + SSO_DEFAULT_USER_DOMAIN: "developer.local" + REDIS_TOKEN_VERSION_KEY_PREFIX: "auth:user" + SSO_CLAIM_NAME_URI: "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name" + SSO_CLAIM_EMAIL_URI: "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress" --- apiVersion: v1 kind: Secret @@ -196,6 +212,8 @@ stringData: MODELSCOPE_TOKEN_URL: "https://modelscope.cn/oauth/token" MODELSCOPE_USERINFO_URL: "https://modelscope.cn/api/v1/user" COUPON_PRIVATE_KEY_PEM: "" + SSO_PUBLIC_KEY_PEM: "" + REDIS_URL: "" METRICS_INGEST_API_KEY: "" --- diff --git a/requirements.txt b/requirements.txt index 5a9fa08..43a729d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ psycopg2-binary>=2.9.0 httpx>=0.27.0 websockets>=12.0 pycryptodome>=3.23.0 +PyJWT>=2.9.0 +redis>=5.2.0 diff --git a/static/app-path.js b/static/app-path.js new file mode 100644 index 0000000..1a3fe4f --- /dev/null +++ b/static/app-path.js @@ -0,0 +1,24 @@ +(function () { + function normalizeBasePath(value) { + var prefix = String(value || "").trim(); + if (!prefix || prefix === "/") return ""; + return "/" + prefix.replace(/^\/+|\/+$/g, ""); + } + + var appBasePath = normalizeBasePath(window.__APP_BASE_PATH__); + + window.appUrl = function appUrl(path) { + var value = String(path || "/"); + if (/^(https?:|mailto:|tel:|data:|blob:|#)/i.test(value) || value.indexOf("//") === 0) return value; + var normalized = value.indexOf("/") === 0 ? value : "/" + value; + if (!appBasePath) return normalized; + if ( + normalized === appBasePath || + normalized.indexOf(appBasePath + "/") === 0 || + normalized.indexOf(appBasePath + "?") === 0 + ) { + return normalized; + } + return normalized === "/" ? appBasePath + "/" : appBasePath + normalized; + }; +})(); diff --git a/templates/admin.html b/templates/admin.html index 5789fba..de00cda 100644 --- a/templates/admin.html +++ b/templates/admin.html @@ -10,6 +10,8 @@ + +