diff --git a/src/main.py b/src/main.py index 1f6fd3a..d238212 100644 --- a/src/main.py +++ b/src/main.py @@ -14,7 +14,7 @@ import geopandas as gpd import default_config -from data_service import DataService +from data_service import DataService, DataServiceException from load_config import AppConfig, load_config from metrics_aggregator import MetricsAggregator from nomad_job_manager import NomadJobManager @@ -42,6 +42,7 @@ def __init__( polygon_gdf: gpd.GeoDataFrame, tags: Dict[str, str], outputs_path: str, + aoi_path: str, log_db: Optional[PipelineLogDB] = None, ): self.config = config @@ -49,6 +50,7 @@ def __init__( self.data_svc = data_svc self.polygon_gdf = polygon_gdf self.tags = tags + self.aoi_path = aoi_path self.log_db = log_db # Ensure the temp directory exists temp_dir = "/tmp" @@ -66,11 +68,11 @@ def __init__( self.benchmark_scenarios: Dict[str, Dict[str, List[str]]] = {} self.stac_results: Dict[str, Dict[str, Dict[str, List[str]]]] = {} - async def initialize(self) -> None: - """Query for catchments and flow scenarios.""" + async def initialize(self) -> Optional[Dict[str, Any]]: + """Query for catchments and flow scenarios. Returns early exit info if no data found.""" # Query STAC for flow scenarios (always required) logger.debug("Querying STAC for flow scenarios") - stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf) + stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf, self.tags) self.flow_scenarios = stac_data.get("combined_flowfiles", {}) # Extract benchmark rasters from STAC scenarios @@ -92,9 +94,15 @@ async def initialize(self) -> None: if self.flow_scenarios: logger.debug(f"Found {len(self.flow_scenarios)} collections") - - if not self.flow_scenarios: - raise RuntimeError("No flow scenarios found") + else: + logger.warning("No flow scenarios found in STAC query results") + return { + "status": "no_data", + "message": "No flow scenarios found for the given polygon", + "catchment_count": 0, + "total_scenarios_attempted": 0, + "successful_scenarios": 0, + } # Query hand index for catchments logger.debug("Querying hand index for catchments") @@ -102,14 +110,25 @@ async def initialize(self) -> None: self.catchments = data.get("catchments", {}) if not self.catchments: - raise RuntimeError("No catchments found") + logger.warning("No catchments found in hand index query results") + return { + "status": "no_data", + "message": "No catchments found for the given polygon", + "catchment_count": 0, + "total_scenarios_attempted": 0, + "successful_scenarios": 0, + } total_scenarios = sum(len(scenarios) for scenarios in self.flow_scenarios.values()) - logger.info(f"Initialization complete: {len(self.catchments)} catchments, " f"{total_scenarios} flow scenarios") + logger.info(f"Initialization complete: {len(self.catchments)} catchments, {total_scenarios} flow scenarios") + return None async def run(self) -> Dict[str, Any]: """Run the pipeline with stage-based parallelism.""" - await self.initialize() + early_exit = await self.initialize() + if early_exit is not None: + logger.info(f"Pipeline exiting early: {early_exit['message']}") + return early_exit # Build scenario results results = [] @@ -144,13 +163,27 @@ async def run(self) -> Dict[str, Any]: self.tags, self.catchments, ) - mosaic_stage = MosaicStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags) - agreement_stage = AgreementStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags) + mosaic_stage = MosaicStage( + self.config, + self.nomad, + self.data_svc, + self.path_factory, + self.tags, + self.aoi_path, + ) + agreement_stage = AgreementStage( + self.config, + self.nomad, + self.data_svc, + self.path_factory, + self.tags, + ) results = await inundation_stage.run(results) results = await mosaic_stage.run(results) results = await agreement_stage.run(results) + # Save results to JSON file if results: try: results_json_path = self.path_factory.results_json_path() @@ -170,6 +203,7 @@ async def run(self) -> Dict[str, Any]: } serializable_results.append(result_dict) + # Write to temporary file first, then copy to final location with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as temp_file: json.dump(serializable_results, temp_file, indent=2) temp_json_path = temp_file.name @@ -177,6 +211,7 @@ async def run(self) -> Dict[str, Any]: await self.data_svc.copy_file_to_uri(temp_json_path, results_json_path) logger.info(f"Results JSON written to {results_json_path}") + # Clean up temp file if os.path.exists(temp_json_path): os.unlink(temp_json_path) @@ -190,8 +225,10 @@ async def run(self) -> Dict[str, Any]: outputs_path=str(self.path_factory.base), stac_results=self.stac_results, data_service=self.data_svc, + flow_scenarios=self.flow_scenarios, + aoi_name=self.path_factory.aoi_name, ) - metrics_path = aggregator.save_results(self.path_factory.metrics_path()) + metrics_path = aggregator.save_results(self.path_factory.results_path()) logger.info(f"Metrics aggregation completed: {metrics_path}") except Exception as e: logger.error(f"Metrics aggregation failed: {e}") @@ -208,6 +245,9 @@ async def run(self) -> Dict[str, Any]: } logger.info(f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed") return summary + except DataServiceException as e: + logger.error(f"Pipeline FAILED due to data service error: {str(e)}") + return {"status": "failed", "error": str(e), "message": f"Data service error: {str(e)}"} except Exception as e: logger.error(f"Pipeline FAILED: {str(e)}") return { @@ -261,8 +301,8 @@ def parsed_tags(tag_list): if tags: tags_str = ",".join(f"{k}={v}" for k, v in tags.items()) - if len(tags_str) > 120: - raise argparse.ArgumentTypeError(f"Tags exceed 120 character limit ({len(tags_str)} chars): {tags_str}") + if len(tags_str) > 150: + raise argparse.ArgumentTypeError(f"Tags exceed 150 character limit ({len(tags_str)} chars): {tags_str}") return tags @@ -282,7 +322,12 @@ def parsed_tags(tag_list): default=None, help="Comma-separated list of STAC collections to query (e.g., 'ble-collection,nwm-collection'). Defaults to all available sources.", ) - parser.add_argument("--hand_index_path", type=str, required=True, help="Path to HAND index data (required)") + parser.add_argument( + "--hand_index_path", + type=str, + required=True, + help="Path to HAND index data (required)", + ) parser.add_argument( "--tags", @@ -292,6 +337,12 @@ def parsed_tags(tag_list): help="List of key=value pairs for tagging (e.g., --tags batch=my_batch aoi=texas) These tags are included in job_ids that the pipeline will dispatch.", ) + parser.add_argument( + "--aoi_is_item", + action="store_true", + help="If set, treat the aoi_name tag as a STAC item ID for direct querying instead of performing spatial queries", + ) + args = parser.parse_args() if args.tags and args.tags != [""]: @@ -344,8 +395,8 @@ async def _main(): nomad_addr=cfg.nomad.address, namespace=cfg.nomad.namespace, token=cfg.nomad.token, - session=session, log_db=log_db, + max_concurrent_dispatch=cfg.defaults.nomad_max_concurrent_dispatch, ) await nomad.start() @@ -355,7 +406,7 @@ async def _main(): benchmark_collections = [col.strip() for col in args.benchmark_sources.split(",")] logging.info(f"Using benchmark sources: {benchmark_collections}") - data_svc = DataService(cfg, args.hand_index_path, benchmark_collections) + data_svc = DataService(cfg, args.hand_index_path, benchmark_collections, args.aoi_is_item) logging.info(f"Loading polygon from: {args.aoi}") polygon_gdf = data_svc.load_polygon_gdf_from_file(args.aoi) @@ -374,7 +425,16 @@ async def _main(): logging.info(f"Using HAND index path: {args.hand_index_path}") - pipeline = PolygonPipeline(cfg, nomad, data_svc, polygon_gdf, args.tags, outputs_path, log_db) + pipeline = PolygonPipeline( + cfg, + nomad, + data_svc, + polygon_gdf, + args.tags, + outputs_path, + args.aoi, + log_db, + ) logging.info(f"Started pipeline run for {args.aoi} with outputs to {outputs_path}") try: @@ -389,7 +449,7 @@ async def _main(): root_logger.removeHandler(file_handler) final_log_path = pipeline.path_factory.logs_path() - await data_svc.copy_file_to_uri(temp_log_path, final_log_path) + await data_svc.append_file_to_uri(temp_log_path, final_log_path) logging.info(f"Logs written to {final_log_path}") print(json.dumps(summary, indent=2))