From 5cb5a425b60b12fb0d06eb1812a6b3865c8ed83c Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 12:25:50 -0400 Subject: [PATCH 1/2] Refresh main.py Add aoi_is_item flag Add more exhaustive pipeline failure reporting and make it so that more things make a pipeline explicitely fail. Early exit logic when no flow scenarios or catchments are found --- src/main.py | 187 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 148 insertions(+), 39 deletions(-) diff --git a/src/main.py b/src/main.py index 1f6fd3a..b1b4809 100644 --- a/src/main.py +++ b/src/main.py @@ -14,7 +14,7 @@ import geopandas as gpd import default_config -from data_service import DataService +from data_service import DataService, DataServiceException from load_config import AppConfig, load_config from metrics_aggregator import MetricsAggregator from nomad_job_manager import NomadJobManager @@ -42,6 +42,7 @@ def __init__( polygon_gdf: gpd.GeoDataFrame, tags: Dict[str, str], outputs_path: str, + aoi_path: str, log_db: Optional[PipelineLogDB] = None, ): self.config = config @@ -49,6 +50,7 @@ def __init__( self.data_svc = data_svc self.polygon_gdf = polygon_gdf self.tags = tags + self.aoi_path = aoi_path self.log_db = log_db # Ensure the temp directory exists temp_dir = "/tmp" @@ -66,11 +68,13 @@ def __init__( self.benchmark_scenarios: Dict[str, Dict[str, List[str]]] = {} self.stac_results: Dict[str, Dict[str, Dict[str, List[str]]]] = {} - async def initialize(self) -> None: - """Query for catchments and flow scenarios.""" + async def initialize(self) -> Optional[Dict[str, Any]]: + """Query for catchments and flow scenarios. Returns early exit info if no data found.""" # Query STAC for flow scenarios (always required) logger.debug("Querying STAC for flow scenarios") - stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf) + stac_data = await self.data_svc.query_stac_for_flow_scenarios( + self.polygon_gdf, self.tags + ) self.flow_scenarios = stac_data.get("combined_flowfiles", {}) # Extract benchmark rasters from STAC scenarios @@ -92,9 +96,15 @@ async def initialize(self) -> None: if self.flow_scenarios: logger.debug(f"Found {len(self.flow_scenarios)} collections") - - if not self.flow_scenarios: - raise RuntimeError("No flow scenarios found") + else: + logger.warning("No flow scenarios found in STAC query results") + return { + "status": "no_data", + "message": "No flow scenarios found for the given polygon", + "catchment_count": 0, + "total_scenarios_attempted": 0, + "successful_scenarios": 0, + } # Query hand index for catchments logger.debug("Querying hand index for catchments") @@ -102,20 +112,38 @@ async def initialize(self) -> None: self.catchments = data.get("catchments", {}) if not self.catchments: - raise RuntimeError("No catchments found") + logger.warning("No catchments found in hand index query results") + return { + "status": "no_data", + "message": "No catchments found for the given polygon", + "catchment_count": 0, + "total_scenarios_attempted": 0, + "successful_scenarios": 0, + } - total_scenarios = sum(len(scenarios) for scenarios in self.flow_scenarios.values()) - logger.info(f"Initialization complete: {len(self.catchments)} catchments, " f"{total_scenarios} flow scenarios") + total_scenarios = sum( + len(scenarios) for scenarios in self.flow_scenarios.values() + ) + logger.info( + f"Initialization complete: {len(self.catchments)} catchments, " + f"{total_scenarios} flow scenarios" + ) + return None async def run(self) -> Dict[str, Any]: """Run the pipeline with stage-based parallelism.""" - await self.initialize() + early_exit = await self.initialize() + if early_exit is not None: + logger.info(f"Pipeline exiting early: {early_exit['message']}") + return early_exit # Build scenario results results = [] for collection, flows in self.flow_scenarios.items(): for scenario, flowfile_path in flows.items(): - benchmark_rasters = self.benchmark_scenarios.get(collection, {}).get(scenario, []) + benchmark_rasters = self.benchmark_scenarios.get( + collection, {} + ).get(scenario, []) result = PipelineResult( scenario_id=f"{collection}-{scenario}", collection_name=collection, @@ -133,7 +161,9 @@ async def run(self) -> Dict[str, Any]: ) results.append(result) - logger.debug(f"Processing {len(results)} scenarios with stage-based parallelism") + logger.debug( + f"Processing {len(results)} scenarios with stage-based parallelism" + ) try: inundation_stage = InundationStage( @@ -144,13 +174,27 @@ async def run(self) -> Dict[str, Any]: self.tags, self.catchments, ) - mosaic_stage = MosaicStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags) - agreement_stage = AgreementStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags) + mosaic_stage = MosaicStage( + self.config, + self.nomad, + self.data_svc, + self.path_factory, + self.tags, + self.aoi_path, + ) + agreement_stage = AgreementStage( + self.config, + self.nomad, + self.data_svc, + self.path_factory, + self.tags, + ) results = await inundation_stage.run(results) results = await mosaic_stage.run(results) results = await agreement_stage.run(results) + # Save results to JSON file if results: try: results_json_path = self.path_factory.results_json_path() @@ -170,13 +214,19 @@ async def run(self) -> Dict[str, Any]: } serializable_results.append(result_dict) - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as temp_file: + # Write to temporary file first, then copy to final location + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as temp_file: json.dump(serializable_results, temp_file, indent=2) temp_json_path = temp_file.name - await self.data_svc.copy_file_to_uri(temp_json_path, results_json_path) + await self.data_svc.copy_file_to_uri( + temp_json_path, results_json_path + ) logger.info(f"Results JSON written to {results_json_path}") + # Clean up temp file if os.path.exists(temp_json_path): os.unlink(temp_json_path) @@ -190,9 +240,15 @@ async def run(self) -> Dict[str, Any]: outputs_path=str(self.path_factory.base), stac_results=self.stac_results, data_service=self.data_svc, + flow_scenarios=self.flow_scenarios, + aoi_name=self.path_factory.aoi_name, + ) + metrics_path = aggregator.save_results( + self.path_factory.results_path() + ) + logger.info( + f"Metrics aggregation completed: {metrics_path}" ) - metrics_path = aggregator.save_results(self.path_factory.metrics_path()) - logger.info(f"Metrics aggregation completed: {metrics_path}") except Exception as e: logger.error(f"Metrics aggregation failed: {e}") @@ -206,8 +262,13 @@ async def run(self) -> Dict[str, Any]: "successful_scenarios": len(successful_results), "message": f"Pipeline completed successfully with {len(successful_results)}/{total_attempted} scenarios", } - logger.info(f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed") + logger.info( + f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed" + ) return summary + except DataServiceException as e: + logger.error(f"Pipeline FAILED due to data service error: {str(e)}") + return {"status": "failed", "error": str(e), "message": f"Data service error: {str(e)}"} except Exception as e: logger.error(f"Pipeline FAILED: {str(e)}") return { @@ -230,7 +291,9 @@ def parsed_tags(tag_list): for tag in tag_list: if "=" not in tag: - raise argparse.ArgumentTypeError(f"Invalid tag format: '{tag}'. Expected key=value.") + raise argparse.ArgumentTypeError( + f"Invalid tag format: '{tag}'. Expected key=value." + ) key, value = tag.split("=", 1) for char in forbidden_chars: @@ -261,28 +324,39 @@ def parsed_tags(tag_list): if tags: tags_str = ",".join(f"{k}={v}" for k, v in tags.items()) - if len(tags_str) > 120: - raise argparse.ArgumentTypeError(f"Tags exceed 120 character limit ({len(tags_str)} chars): {tags_str}") + if len(tags_str) > 150: + raise argparse.ArgumentTypeError( + f"Tags exceed 150 character limit ({len(tags_str)} chars): {tags_str}" + ) return tags if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run one PolygonPipeline in isolation") + parser = argparse.ArgumentParser( + description="Run one PolygonPipeline in isolation" + ) parser.add_argument( "--aoi", type=str, required=True, help="File path to a GPKG containing a single polygon. If more than one layer/feature, only the first is used.", ) - parser.add_argument("--outputs_path", type=str, required=True, help="Output directory path") + parser.add_argument( + "--outputs_path", type=str, required=True, help="Output directory path" + ) parser.add_argument( "--benchmark_sources", type=str, default=None, help="Comma-separated list of STAC collections to query (e.g., 'ble-collection,nwm-collection'). Defaults to all available sources.", ) - parser.add_argument("--hand_index_path", type=str, required=True, help="Path to HAND index data (required)") + parser.add_argument( + "--hand_index_path", + type=str, + required=True, + help="Path to HAND index data (required)", + ) parser.add_argument( "--tags", @@ -292,6 +366,12 @@ def parsed_tags(tag_list): help="List of key=value pairs for tagging (e.g., --tags batch=my_batch aoi=texas) These tags are included in job_ids that the pipeline will dispatch.", ) + parser.add_argument( + "--aoi_is_item", + action="store_true", + help="If set, treat the aoi_name tag as a STAC item ID for direct querying instead of performing spatial queries", + ) + args = parser.parse_args() if args.tags and args.tags != [""]: @@ -311,13 +391,17 @@ def parsed_tags(tag_list): format="%(asctime)s %(levelname)s %(message)s", ) - temp_log_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix="_pipeline.log") + temp_log_file = tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix="_pipeline.log" + ) temp_log_path = temp_log_file.name temp_log_file.close() file_handler = logging.FileHandler(temp_log_path) file_handler.setLevel(os.environ.get("LOG_LEVEL", "INFO")) - file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(message)s") + ) root_logger = logging.getLogger() root_logger.addHandler(file_handler) @@ -335,8 +419,12 @@ async def _main(): raise ValueError(f"Output directory already exists: {outputs_path}") timeout = aiohttp.ClientTimeout(total=160, connect=40, sock_read=60) - connector = aiohttp.TCPConnector(limit=cfg.defaults.http_connection_limit) - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + connector = aiohttp.TCPConnector( + limit=cfg.defaults.http_connection_limit + ) + async with aiohttp.ClientSession( + timeout=timeout, connector=connector + ) as session: log_db = PipelineLogDB("pipeline_log.db") await log_db.initialize() @@ -344,18 +432,24 @@ async def _main(): nomad_addr=cfg.nomad.address, namespace=cfg.nomad.namespace, token=cfg.nomad.token, - session=session, log_db=log_db, + max_concurrent_dispatch=cfg.defaults.nomad_max_concurrent_dispatch, ) await nomad.start() # if no benchmark collections provided all collections queried benchmark_collections = None if args.benchmark_sources: - benchmark_collections = [col.strip() for col in args.benchmark_sources.split(",")] - logging.info(f"Using benchmark sources: {benchmark_collections}") + benchmark_collections = [ + col.strip() for col in args.benchmark_sources.split(",") + ] + logging.info( + f"Using benchmark sources: {benchmark_collections}" + ) - data_svc = DataService(cfg, args.hand_index_path, benchmark_collections) + data_svc = DataService( + cfg, args.hand_index_path, benchmark_collections, args.aoi_is_item + ) logging.info(f"Loading polygon from: {args.aoi}") polygon_gdf = data_svc.load_polygon_gdf_from_file(args.aoi) @@ -365,17 +459,32 @@ async def _main(): raise ValueError("GPKG file contains no features") if len(polygon_gdf) > 1: - logging.warning(f"Found {len(polygon_gdf)} features in {args.aoi}, using only the first one") + logging.warning( + f"Found {len(polygon_gdf)} features in {args.aoi}, using only the first one" + ) polygon_gdf = polygon_gdf.iloc[[0]] geom = polygon_gdf.geometry.iloc[0] if geom.geom_type != "Polygon": - raise ValueError(f"Feature must be POLYGON type, got: {geom.geom_type}") + raise ValueError( + f"Feature must be POLYGON type, got: {geom.geom_type}" + ) logging.info(f"Using HAND index path: {args.hand_index_path}") - pipeline = PolygonPipeline(cfg, nomad, data_svc, polygon_gdf, args.tags, outputs_path, log_db) - logging.info(f"Started pipeline run for {args.aoi} with outputs to {outputs_path}") + pipeline = PolygonPipeline( + cfg, + nomad, + data_svc, + polygon_gdf, + args.tags, + outputs_path, + args.aoi, + log_db, + ) + logging.info( + f"Started pipeline run for {args.aoi} with outputs to {outputs_path}" + ) try: summary = await pipeline.run() @@ -389,7 +498,7 @@ async def _main(): root_logger.removeHandler(file_handler) final_log_path = pipeline.path_factory.logs_path() - await data_svc.copy_file_to_uri(temp_log_path, final_log_path) + await data_svc.append_file_to_uri(temp_log_path, final_log_path) logging.info(f"Logs written to {final_log_path}") print(json.dumps(summary, indent=2)) From bf29a294a19a6a43f4aa160f4108875141c50bd4 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Wed, 10 Sep 2025 15:10:20 -0400 Subject: [PATCH 2/2] Reformat line lengths --- src/main.py | 97 +++++++++++++---------------------------------------- 1 file changed, 24 insertions(+), 73 deletions(-) diff --git a/src/main.py b/src/main.py index b1b4809..d238212 100644 --- a/src/main.py +++ b/src/main.py @@ -72,9 +72,7 @@ async def initialize(self) -> Optional[Dict[str, Any]]: """Query for catchments and flow scenarios. Returns early exit info if no data found.""" # Query STAC for flow scenarios (always required) logger.debug("Querying STAC for flow scenarios") - stac_data = await self.data_svc.query_stac_for_flow_scenarios( - self.polygon_gdf, self.tags - ) + stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf, self.tags) self.flow_scenarios = stac_data.get("combined_flowfiles", {}) # Extract benchmark rasters from STAC scenarios @@ -121,13 +119,8 @@ async def initialize(self) -> Optional[Dict[str, Any]]: "successful_scenarios": 0, } - total_scenarios = sum( - len(scenarios) for scenarios in self.flow_scenarios.values() - ) - logger.info( - f"Initialization complete: {len(self.catchments)} catchments, " - f"{total_scenarios} flow scenarios" - ) + total_scenarios = sum(len(scenarios) for scenarios in self.flow_scenarios.values()) + logger.info(f"Initialization complete: {len(self.catchments)} catchments, {total_scenarios} flow scenarios") return None async def run(self) -> Dict[str, Any]: @@ -141,9 +134,7 @@ async def run(self) -> Dict[str, Any]: results = [] for collection, flows in self.flow_scenarios.items(): for scenario, flowfile_path in flows.items(): - benchmark_rasters = self.benchmark_scenarios.get( - collection, {} - ).get(scenario, []) + benchmark_rasters = self.benchmark_scenarios.get(collection, {}).get(scenario, []) result = PipelineResult( scenario_id=f"{collection}-{scenario}", collection_name=collection, @@ -161,9 +152,7 @@ async def run(self) -> Dict[str, Any]: ) results.append(result) - logger.debug( - f"Processing {len(results)} scenarios with stage-based parallelism" - ) + logger.debug(f"Processing {len(results)} scenarios with stage-based parallelism") try: inundation_stage = InundationStage( @@ -215,15 +204,11 @@ async def run(self) -> Dict[str, Any]: serializable_results.append(result_dict) # Write to temporary file first, then copy to final location - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as temp_file: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as temp_file: json.dump(serializable_results, temp_file, indent=2) temp_json_path = temp_file.name - await self.data_svc.copy_file_to_uri( - temp_json_path, results_json_path - ) + await self.data_svc.copy_file_to_uri(temp_json_path, results_json_path) logger.info(f"Results JSON written to {results_json_path}") # Clean up temp file @@ -243,12 +228,8 @@ async def run(self) -> Dict[str, Any]: flow_scenarios=self.flow_scenarios, aoi_name=self.path_factory.aoi_name, ) - metrics_path = aggregator.save_results( - self.path_factory.results_path() - ) - logger.info( - f"Metrics aggregation completed: {metrics_path}" - ) + metrics_path = aggregator.save_results(self.path_factory.results_path()) + logger.info(f"Metrics aggregation completed: {metrics_path}") except Exception as e: logger.error(f"Metrics aggregation failed: {e}") @@ -262,9 +243,7 @@ async def run(self) -> Dict[str, Any]: "successful_scenarios": len(successful_results), "message": f"Pipeline completed successfully with {len(successful_results)}/{total_attempted} scenarios", } - logger.info( - f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed" - ) + logger.info(f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed") return summary except DataServiceException as e: logger.error(f"Pipeline FAILED due to data service error: {str(e)}") @@ -291,9 +270,7 @@ def parsed_tags(tag_list): for tag in tag_list: if "=" not in tag: - raise argparse.ArgumentTypeError( - f"Invalid tag format: '{tag}'. Expected key=value." - ) + raise argparse.ArgumentTypeError(f"Invalid tag format: '{tag}'. Expected key=value.") key, value = tag.split("=", 1) for char in forbidden_chars: @@ -325,26 +302,20 @@ def parsed_tags(tag_list): if tags: tags_str = ",".join(f"{k}={v}" for k, v in tags.items()) if len(tags_str) > 150: - raise argparse.ArgumentTypeError( - f"Tags exceed 150 character limit ({len(tags_str)} chars): {tags_str}" - ) + raise argparse.ArgumentTypeError(f"Tags exceed 150 character limit ({len(tags_str)} chars): {tags_str}") return tags if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run one PolygonPipeline in isolation" - ) + parser = argparse.ArgumentParser(description="Run one PolygonPipeline in isolation") parser.add_argument( "--aoi", type=str, required=True, help="File path to a GPKG containing a single polygon. If more than one layer/feature, only the first is used.", ) - parser.add_argument( - "--outputs_path", type=str, required=True, help="Output directory path" - ) + parser.add_argument("--outputs_path", type=str, required=True, help="Output directory path") parser.add_argument( "--benchmark_sources", type=str, @@ -391,17 +362,13 @@ def parsed_tags(tag_list): format="%(asctime)s %(levelname)s %(message)s", ) - temp_log_file = tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix="_pipeline.log" - ) + temp_log_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix="_pipeline.log") temp_log_path = temp_log_file.name temp_log_file.close() file_handler = logging.FileHandler(temp_log_path) file_handler.setLevel(os.environ.get("LOG_LEVEL", "INFO")) - file_handler.setFormatter( - logging.Formatter("%(asctime)s %(levelname)s %(message)s") - ) + file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) root_logger = logging.getLogger() root_logger.addHandler(file_handler) @@ -419,12 +386,8 @@ async def _main(): raise ValueError(f"Output directory already exists: {outputs_path}") timeout = aiohttp.ClientTimeout(total=160, connect=40, sock_read=60) - connector = aiohttp.TCPConnector( - limit=cfg.defaults.http_connection_limit - ) - async with aiohttp.ClientSession( - timeout=timeout, connector=connector - ) as session: + connector = aiohttp.TCPConnector(limit=cfg.defaults.http_connection_limit) + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: log_db = PipelineLogDB("pipeline_log.db") await log_db.initialize() @@ -440,16 +403,10 @@ async def _main(): # if no benchmark collections provided all collections queried benchmark_collections = None if args.benchmark_sources: - benchmark_collections = [ - col.strip() for col in args.benchmark_sources.split(",") - ] - logging.info( - f"Using benchmark sources: {benchmark_collections}" - ) + benchmark_collections = [col.strip() for col in args.benchmark_sources.split(",")] + logging.info(f"Using benchmark sources: {benchmark_collections}") - data_svc = DataService( - cfg, args.hand_index_path, benchmark_collections, args.aoi_is_item - ) + data_svc = DataService(cfg, args.hand_index_path, benchmark_collections, args.aoi_is_item) logging.info(f"Loading polygon from: {args.aoi}") polygon_gdf = data_svc.load_polygon_gdf_from_file(args.aoi) @@ -459,16 +416,12 @@ async def _main(): raise ValueError("GPKG file contains no features") if len(polygon_gdf) > 1: - logging.warning( - f"Found {len(polygon_gdf)} features in {args.aoi}, using only the first one" - ) + logging.warning(f"Found {len(polygon_gdf)} features in {args.aoi}, using only the first one") polygon_gdf = polygon_gdf.iloc[[0]] geom = polygon_gdf.geometry.iloc[0] if geom.geom_type != "Polygon": - raise ValueError( - f"Feature must be POLYGON type, got: {geom.geom_type}" - ) + raise ValueError(f"Feature must be POLYGON type, got: {geom.geom_type}") logging.info(f"Using HAND index path: {args.hand_index_path}") @@ -482,9 +435,7 @@ async def _main(): args.aoi, log_db, ) - logging.info( - f"Started pipeline run for {args.aoi} with outputs to {outputs_path}" - ) + logging.info(f"Started pipeline run for {args.aoi} with outputs to {outputs_path}") try: summary = await pipeline.run()