diff --git a/app/__init__.py b/app/__init__.py index e08e75a..b79cfb0 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,10 +1,20 @@ +from dataclasses import dataclass from logging import Logger +from threading import BoundedSemaphore from fastapi import HTTPException, Request from pyiceberg.catalog import Catalog from rustworkx import PyDiGraph +@dataclass +class GpkgLimiter: + """Per-app concurrency guard for the hydrofabric gpkg endpoint.""" + + semaphore: BoundedSemaphore + queue_timeout_s: float + + def get_catalog(request: Request) -> Catalog: """Gets the pyiceberg catalog reference from the app state @@ -92,3 +102,11 @@ def get_graphs(request: Request) -> PyDiGraph: if not hasattr(request.app.state, "network_graphs") or request.app.state.network_graphs is None: raise HTTPException(status_code=500, detail="network_graphs not loaded") return request.app.state.network_graphs + + +def get_gpkg_limiter(request: Request) -> GpkgLimiter: + """Returns the per-app GpkgLimiter; 500 if not configured in lifespan.""" + limiter = getattr(request.app.state, "gpkg_limiter", None) + if limiter is None: + raise HTTPException(status_code=500, detail="gpkg_limiter not loaded") + return limiter diff --git a/app/main.py b/app/main.py index c421504..431b13a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,11 @@ import argparse import logging import os +import threading from contextlib import asynccontextmanager from pathlib import Path +import anyio import uvicorn from fastapi import FastAPI, status from fastapi.staticfiles import StaticFiles @@ -12,6 +14,7 @@ from pyiceberg.exceptions import NoSuchTableError from pyprojroot import here +from app import GpkgLimiter from app.routers.hydrofabric.router import api_router as hydrofabric_api_router from app.routers.nwm_modules.router import ( cfe_router, @@ -118,18 +121,39 @@ async def lifespan(app: FastAPI): """ app.state.main_logger = main_logger app.state.main_logger.info("Application starting up.") + # Cap per-worker sync-handler concurrency. Hydrofabric/ras_xs/nwm handlers + # can spike to hundreds of MB of pandas/geopandas memory per in-flight + # request, so on a t3.large (8 GB / 2 workers) we keep this low to avoid + # OOM. Effective per-instance concurrency = workers * total_tokens. + thread_limiter = anyio.to_thread.current_default_thread_limiter() + thread_limiter.total_tokens = 20 + app.state.main_logger.info(f"AnyIO threadpool limit set to {thread_limiter.total_tokens}") deploy_env = os.environ.get("ICEFABRIC_DEPLOY_ENV") or os.environ.get("ENVIRONMENT") or args.deploy_env deploy_env = deploy_env.lower() load_creds(deploy_env) - if args.cache_catalog == "sql": + if args.cache_catalog == "sql" and not os.environ.get("ICEFABRIC_CACHE_BUILT"): app.state.main_logger.info("Building local SQL cache...") build_cache(set(args.cached_namespaces), deploy_env) + else: + app.state.main_logger.info( + "Skipping local SQL cache build (already built by parent process or disabled)." + ) catalog = load_catalog(args.catalog) cache_catalog = load_catalog(args.cache_catalog) hydrofabric_namespaces = ["conus_hf", "ak_hf", "hi_hf", "prvi_hf"] app.state.catalog = catalog app.state.cache_catalog = cache_catalog app.state.cached_namespaces = {e.split(":")[0] for e in args.cached_namespaces} + # Per-worker concurrency cap for the heavy gpkg endpoint. Tunable via env. + gpkg_concurrency = int(os.environ.get("ICEFABRIC_HF_GPKG_CONCURRENCY", "1")) + gpkg_queue_timeout_s = float(os.environ.get("ICEFABRIC_HF_GPKG_QUEUE_TIMEOUT_S", "300")) + app.state.gpkg_limiter = GpkgLimiter( + semaphore=threading.BoundedSemaphore(gpkg_concurrency), + queue_timeout_s=gpkg_queue_timeout_s, + ) + app.state.main_logger.info( + f"gpkg concurrency cap per worker = {gpkg_concurrency} (queue timeout {gpkg_queue_timeout_s:.0f}s)" + ) try: app.state.network_graphs = load_upstream_json( catalog=catalog, @@ -211,4 +235,46 @@ def get_health() -> HealthCheck: print("INFO: Documentation directory 'static/docs' not found. Docs will not be served.") if __name__ == "__main__": - uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info") + # One-time setup in the parent before forking workers. With workers>1, + # doing this in lifespan races: concurrent SQL cache builds clobber the + # warehouse, and concurrent load_upstream_json() calls have one worker + # reading a partially-written graph JSON (-> EOF). Pre-building here + # means workers only hit the safe "read existing" paths. + _deploy_env = ( + os.environ.get("ICEFABRIC_DEPLOY_ENV") or os.environ.get("ENVIRONMENT") or args.deploy_env + ).lower() + load_creds(_deploy_env) + + if args.cache_catalog == "sql": + main_logger.info("Building local SQL cache (parent process, one-time)...") + build_cache(set(args.cached_namespaces), _deploy_env) + os.environ["ICEFABRIC_CACHE_BUILT"] = "1" + + # Prewarm hydrofabric graph JSON files so workers only read, never write. + _hf_namespaces = ["conus_hf", "ak_hf", "hi_hf", "prvi_hf"] + try: + main_logger.info("Prewarming hydrofabric network graphs (parent process)...") + _prewarm_catalog = load_catalog(args.catalog) + load_upstream_json( + catalog=_prewarm_catalog, + namespaces=_hf_namespaces, + output_path=here() / "data", + ) + except NoSuchTableError: + main_logger.warning( + "Hydrofabric namespaces not reachable at prewarm time; workers will attempt at startup." + ) + + # Recycle each worker after this many requests. Resets per-process RSS + # that otherwise creeps from glibc/numpy fragmentation over time. The + # supervisor respawns the worker; new workers skip the heavy one-time + # setup (cache + graphs are already on disk), so churn is ~seconds. + max_requests_per_worker = int(os.environ.get("ICEFABRIC_MAX_REQUESTS_PER_WORKER", "500")) + uvicorn.run( + "app.main:app", + host="0.0.0.0", + port=8000, + workers=2, + log_level="info", + limit_max_requests=max_requests_per_worker, + ) diff --git a/app/routers/hydrofabric/router.py b/app/routers/hydrofabric/router.py index b5cfb7d..b793d3d 100644 --- a/app/routers/hydrofabric/router.py +++ b/app/routers/hydrofabric/router.py @@ -1,7 +1,9 @@ +import gc import logging import pathlib import sqlite3 import tempfile +import time import uuid import geopandas as gpd @@ -13,7 +15,14 @@ from pyiceberg.expressions import EqualTo from starlette.background import BackgroundTask -from app import get_cache_catalog, get_cached_namespaces, get_catalog, get_graphs +from app import ( + GpkgLimiter, + get_cache_catalog, + get_cached_namespaces, + get_catalog, + get_gpkg_limiter, + get_graphs, +) from icefabric.hydrofabric import subset_hydrofabric, subset_nhf from icefabric.schemas import ( DivideAttributes, @@ -41,8 +50,16 @@ api_router = APIRouter(prefix="/hydrofabric") +def _cleanup_tmp(path: pathlib.Path) -> None: + """Delete a temp file; safe to call multiple times.""" + try: + path.unlink(missing_ok=True) + except OSError as e: # pragma: no cover - defensive + logger.warning(f"cleanup: failed to delete {path}: {e}") + + @api_router.get("/{identifier}/gpkg", tags=["Hydrofabric Services"]) -async def get_hydrofabric_subset_gpkg( +def get_hydrofabric_subset_gpkg( identifier: str = FastAPIPath( ..., description="Identifier to start tracing from (e.g., catchment ID, POI ID, HL_URI)", @@ -81,6 +98,7 @@ async def get_hydrofabric_subset_gpkg( cache_catalog=Depends(get_cache_catalog), cached_namespaces=Depends(get_cached_namespaces), network_graphs=Depends(get_graphs), + gpkg_limiter: GpkgLimiter = Depends(get_gpkg_limiter), ): """ Get hydrofabric subset as a geopackage file (.gpkg) @@ -122,6 +140,31 @@ async def get_hydrofabric_subset_gpkg( # swap catalog for cached catalog if appropriate catalog = cache_catalog if namespace in cached_namespaces else catalog + # Cap concurrent heavy builds per worker. Each CONUS VPU subset peaks + # at hundreds of MB of pandas/geopandas memory, so uncapped concurrency + # can OOM the instance. Queue timeouts surface as 503 rather than a + # silently stalled client. + sem_wait_start = time.monotonic() + if not gpkg_limiter.semaphore.acquire(timeout=gpkg_limiter.queue_timeout_s): + raise HTTPException( + status_code=503, + detail=( + "Hydrofabric service is at capacity. Please retry shortly. " + f"(waited {gpkg_limiter.queue_timeout_s:.0f}s for a slot)" + ), + headers={"Retry-After": "30"}, + ) + sem_held = True + sem_wait_ms = (time.monotonic() - sem_wait_start) * 1000 + if sem_wait_ms > 250: + logger.info(f"gpkg semaphore wait: {sem_wait_ms:.0f} ms for {identifier}") + + def _release_sem() -> None: + nonlocal sem_held + if sem_held: + sem_held = False + gpkg_limiter.semaphore.release() + try: if namespace.is_nhf: if id_type == QueryIdType.VPU_ID: @@ -175,30 +218,50 @@ async def get_hydrofabric_subset_gpkg( tmp_path.parent.mkdir(parents=True, exist_ok=True) + # Partition layers up front so we can pop + free as we write. + # pyogrio handles spatial, sqlite handles tabular (incl. empty ones). + spatial_names: list[str] = [] + nonspatial_names: list[str] = [] + for name, data in output_layers.items(): + if isinstance(data, gpd.GeoDataFrame) and len(data) > 0: + spatial_names.append(name) + elif not isinstance(data, gpd.GeoDataFrame): + nonspatial_names.append(name) + layers_written = 0 - spatial_layers = {} - nonspatial_layers = {} - - # Separate spatial vs non-spatial - for table_name, layer_data in output_layers.items(): - if isinstance(layer_data, gpd.GeoDataFrame) and len(layer_data) > 0: - spatial_layers[table_name] = layer_data - elif not isinstance(layer_data, gpd.GeoDataFrame): - nonspatial_layers[table_name] = layer_data - - # Write spatial layers first with pyogrio - for table_name, layer_data in spatial_layers.items(): - pyogrio.write_dataframe(layer_data, tmp_path, layer=table_name) - layers_written += 1 - logger.info(f"Written spatial layer '{table_name}' with {len(layer_data)} records") - # Then write non-spatial layers with sqlite3 (includes empty layers) - conn = sqlite3.connect(tmp_path) - for table_name, layer_data in nonspatial_layers.items(): - layer_data.to_sql(table_name, conn, if_exists="replace", index=False) + # Stream spatial layers one at a time: pop -> write -> del + gc so + # RSS stays flat across layers instead of accumulating. + for name in spatial_names: + layer_data = output_layers.pop(name) + n_rows = len(layer_data) + pyogrio.write_dataframe(layer_data, tmp_path, layer=name) + del layer_data + gc.collect() layers_written += 1 - logger.info(f"Written non-spatial layer '{table_name}' with {len(layer_data)} records") - conn.close() + logger.info(f"Written spatial layer '{name}' with {n_rows} records") + + # Share one sqlite connection across tabular layers. + if nonspatial_names: + conn = sqlite3.connect(tmp_path) + try: + for name in nonspatial_names: + layer_data = output_layers.pop(name) + n_rows = len(layer_data) + layer_data.to_sql(name, conn, if_exists="replace", index=False) + del layer_data + gc.collect() + layers_written += 1 + logger.info(f"Written non-spatial layer '{name}' with {n_rows} records") + finally: + conn.close() + + # Drop any layers skipped above (empty spatial frames etc.). + output_layers.clear() + gc.collect() + + # Heavy work is done; release the slot before streaming to the client. + _release_sem() if layers_written == 0: raise HTTPException( @@ -234,23 +297,17 @@ async def get_hydrofabric_subset_gpkg( "X-Domain": namespace, "X-Layers-Count": str(layers_written), }, - background=BackgroundTask(lambda: tmp_path.unlink(missing_ok=True)), + background=BackgroundTask(_cleanup_tmp, tmp_path), ) except HTTPException: - # Clean up temp file if it exists and re-raise HTTP exceptions - if tmp_path.exists(): - tmp_path.unlink(missing_ok=True) + _cleanup_tmp(tmp_path) raise except FileNotFoundError as e: - # Clean up temp file if it exists - if tmp_path.exists(): - tmp_path.unlink(missing_ok=True) + _cleanup_tmp(tmp_path) raise HTTPException(status_code=404, detail=f"Required file not found: {str(e)}") from None except ValueError as e: - # Clean up temp file if it exists - if tmp_path.exists(): - tmp_path.unlink(missing_ok=True) + _cleanup_tmp(tmp_path) if "No origin found" in str(e): raise HTTPException( status_code=404, @@ -258,10 +315,13 @@ async def get_hydrofabric_subset_gpkg( ) from None else: raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") from None + finally: + # Idempotent: no-op if already released on the happy path. + _release_sem() @api_router.get("/history", tags=["Hydrofabric Services"]) -async def get_hydrofabric_history( +def get_hydrofabric_history( domain: str = Query("conus_hf", description="The iceberg namespace used to query the hydrofabric"), catalog=Depends(get_catalog), ): diff --git a/app/routers/nwm_modules/router.py b/app/routers/nwm_modules/router.py index a706855..f368f87 100644 --- a/app/routers/nwm_modules/router.py +++ b/app/routers/nwm_modules/router.py @@ -1,8 +1,18 @@ +import logging +import time + from fastapi import APIRouter, Depends, HTTPException, Query from pydantic.json_schema import SkipJsonSchema from pyiceberg.catalog import Catalog -from app import get_cache_catalog, get_cached_namespaces, get_catalog, get_graphs +from app import ( + GpkgLimiter, + get_cache_catalog, + get_cached_namespaces, + get_catalog, + get_gpkg_limiter, + get_graphs, +) from icefabric.modules import SmpModules, config_mapper, get_parameter_metadata from icefabric.schemas import GeographicDomain, HydrofabricNamespace, HydrofabricSource from icefabric.schemas.modules import ( @@ -20,6 +30,8 @@ TRoute, ) +logger = logging.getLogger(__name__) + def _resolve_module_namespace( domain: GeographicDomain | SkipJsonSchema[None], @@ -59,7 +71,7 @@ def _resolve_module_namespace( @sft_router.get("/", tags=["NWM Modules"]) -async def get_sft_ipes( +def get_sft_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -119,7 +131,7 @@ async def get_sft_ipes( @snow17_router.get("/", tags=["NWM Modules"]) -async def get_snow17_ipes( +def get_snow17_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -178,7 +190,7 @@ async def get_snow17_ipes( @smp_router.get("/", tags=["NWM Modules"]) -async def get_smp_ipes( +def get_smp_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -237,7 +249,7 @@ async def get_smp_ipes( @lstm_router.get("/", tags=["NWM Modules"]) -async def get_lstm_ipes( +def get_lstm_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -289,7 +301,7 @@ async def get_lstm_ipes( @lasam_router.get("/", tags=["NWM Modules"]) -async def get_lasam_ipes( +def get_lasam_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -357,7 +369,7 @@ async def get_lasam_ipes( @noahowp_router.get("/", tags=["NWM Modules"]) -async def get_noahowp_ipes( +def get_noahowp_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -409,7 +421,7 @@ async def get_noahowp_ipes( @sacsma_router.get("/", tags=["NWM Modules"]) -async def get_sacsma_ipes( +def get_sacsma_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -468,7 +480,7 @@ async def get_sacsma_ipes( @troute_router.get("/", tags=["NWM Modules"]) -async def get_troute_ipes( +def get_troute_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -520,7 +532,7 @@ async def get_troute_ipes( @topmodel_router.get("/", tags=["NWM Modules"]) -async def get_topmodel_ipes( +def get_topmodel_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -572,7 +584,7 @@ async def get_topmodel_ipes( @topoflow_router.get("/", tags=["NWM Modules"]) -async def get_topoflow_ipes( +def get_topoflow_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -625,7 +637,7 @@ async def get_topoflow_ipes( ''' @topoflow_router.get("/albedo", tags=["NWM Modules"]) -async def get_albedo( +def get_albedo( landcover_state: Albedo = Query( ..., description="The landcover state of a catchment for albedo classification", @@ -649,7 +661,7 @@ async def get_albedo( @ueb_router.get("/", tags=["NWM Modules"]) -async def get_ueb_ipes( +def get_ueb_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -708,7 +720,7 @@ async def get_ueb_ipes( @cfe_router.get("/", tags=["NWM Modules"]) -async def get_cfe_ipes( +def get_cfe_ipes( identifier: str = Query( ..., description="Gage ID from which to trace upstream catchments.", @@ -781,7 +793,7 @@ async def get_cfe_ipes( @parameter_metadata_router.get("/", tags=["NWM Modules"]) -async def get_calibratable_parameter_metadata( +def get_calibratable_parameter_metadata( modules: list[str] = Query( ..., description="module name", @@ -791,6 +803,7 @@ async def get_calibratable_parameter_metadata( "LASAM": {"summary": "LASAM", "value": "LASAM"}, "LSTM": {"summary": "LSTM", "value": "LSTM"}, "Noah-OWP-Modular": {"summary": "Noah-OWP-Modular", "value": "Noah-OWP-Modular"}, + "PET": {"summary": "PET", "value": "PET"}, "Sac-SMA": {"summary": "Sac-SMA", "value": "Sac-SMA"}, "SFT": {"summary": "SFT", "value": "SFT"}, "SMP": {"summary": "SMP", "value": "SMP"}, @@ -821,6 +834,7 @@ async def get_calibratable_parameter_metadata( cache_catalog: Catalog = Depends(get_cache_catalog), cached_namespaces=Depends(get_cached_namespaces), network_graphs=Depends(get_graphs), + gpkg_limiter: GpkgLimiter = Depends(get_gpkg_limiter), ): """ An endpoint to return calibratable parameter metadata for a module. @@ -841,6 +855,7 @@ async def get_calibratable_parameter_metadata( "LASAM": "lasam", "LSTM": "lstm", "Noah-OWP-Modular": "noahowp", + "PET": "pet", "Sac-SMA": "sacsma", "SFT": "sft", "SMP": "smp", @@ -858,6 +873,7 @@ async def get_calibratable_parameter_metadata( "lasam": "LASAM", "lstm": "LSTM", "noahowp": "Noah-OWP-Modular", + "pet": "PET", "sacsma": "Sac-SMA", "sft": "SFT", "smp": "SMP", @@ -888,16 +904,43 @@ async def get_calibratable_parameter_metadata( graph = network_graphs[namespace] # Use cache catalog if parameter_metadata namespace is cached - active_catalog = cache_catalog if "parameter_metadata" in cached_namespaces and namespace in cached_namespaces else catalog - - parameter_metadata = get_parameter_metadata( - modules=modules, - catalog=active_catalog, - gage_id=formatted_gage_id, - domain=namespace, - graph=graph, + active_catalog = ( + cache_catalog + if "parameter_metadata" in cached_namespaces and namespace in cached_namespaces + else catalog ) + # With gage_id, runs same hydrofabric subset as /gpkg -> share the limiter. + needs_subset = formatted_gage_id is not None + sem_held = False + if needs_subset: + sem_wait_start = time.monotonic() + if not gpkg_limiter.semaphore.acquire(timeout=gpkg_limiter.queue_timeout_s): + raise HTTPException( + status_code=503, + detail=( + "Hydrofabric subset service is at capacity. Please retry shortly. " + f"(waited {gpkg_limiter.queue_timeout_s:.0f}s for a slot)" + ), + headers={"Retry-After": "30"}, + ) + sem_held = True + sem_wait_ms = (time.monotonic() - sem_wait_start) * 1000 + if sem_wait_ms > 250: + logger.info(f"param_metadata semaphore wait: {sem_wait_ms:.0f} ms for gage {gage_id}") + + try: + parameter_metadata = get_parameter_metadata( + modules=modules, + catalog=active_catalog, + gage_id=formatted_gage_id, + domain=namespace, + graph=graph, + ) + finally: + if sem_held: + gpkg_limiter.semaphore.release() + for module in parameter_metadata: module_name = module["module_name"] NWM_module_name = icefabric_nwm_module_mapping.get(module_name, module_name) diff --git a/app/routers/ras_xs/router.py b/app/routers/ras_xs/router.py index 1fb0f32..f22e731 100644 --- a/app/routers/ras_xs/router.py +++ b/app/routers/ras_xs/router.py @@ -96,7 +96,7 @@ def filesystem_check(tmp_path: pathlib.PosixPath, temp_dir: pathlib.PosixPath): @api_router.get("/{identifier}/", tags=["HEC-RAS XS"]) -async def get_xs_subset_gpkg( +def get_xs_subset_gpkg( identifier: str = Path( ..., description="The flowpath ID from the reference hydrofabric that the current RAS XS aligns is conflated to. Must be numeric.", @@ -170,7 +170,7 @@ async def get_xs_subset_gpkg( @api_router.get("/within", tags=["HEC-RAS XS"]) -async def get_by_geospatial_query( +def get_by_geospatial_query( bbox: BoundingBox = Depends(get_bbox_query_params), schema_type: XsType = Query( XsType.CONFLATED, description="The schema type used to query the cross-sections" diff --git a/app/routers/streamflow_observations/router.py b/app/routers/streamflow_observations/router.py index e446037..2a74d86 100644 --- a/app/routers/streamflow_observations/router.py +++ b/app/routers/streamflow_observations/router.py @@ -122,7 +122,7 @@ def validate_identifier(identifier: str): @api_router.get("/{identifier}/info", tags=["Streamflow Observations"]) -async def get_identifier_info( +def get_identifier_info( identifier: str = Path( ..., description="Station/gauge ID", @@ -167,7 +167,7 @@ async def get_identifier_info( @api_router.get("/{identifier}/{output_format}", tags=["Streamflow Observations"]) -async def get_data_time_range( +def get_data_time_range( identifier: str = Path( ..., description="Station/gauge ID", @@ -252,7 +252,7 @@ async def get_data_time_range( @api_router.get("/history", tags=["Streamflow Observations"]) -async def get_repo_history(): +def get_repo_history(): """ GET Repo History/Snapshots @@ -285,7 +285,7 @@ async def get_repo_history(): @api_router.get("/available", tags=["Streamflow Observations"]) -async def get_available_identifiers( +def get_available_identifiers( limit: int = Query(100, description="Maximum number of IDs to return"), ): """ diff --git a/docker/nginx/default.conf b/docker/nginx/default.conf index db09c36..97038c9 100644 --- a/docker/nginx/default.conf +++ b/docker/nginx/default.conf @@ -24,6 +24,12 @@ server { proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header Upgrade $http_upgrade; # Enable HTTP/2 support proxy_http_version 1.1; # Enable HTTP/1.1 + + # Long-running endpoints (e.g., hydrofabric gpkg exports) can take + # several minutes. Default nginx proxy timeouts are 60s. + proxy_connect_timeout 600s; + proxy_read_timeout 600s; + proxy_send_timeout 600s; } # Route to dashboard diff --git a/pyproject.toml b/pyproject.toml index 6c1682d..ec7d244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ convention = "numpy" [tool.ruff.lint.per-file-ignores] "docs/*" = ["I"] "tests/*" = ["D"] +"scripts/*" = ["D", "BLE001"] "*/__init__.py" = ["F401"] [tool.mypy] diff --git a/scripts/load_test/analyze.py b/scripts/load_test/analyze.py new file mode 100644 index 0000000..a9d241b --- /dev/null +++ b/scripts/load_test/analyze.py @@ -0,0 +1,86 @@ +"""Summarise stats.csv from monitor.sh — look for memory creep, OOM proximity, CPU saturation.""" + +from __future__ import annotations + +import argparse +import csv +import re +from pathlib import Path + + +def to_bytes(s: str) -> float: + s = s.strip() + m = re.match(r"([0-9.]+)\s*([KMGT]?i?B)", s, re.I) + if not m: + return 0.0 + v = float(m.group(1)) + unit = m.group(2).lower() + mult = { + "b": 1, + "kb": 1e3, + "mb": 1e6, + "gb": 1e9, + "tb": 1e12, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + }.get(unit, 1) + return v * mult + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--stats", default="scripts/load_test/results/stats.csv") + args = p.parse_args() + path = Path(args.stats) + if not path.exists(): + print(f"no stats file at {path}") + return + + rows = list(csv.DictReader(path.open())) + if not rows: + print("no samples") + return + + cpu_vals = [float(r["cpu_pct"].rstrip("%")) for r in rows if r["cpu_pct"]] + mem_vals = [to_bytes(r["mem_usage"]) for r in rows if r["mem_usage"]] + mem_pct_vals = [float(r["mem_pct"].rstrip("%")) for r in rows if r["mem_pct"]] + + def q(xs, p): + if not xs: + return 0.0 + xs = sorted(xs) + return xs[int((p / 100.0) * (len(xs) - 1))] + + print(f"samples: {len(rows)}") + print( + f"cpu % min={min(cpu_vals):6.1f} mean={sum(cpu_vals) / len(cpu_vals):6.1f} " + f"p95={q(cpu_vals, 95):6.1f} max={max(cpu_vals):6.1f}" + ) + print( + f"mem GiB min={min(mem_vals) / 1024**3:6.2f} mean={sum(mem_vals) / len(mem_vals) / 1024**3:6.2f} " + f"p95={q(mem_vals, 95) / 1024**3:6.2f} max={max(mem_vals) / 1024**3:6.2f}" + ) + print( + f"mem % min={min(mem_pct_vals):6.1f} mean={sum(mem_pct_vals) / len(mem_pct_vals):6.1f} " + f"p95={q(mem_pct_vals, 95):6.1f} max={max(mem_pct_vals):6.1f}" + ) + + # Creep detection: compare first-third vs last-third mean memory + third = len(mem_vals) // 3 + if third >= 3: + head = sum(mem_vals[:third]) / third + tail = sum(mem_vals[-third:]) / third + creep = (tail - head) / max(head, 1) + print(f"mem creep (last-third vs first-third): {creep * 100:+.1f}%") + if creep > 0.25: + print(" ⚠️ > 25% growth — possible leak or unbounded caching") + elif creep > 0.10: + print(" ⚠️ > 10% growth — keep an eye on it over a longer run") + else: + print(" ✅ memory stable") + + +if __name__ == "__main__": + main() diff --git a/scripts/load_test/docker-compose.load.yml b/scripts/load_test/docker-compose.load.yml new file mode 100644 index 0000000..fbb3c38 --- /dev/null +++ b/scripts/load_test/docker-compose.load.yml @@ -0,0 +1,35 @@ +version: "2.4" +services: + api: + build: + context: ../.. + dockerfile: docker/Dockerfile.api + image: icefabric-api:loadtest + container_name: icefabric-api-loadtest + ports: + - "127.0.0.1:8000:8000" + env_file: + - ../../.env + environment: + - ICEFABRIC_DEPLOY_ENV=test + # Curb glibc per-thread arena fragmentation on heavy numpy/pandas use. + - MALLOC_ARENA_MAX=2 + # Stop numerical libs and polars from spawning 1 thread per core; + # we only have ~2 vCPU of budget and already fan out via asyncio + # threadpool + subset_nhf's ThreadPoolExecutor. + - OMP_NUM_THREADS=2 + - OPENBLAS_NUM_THREADS=2 + - MKL_NUM_THREADS=2 + - POLARS_MAX_THREADS=2 + # Emulating m6i.xlarge (4 vCPU, 16 GB) for this run. + cpus: 4.0 + mem_limit: 16g + memswap_limit: 16g + oom_kill_disable: false + restart: "no" + healthcheck: + test: ["CMD", "curl", "-f", "--head", "http://localhost:8000/health"] + interval: 15s + timeout: 10s + retries: 20 + start_period: 300s diff --git a/scripts/load_test/load_test.py b/scripts/load_test/load_test.py new file mode 100644 index 0000000..294b78d --- /dev/null +++ b/scripts/load_test/load_test.py @@ -0,0 +1,354 @@ +""" +Load test for the icefabric API. + +Targets 3 commonly-used endpoints at ~100 requests/minute total, with valid +CONUS NHF identifiers. Designed to run against a locally Dockerized API with +t3.large-equivalent resource caps. + +Usage +----- + python load_test.py --base-url http://localhost:8000 --rpm 100 --duration 300 +""" + +from __future__ import annotations + +import argparse +import asyncio +import csv +import json +import random +import statistics +import time +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from pathlib import Path + +import httpx + +# --------------------------------------------------------------------------- +# Identifier pools (valid CONUS NHF) +# --------------------------------------------------------------------------- +# NHF CONUS VPU IDs. These are the standard National Hydrofabric VPU codes. +VPU_IDS: list[str] = [f"{i:02d}" for i in range(1, 19)] # "01".."18" + +# Known-valid NHF flowpath IDs (integers). Populated at runtime from the +# `/available` / hydrofabric responses when possible; defaults below are from +# the router's documented examples and a safe spread. +FLOWPATH_IDS: list[int] = [3490271] + +# Parameter-metadata modules +MODULES: list[str] = [ + "CFE-S", + "CFE-X", + "LASAM", + "LSTM", + "Noah-OWP-Modular", + "PET", + "Sac-SMA", + "SFT", + "SMP", + "Snow-17", + "T-Route", + "TopModel", + "Topoflow-Glacier", + "UEB", +] + +# Fallback gauge IDs (USGS) — the driver will try to fetch a fresh list from +# /streamflow_observations/available at start; these are the safety net. +FALLBACK_GAGES: list[str] = [ + "01010000", + "01031500", + "02GC002", + "08102730", + "01013500", + "01022500", + "01030500", + "01047000", +] + + +# --------------------------------------------------------------------------- +# Result tracking +# --------------------------------------------------------------------------- +@dataclass +class Sample: + endpoint: str + url: str + status: int + latency_ms: float + error: str | None = None + bytes: int = 0 + started_at: float = 0.0 + + +@dataclass +class Results: + samples: list[Sample] = field(default_factory=list) + + def add(self, s: Sample) -> None: + self.samples.append(s) + + def summary(self) -> dict: + by_ep: dict[str, list[Sample]] = defaultdict(list) + for s in self.samples: + by_ep[s.endpoint].append(s) + + out: dict = {"overall": self._stats(self.samples), "by_endpoint": {}} + for ep, xs in by_ep.items(): + out["by_endpoint"][ep] = self._stats(xs) + return out + + @staticmethod + def _stats(xs: list[Sample]) -> dict: + if not xs: + return {"count": 0} + lats = [s.latency_ms for s in xs] + codes = Counter(s.status for s in xs) + ok = sum(1 for s in xs if 200 <= s.status < 300) + errs = sum(1 for s in xs if s.status >= 500 or s.status == 0) + lats_sorted = sorted(lats) + + def pct(p): + if not lats_sorted: + return 0.0 + k = int(round((p / 100.0) * (len(lats_sorted) - 1))) + return lats_sorted[k] + + return { + "count": len(xs), + "success": ok, + "errors_5xx_or_conn": errs, + "status_codes": dict(codes), + "latency_ms": { + "min": round(min(lats), 1), + "mean": round(statistics.mean(lats), 1), + "p50": round(pct(50), 1), + "p95": round(pct(95), 1), + "p99": round(pct(99), 1), + "max": round(max(lats), 1), + }, + "total_bytes": sum(s.bytes for s in xs), + } + + +# --------------------------------------------------------------------------- +# Request builders +# --------------------------------------------------------------------------- +def build_hydrofabric_request(gages: list[str]) -> tuple[str, str]: + """Hydrofabric NHF subset — realistic mix of VPU and gage requests. + + VPU subsets are the largest (full region); gage subsets vary widely by + basin size. Random uniform sampling over the ~200 gauge pool gives the + natural cost distribution (headwater gauges are cheap, basin-outlet + gauges are expensive). + """ + # ~40% VPU (consistent heavy), ~60% gage (variable cost) + if random.random() < 0.4: + ident = random.choice(VPU_IDS) + id_type = "vpu_id" + else: + ident = random.choice(gages) + id_type = "gage_id" + url = f"/api/v1/hydrofabric/{ident}/gpkg?id_type={id_type}&source=nhf&domain=CONUS" + return "hydrofabric_gpkg", url + + +def build_parameter_metadata_request(gages: list[str]) -> tuple[str, str]: + """Parameter metadata with gage_id — always the heavy subset path.""" + mod = random.choice(MODULES) + gage = random.choice(gages) + url = f"/api/v1/modules/parameter_metadata/?modules={mod}&gage_id={gage}&domain=CONUS&source=hf" + return "parameter_metadata", url + + +def build_streamflow_request(gages: list[str]) -> tuple[str, str]: + gage = random.choice(gages) + url = f"/api/v1/streamflow_observations/{gage}/info" + return "streamflow_info", url + + +def pick_request(gages: list[str], weights: tuple[int, int, int]) -> tuple[str, str]: + """Pick a weighted random endpoint. weights = (hf, param, streamflow).""" + total = sum(weights) + r = random.randint(1, total) + if r <= weights[0]: + return build_hydrofabric_request(gages) + elif r <= weights[0] + weights[1]: + return build_parameter_metadata_request(gages) + else: + return build_streamflow_request(gages) + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- +async def fetch_available_gages(client: httpx.AsyncClient) -> list[str]: + try: + r = await client.get("/api/v1/streamflow_observations/available?limit=200", timeout=60) + if r.status_code == 200: + data = r.json() + # Endpoint may return dict or list; try common shapes + if isinstance(data, dict): + for key in ("identifiers", "ids", "available"): + if key in data and isinstance(data[key], list): + return [str(x) for x in data[key] if x][:200] + elif isinstance(data, list): + return [str(x) for x in data if x][:200] + except Exception as e: + print(f"[discover] couldn't fetch /available: {e}") + return FALLBACK_GAGES + + +async def one_request( + client: httpx.AsyncClient, + ep: str, + url: str, + results: Results, + timeout_s: float, +) -> None: + started = time.perf_counter() + wall = time.time() + status = 0 + err = None + nbytes = 0 + try: + # Stream so we don't buffer entire gpkg into memory on the client side + async with client.stream("GET", url, timeout=timeout_s) as r: + status = r.status_code + async for chunk in r.aiter_bytes(): + nbytes += len(chunk) + except httpx.TimeoutException: + err = "timeout" + except httpx.HTTPError as e: + err = f"http_error:{type(e).__name__}" + except Exception as e: + err = f"exc:{type(e).__name__}:{e}" + latency_ms = (time.perf_counter() - started) * 1000 + results.add( + Sample( + endpoint=ep, + url=url, + status=status, + latency_ms=latency_ms, + error=err, + bytes=nbytes, + started_at=wall, + ) + ) + + +async def run( + base_url: str, + rpm: int, + duration_s: int, + timeout_s: float, + weights: tuple[int, int, int], + out_dir: Path, +) -> Results: + interval = 60.0 / rpm + results = Results() + limits = httpx.Limits(max_connections=64, max_keepalive_connections=32) + async with httpx.AsyncClient(base_url=base_url, limits=limits) as client: + # Warm up / discover valid gages + print("[discover] fetching available gauge IDs...") + gages = await fetch_available_gages(client) + print(f"[discover] using {len(gages)} gauge IDs (sample: {gages[:5]})") + + start = time.time() + tasks: list[asyncio.Task] = [] + issued = 0 + next_fire = start + while time.time() - start < duration_s: + now = time.time() + if now >= next_fire: + ep, url = pick_request(gages, weights) + tasks.append(asyncio.create_task(one_request(client, ep, url, results, timeout_s))) + issued += 1 + next_fire += interval + if issued % 10 == 0: + elapsed = now - start + ok = sum(1 for s in results.samples if 200 <= s.status < 300) + print( + f"[{elapsed:5.0f}s] issued={issued} done={len(results.samples)} " + f"ok={ok} inflight={issued - len(results.samples)}" + ) + else: + await asyncio.sleep(min(0.05, next_fire - now)) + + # Drain + print(f"[drain] waiting for {issued - len(results.samples)} in-flight requests...") + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + # Persist raw samples as CSV + out_dir.mkdir(parents=True, exist_ok=True) + csv_path = out_dir / "samples.csv" + with csv_path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["started_at", "endpoint", "status", "latency_ms", "bytes", "error", "url"]) + for s in results.samples: + w.writerow( + [ + f"{s.started_at:.3f}", + s.endpoint, + s.status, + f"{s.latency_ms:.1f}", + s.bytes, + s.error or "", + s.url, + ] + ) + print(f"[io] wrote {csv_path}") + return results + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--base-url", default="http://localhost:8000") + p.add_argument("--rpm", type=int, default=100, help="Requests per minute") + p.add_argument("--duration", type=int, default=300, help="Total test duration in seconds") + p.add_argument( + "--timeout", + type=float, + default=300, + help="Per-request timeout seconds (matches ICEFABRIC_HF_GPKG_QUEUE_TIMEOUT_S default).", + ) + p.add_argument( + "--weights", + default="1,2,2", + help=( + "Endpoint weights hf,param,stream. Hydrofabric gpkg is heavy (~130 MB / ~30s " + "per CONUS VPU) so we keep its share modest; a run at 100 rpm with 20% hf " + "weight still exercises ~20 concurrent heavy gpkg builds per minute." + ), + ) + p.add_argument("--out-dir", default="scripts/load_test/results") + args = p.parse_args() + weights = tuple(int(x) for x in args.weights.split(",")) + assert len(weights) == 3, "weights must be 3 ints" + + print(f"[cfg] base_url={args.base_url} rpm={args.rpm} duration={args.duration}s weights={weights}") + results = asyncio.run( + run( + base_url=args.base_url, + rpm=args.rpm, + duration_s=args.duration, + timeout_s=args.timeout, + weights=weights, # type: ignore[arg-type] + out_dir=Path(args.out_dir), + ) + ) + + summary = results.summary() + summary_path = Path(args.out_dir) / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(json.dumps(summary, indent=2)) + print(f"\n[io] wrote {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/load_test/monitor.sh b/scripts/load_test/monitor.sh new file mode 100755 index 0000000..9d5c026 --- /dev/null +++ b/scripts/load_test/monitor.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Stream `docker stats` into a CSV for later analysis. +# Usage: ./monitor.sh [interval_seconds] +set -euo pipefail + +CONTAINER="${1:-icefabric-api-loadtest}" +OUT="${2:-scripts/load_test/results/stats.csv}" +INTERVAL="${3:-2}" + +mkdir -p "$(dirname "$OUT")" +echo "timestamp,name,cpu_pct,mem_usage,mem_limit,mem_pct,net_io,block_io,pids" > "$OUT" + +while true; do + ts=$(date +%s) + # --no-stream gives one snapshot; we parse its non-header line + docker stats --no-stream --format '{{.Name}},{{.CPUPerc}},{{.MemUsage}},{{.MemPerc}},{{.NetIO}},{{.BlockIO}},{{.PIDs}}' \ + "$CONTAINER" 2>/dev/null | while IFS=, read -r name cpu mem mempct net block pids; do + # Split "123.4MiB / 8GiB" into usage,limit + usage=$(echo "$mem" | awk -F' / ' '{print $1}') + limit=$(echo "$mem" | awk -F' / ' '{print $2}') + printf '%s,%s,%s,%s,%s,%s,"%s","%s",%s\n' \ + "$ts" "$name" "$cpu" "$usage" "$limit" "$mempct" "$net" "$block" "$pids" >> "$OUT" + done || true + sleep "$INTERVAL" +done diff --git a/scripts/load_test/run.sh b/scripts/load_test/run.sh new file mode 100755 index 0000000..ad2c639 --- /dev/null +++ b/scripts/load_test/run.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +# Orchestrate: build → start container → wait healthy → monitor + load test → teardown. +set -euo pipefail + +HERE="$(cd "$(dirname "$0")" && pwd)" +cd "$HERE" + +RESULTS="results" +mkdir -p "$RESULTS" + +DURATION="${DURATION:-300}" # seconds +RPM="${RPM:-100}" +TIMEOUT="${TIMEOUT:-120}" # per-request seconds +CPUS="${CPUS:-2.0}" +MEMORY="${MEMORY:-8g}" + +echo "== config: rpm=$RPM duration=${DURATION}s per-req-timeout=${TIMEOUT}s cpus=$CPUS mem=$MEMORY ==" + +# Use docker compose v2 if available, else fall back to docker-compose v1. +if docker compose version >/dev/null 2>&1; then + DC="docker compose" +else + DC="docker-compose" +fi +echo "== using compose: $DC ==" + +# --- Build --- +echo "== building api image ==" +$DC -f docker-compose.load.yml build api + +# --- Start --- +echo "== starting container ==" +$DC -f docker-compose.load.yml up -d api + +# Apply runtime caps (compose v2 ignores top-level cpus/mem_limit without swarm; +# make it explicit with docker update after start). +docker update --cpus="$CPUS" --memory="$MEMORY" --memory-swap="$MEMORY" \ + icefabric-api-loadtest >/dev/null +echo "== applied --cpus=$CPUS --memory=$MEMORY ==" + +# --- Wait for healthy --- +echo "== waiting for /health (up to 10 min, cache build takes time) ==" +for i in $(seq 1 120); do + if curl -fsS --max-time 3 http://localhost:8000/health >/dev/null 2>&1; then + echo "== api is healthy after $((i*5))s ==" + break + fi + sleep 5 + if [[ $((i % 12)) -eq 0 ]]; then + echo " still waiting... ($((i*5))s elapsed)" + fi +done +if ! curl -fsS --max-time 3 http://localhost:8000/health >/dev/null 2>&1; then + echo "!! api never became healthy. Last 80 lines of container logs:" + docker logs --tail 80 icefabric-api-loadtest || true + exit 1 +fi + +# --- Monitor --- +echo "== starting docker-stats monitor ==" +bash monitor.sh icefabric-api-loadtest "$RESULTS/stats.csv" 2 & +MON_PID=$! +trap 'kill $MON_PID 2>/dev/null || true; $DC -f docker-compose.load.yml logs api > "$RESULTS/container.log" 2>&1 || true' EXIT + +# --- Load test --- +echo "== running load test ==" +python3 load_test.py \ + --base-url http://localhost:8000 \ + --rpm "$RPM" \ + --duration "$DURATION" \ + --timeout "$TIMEOUT" \ + --out-dir "$RESULTS" + +# --- Stop monitor, dump logs, analyze --- +kill $MON_PID 2>/dev/null || true +sleep 2 +$DC -f docker-compose.load.yml logs api > "$RESULTS/container.log" 2>&1 || true + +echo "" +echo "== resource analysis ==" +python3 analyze.py --stats "$RESULTS/stats.csv" + +echo "" +echo "== container state ==" +docker inspect icefabric-api-loadtest \ + --format '{{.State.Status}} OOMKilled={{.State.OOMKilled}} ExitCode={{.State.ExitCode}} RestartCount={{.RestartCount}}' + +echo "" +echo "== grep for errors / OOM / tracebacks in container log ==" +grep -i -E "oom|killed|memoryerror|traceback|error" "$RESULTS/container.log" | head -30 || true + +echo "" +echo "results in: $HERE/$RESULTS" diff --git a/src/icefabric/modules/get_parameter_metadata.py b/src/icefabric/modules/get_parameter_metadata.py index 41c112d..da999bc 100644 --- a/src/icefabric/modules/get_parameter_metadata.py +++ b/src/icefabric/modules/get_parameter_metadata.py @@ -105,7 +105,7 @@ def get_parameter_metadata( check_module_name = module # Modules that are valid but don't have parameter metadata yet - modules_without_metadata = {"troute", "lstm"} + modules_without_metadata = {"troute", "lstm", "pet"} if check_module_name in modules_without_metadata: output = {"module_name": module, "calibratable_parameters": []} diff --git a/src/icefabric/schemas/iceberg_tables/hydrofabric_update.py b/src/icefabric/schemas/iceberg_tables/hydrofabric_update.py index 7a0a5e7..8884f71 100644 --- a/src/icefabric/schemas/iceberg_tables/hydrofabric_update.py +++ b/src/icefabric/schemas/iceberg_tables/hydrofabric_update.py @@ -1700,10 +1700,10 @@ def schema(cls) -> Schema: "Percentage of the length of a flowpath segment that falls inside a buffer around the reference flowpath", ] return Schema( - NestedField(1, "nhd_feature_id", LongType(), required=True, doc=desc[0]), - NestedField(2, "ref_id", LongType(), required=False, doc=desc[1]), + NestedField(1, "nhd_feature_id", LongType(), required=False, doc=desc[0]), + NestedField(2, "ref_id", LongType(), required=True, doc=desc[1]), NestedField(3, "percent_inside", DoubleType(), required=False, doc=desc[2]), - identifier_field_ids=[1], + identifier_field_ids=[2], ) @classmethod @@ -1718,8 +1718,8 @@ def arrow_schema(cls) -> pa.Schema: """ return pa.schema( [ - pa.field("nhd_feature_id", pa.int64(), nullable=False), - pa.field("ref_id", pa.int64(), nullable=True), + pa.field("nhd_feature_id", pa.int64(), nullable=True), + pa.field("ref_id", pa.int64(), nullable=False), pa.field("percent_inside", pa.float64(), nullable=True), ] ) diff --git a/terraform/modules/app_service/templates/user_data.sh.tpl b/terraform/modules/app_service/templates/user_data.sh.tpl index 0ad9ea3..313ba79 100644 --- a/terraform/modules/app_service/templates/user_data.sh.tpl +++ b/terraform/modules/app_service/templates/user_data.sh.tpl @@ -282,6 +282,23 @@ services: - "127.0.0.1:8000:8000" env_file: - ./.env + environment: + # Curb glibc per-thread arena fragmentation (heavy numpy/pandas use). + - MALLOC_ARENA_MAX=2 + # Stop numerical libs and polars from spawning 1 thread per core; + # we only have ~2 vCPU of budget per worker on t3.large. + - OMP_NUM_THREADS=2 + - OPENBLAS_NUM_THREADS=2 + - MKL_NUM_THREADS=2 + - POLARS_MAX_THREADS=2 + # Hydrofabric subset concurrency guard. 1 per worker * 2 workers = 2 + # concurrent heavy builds per EC2. Tune up if you move to a larger + # instance type with more RAM headroom. + - ICEFABRIC_HF_GPKG_CONCURRENCY=1 + - ICEFABRIC_HF_GPKG_QUEUE_TIMEOUT_S=300 + # Recycle each worker after N requests to reset memory creep + # from glibc arena / numpy allocator fragmentation. + - ICEFABRIC_MAX_REQUESTS_PER_WORKER=500 restart: always healthcheck: test: ["CMD", "curl", "-f", "--head", "http://localhost:8000/health"] diff --git a/tests/conftest.py b/tests/conftest.py index 4415193..e5b91d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThanOrEqual from pyprojroot import here +from app import GpkgLimiter from app.main import app from icefabric.builds.graph_connectivity import read_edge_attrs, read_node_attrs from icefabric.schemas.icechunk import NGWPCTestLocations @@ -997,9 +998,16 @@ def testing_dir() -> Path: @pytest.fixture(scope="session") def client(): """Create a test client for the FastAPI app with mock catalog.""" + import threading + app.state.catalog = MockCatalog() # defaulting to use the mock catalog app.state.cached_namespaces = {"conus_hf", "ak_hf", "hi_hf", "prvi_hf"} app.state.cache_catalog = MockCatalog() # defaulting to use the mock catalog + # Tests skip lifespan; seed app.state manually. + app.state.gpkg_limiter = GpkgLimiter( + semaphore=threading.BoundedSemaphore(16), + queue_timeout_s=60.0, + ) return TestClient(app) diff --git a/tests/smoke/test_parameter_metadata_smoke.py b/tests/smoke/test_parameter_metadata_smoke.py index f34e511..00252ae 100644 --- a/tests/smoke/test_parameter_metadata_smoke.py +++ b/tests/smoke/test_parameter_metadata_smoke.py @@ -16,6 +16,7 @@ "LASAM", "LSTM", "Noah-OWP-Modular", + "PET", "Sac-SMA", "SFT", "SMP",