From b7c3e5a89edbacd787e7650997ac811f7abae895 Mon Sep 17 00:00:00 2001 From: makiper Date: Wed, 6 Aug 2025 18:48:16 -0700 Subject: [PATCH 1/3] Added support for separate geotiff --- spectf_cloud/deploy/deploy_pt.py | 42 ++---------------------- spectf_cloud/deploy/deploy_trt.py | 39 ++--------------------- spectf_cloud/deploy/gen_geotiff.py | 51 ++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 75 deletions(-) create mode 100644 spectf_cloud/deploy/gen_geotiff.py diff --git a/spectf_cloud/deploy/deploy_pt.py b/spectf_cloud/deploy/deploy_pt.py index daedffa..e8c77c6 100644 --- a/spectf_cloud/deploy/deploy_pt.py +++ b/spectf_cloud/deploy/deploy_pt.py @@ -14,7 +14,6 @@ import yaml import rich_click as click import numpy as np -from osgeo import gdal import torch from torch import nn @@ -22,22 +21,13 @@ from spectf.model import SpecTfEncoder from spectf.dataset import RasterDatasetTOA +from spectf_cloud.deploy.gen_geotiff import make_geotiff from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR PRECISION = torch.bfloat16 ENV_VAR_PREFIX = 'SPECTF_DEPLOY_' -numpy_to_gdal = { - np.dtype(np.float64): 7, - np.dtype(np.float32): 6, - np.dtype(np.int32): 5, - np.dtype(np.uint32): 4, - np.dtype(np.int16): 3, - np.dtype(np.uint16): 2, - np.dtype(np.uint8): 1, -} - # TODO: Refactor this into the CLI # Configure logging logging.basicConfig( @@ -79,7 +69,7 @@ "--proba", is_flag=True, default=False, - help="Output probability map instead of binary cloud mask.", + help="Output probability map with the binary cloud mask.", envvar=f"{ENV_VAR_PREFIX}PROBA", ) @click.option( @@ -213,33 +203,7 @@ def deploy_pt( logging.info("Inference complete.") - # Account for NODATA values and threshold - if proba: - cloud_mask[np.isnan(cloud_mask)] = -9999 - else: - cloud_mask[cloud_mask < threshold] = 0 - cloud_mask[cloud_mask > 0] = 1 - cloud_mask[np.isnan(cloud_mask)] = 255 - cloud_mask = cloud_mask.astype(np.uint8) - - - # Reshape into input shape - cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1)) - - driver = gdal.GetDriverByName('MEM') - ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) - ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) - - # Set NODATA value - if proba: - ds.GetRasterBand(1).SetNoDataValue(-9999) - else: - ds.GetRasterBand(1).SetNoDataValue(255) - - tiff_driver = gdal.GetDriverByName('GTiff') - _ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256']) - - logging.info("Cloud mask saved to %s", outfp) + make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold) if __name__ == "__main__": print(MAIN_CALL_ERR_MSG % "deploy-pt") \ No newline at end of file diff --git a/spectf_cloud/deploy/deploy_trt.py b/spectf_cloud/deploy/deploy_trt.py index 241275c..ad62405 100644 --- a/spectf_cloud/deploy/deploy_trt.py +++ b/spectf_cloud/deploy/deploy_trt.py @@ -22,6 +22,7 @@ from spectf.model import BandConcat from spectf.dataset import RasterDatasetTOA +from spectf_cloud.deploy.gen_geotiff import make_geotiff from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR from spectf_cloud.deploy.tensor_rt_model import load_model_network_engine @@ -32,15 +33,6 @@ PRECISION = torch.bfloat16 ENV_VAR_PREFIX = 'SPECTF_DEPLOY_' -numpy_to_gdal = { - np.dtype(np.float64): 7, - np.dtype(np.float32): 6, - np.dtype(np.int32): 5, - np.dtype(np.uint32): 4, - np.dtype(np.int16): 3, - np.dtype(np.uint16): 2, - np.dtype(np.uint8): 1, -} # TODO: Refactor this into the CLI # Configure logging @@ -238,34 +230,9 @@ def deploy_trt( logging.info("Inference complete.") - # Account for NODATA values and threshold - if proba: - cloud_mask[np.isnan(cloud_mask)] = -9999 - else: - cloud_mask[cloud_mask < threshold] = 0 - cloud_mask[cloud_mask > 0] = 1 - cloud_mask[np.isnan(cloud_mask)] = 255 - cloud_mask = cloud_mask.astype(np.uint8) + make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold) - # Reshape into input shape - cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1)) - - driver = gdal.GetDriverByName('MEM') - ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) - ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) - - # Set NODATA value - if proba: - ds.GetRasterBand(1).SetNoDataValue(-9999) - else: - ds.GetRasterBand(1).SetNoDataValue(255) - - tiff_driver = gdal.GetDriverByName('GTiff') - _ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256']) - - logging.info("Cloud mask saved to %s", outfp) - -def pad_batch(b: torch.tensor, target_bsz:int): +def pad_batch(b: torch.Tensor, target_bsz:int): # Pad w/ zeros padded_shape = (target_bsz,) + b.shape[1:] padded_batch = torch.zeros( diff --git a/spectf_cloud/deploy/gen_geotiff.py b/spectf_cloud/deploy/gen_geotiff.py new file mode 100644 index 0000000..aa2f199 --- /dev/null +++ b/spectf_cloud/deploy/gen_geotiff.py @@ -0,0 +1,51 @@ +import numpy as np +from osgeo import gdal +import logging + +BINARY_NO_DATA_VAL = 255 +PROBA_NO_DATA_VAL = -9999 + +numpy_to_gdal = { + np.dtype(np.float64): 7, + np.dtype(np.float32): 6, + np.dtype(np.int32): 5, + np.dtype(np.uint32): 4, + np.dtype(np.int16): 3, + np.dtype(np.uint16): 2, + np.dtype(np.uint8): 1, +} + +def make_geotiff(cloud_mask: np.ndarray, dataset_shape: tuple, outfp: str, proba: bool, threshold: float): + nan_mask = np.isnan(cloud_mask) + + mem_driver, tiff_driver = gdal.GetDriverByName('MEM'), gdal.GetDriverByName('GTiff') + opts = ['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256'] + + if proba: + old_shape = cloud_mask.shape + cloud_mask[nan_mask] = PROBA_NO_DATA_VAL + cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1)) + ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) + ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) + ds.GetRasterBand(1).SetNoDataValue(PROBA_NO_DATA_VAL) + + sp = str(outfp).split('.') + sp[-1] = '_proba'+sp[-1] + _ = tiff_driver.CreateCopy('.'.join(sp), ds, options=opts) + + logging.info("Probability cloud mask saved to %s", '.'.join(sp)) + cloud_mask = cloud_mask.reshape(old_shape) + + cloud_mask[cloud_mask < threshold] = 0 + cloud_mask[cloud_mask > 0] = 1 + cloud_mask[nan_mask] = BINARY_NO_DATA_VAL + cloud_mask = cloud_mask.astype(np.uint8) + + # Reshape into input shape + cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1)) + + ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) + ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) + ds.GetRasterBand(1).SetNoDataValue(BINARY_NO_DATA_VAL) + + _ = tiff_driver.CreateCopy(outfp, ds, options=opts) From 022a4d1521d6760a817a8b3a0b681f784d1b7894 Mon Sep 17 00:00:00 2001 From: makiper Date: Thu, 7 Aug 2025 13:26:24 -0700 Subject: [PATCH 2/3] minor file name change --- spectf_cloud/deploy/gen_geotiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spectf_cloud/deploy/gen_geotiff.py b/spectf_cloud/deploy/gen_geotiff.py index aa2f199..979f11a 100644 --- a/spectf_cloud/deploy/gen_geotiff.py +++ b/spectf_cloud/deploy/gen_geotiff.py @@ -30,7 +30,7 @@ def make_geotiff(cloud_mask: np.ndarray, dataset_shape: tuple, outfp: str, proba ds.GetRasterBand(1).SetNoDataValue(PROBA_NO_DATA_VAL) sp = str(outfp).split('.') - sp[-1] = '_proba'+sp[-1] + sp[-1] = 'prob.'+sp[-1] _ = tiff_driver.CreateCopy('.'.join(sp), ds, options=opts) logging.info("Probability cloud mask saved to %s", '.'.join(sp)) From f31059e006340aca7f056a8bdb622358546ac301 Mon Sep 17 00:00:00 2001 From: makiper Date: Fri, 22 Aug 2025 12:01:11 -0700 Subject: [PATCH 3/3] added pathlib support --- spectf_cloud/deploy/gen_geotiff.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spectf_cloud/deploy/gen_geotiff.py b/spectf_cloud/deploy/gen_geotiff.py index 979f11a..817beec 100644 --- a/spectf_cloud/deploy/gen_geotiff.py +++ b/spectf_cloud/deploy/gen_geotiff.py @@ -1,6 +1,7 @@ import numpy as np from osgeo import gdal import logging +from pathlib import Path BINARY_NO_DATA_VAL = 255 PROBA_NO_DATA_VAL = -9999 @@ -29,11 +30,11 @@ def make_geotiff(cloud_mask: np.ndarray, dataset_shape: tuple, outfp: str, proba ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) ds.GetRasterBand(1).SetNoDataValue(PROBA_NO_DATA_VAL) - sp = str(outfp).split('.') - sp[-1] = 'prob.'+sp[-1] - _ = tiff_driver.CreateCopy('.'.join(sp), ds, options=opts) + _op = Path(outfp) + _op_s = str(_op.with_name(f"{_op.stem}_prob{_op.suffix}")) + _ = tiff_driver.CreateCopy(_op_s, ds, options=opts) - logging.info("Probability cloud mask saved to %s", '.'.join(sp)) + logging.info("Probability cloud mask saved to %s", _op_s) cloud_mask = cloud_mask.reshape(old_shape) cloud_mask[cloud_mask < threshold] = 0