diff --git a/spectf_cloud/deploy/deploy_pt.py b/spectf_cloud/deploy/deploy_pt.py index 1f54631..0d1a295 100644 --- a/spectf_cloud/deploy/deploy_pt.py +++ b/spectf_cloud/deploy/deploy_pt.py @@ -14,7 +14,6 @@ import yaml import rich_click as click import numpy as np -from osgeo import gdal import torch from torch import nn @@ -22,22 +21,13 @@ from spectf.model import SpecTfEncoder from spectf.dataset import RasterDatasetTOA +from spectf_cloud.deploy.gen_geotiff import make_geotiff from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR PRECISION = torch.bfloat16 ENV_VAR_PREFIX = 'SPECTF_DEPLOY_' -numpy_to_gdal = { - np.dtype(np.float64): 7, - np.dtype(np.float32): 6, - np.dtype(np.int32): 5, - np.dtype(np.uint32): 4, - np.dtype(np.int16): 3, - np.dtype(np.uint16): 2, - np.dtype(np.uint8): 1, -} - # TODO: Refactor this into the CLI # Configure logging logging.basicConfig( @@ -79,7 +69,7 @@ "--proba", is_flag=True, default=False, - help="Output probability map instead of binary cloud mask.", + help="Output probability map with the binary cloud mask.", envvar=f"{ENV_VAR_PREFIX}PROBA", ) @click.option( @@ -218,33 +208,7 @@ def deploy_pt( logging.info("Inference complete.") - # Account for NODATA values and threshold - if proba: - cloud_mask[np.isnan(cloud_mask)] = -9999 - else: - cloud_mask[cloud_mask < threshold] = 0 - cloud_mask[cloud_mask > 0] = 1 - cloud_mask[np.isnan(cloud_mask)] = 255 - cloud_mask = cloud_mask.astype(np.uint8) - - - # Reshape into input shape - cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1)) - - driver = gdal.GetDriverByName('MEM') - ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) - ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) - - # Set NODATA value - if proba: - ds.GetRasterBand(1).SetNoDataValue(-9999) - else: - ds.GetRasterBand(1).SetNoDataValue(255) - - tiff_driver = gdal.GetDriverByName('GTiff') - _ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256']) - - logging.info("Cloud mask saved to %s", outfp) + make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold) if __name__ == "__main__": print(MAIN_CALL_ERR_MSG % "deploy-pt") \ No newline at end of file diff --git a/spectf_cloud/deploy/deploy_trt.py b/spectf_cloud/deploy/deploy_trt.py index f8f9764..2538ede 100644 --- a/spectf_cloud/deploy/deploy_trt.py +++ b/spectf_cloud/deploy/deploy_trt.py @@ -22,6 +22,7 @@ from spectf.model import BandConcat from spectf.dataset import RasterDatasetTOA +from spectf_cloud.deploy.gen_geotiff import make_geotiff from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR from spectf_cloud.deploy.tensor_rt_model import load_model_network_engine @@ -32,15 +33,6 @@ PRECISION = torch.bfloat16 ENV_VAR_PREFIX = 'SPECTF_DEPLOY_' -numpy_to_gdal = { - np.dtype(np.float64): 7, - np.dtype(np.float32): 6, - np.dtype(np.int32): 5, - np.dtype(np.uint32): 4, - np.dtype(np.int16): 3, - np.dtype(np.uint16): 2, - np.dtype(np.uint8): 1, -} # TODO: Refactor this into the CLI # Configure logging @@ -243,34 +235,9 @@ def deploy_trt( logging.info("Inference complete.") - # Account for NODATA values and threshold - if proba: - cloud_mask[np.isnan(cloud_mask)] = -9999 - else: - cloud_mask[cloud_mask < threshold] = 0 - cloud_mask[cloud_mask > 0] = 1 - cloud_mask[np.isnan(cloud_mask)] = 255 - cloud_mask = cloud_mask.astype(np.uint8) + make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold) - # Reshape into input shape - cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1)) - - driver = gdal.GetDriverByName('MEM') - ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) - ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) - - # Set NODATA value - if proba: - ds.GetRasterBand(1).SetNoDataValue(-9999) - else: - ds.GetRasterBand(1).SetNoDataValue(255) - - tiff_driver = gdal.GetDriverByName('GTiff') - _ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256']) - - logging.info("Cloud mask saved to %s", outfp) - -def pad_batch(b: torch.tensor, target_bsz:int): +def pad_batch(b: torch.Tensor, target_bsz:int): # Pad w/ zeros padded_shape = (target_bsz,) + b.shape[1:] padded_batch = torch.zeros( diff --git a/spectf_cloud/deploy/gen_geotiff.py b/spectf_cloud/deploy/gen_geotiff.py new file mode 100644 index 0000000..817beec --- /dev/null +++ b/spectf_cloud/deploy/gen_geotiff.py @@ -0,0 +1,52 @@ +import numpy as np +from osgeo import gdal +import logging +from pathlib import Path + +BINARY_NO_DATA_VAL = 255 +PROBA_NO_DATA_VAL = -9999 + +numpy_to_gdal = { + np.dtype(np.float64): 7, + np.dtype(np.float32): 6, + np.dtype(np.int32): 5, + np.dtype(np.uint32): 4, + np.dtype(np.int16): 3, + np.dtype(np.uint16): 2, + np.dtype(np.uint8): 1, +} + +def make_geotiff(cloud_mask: np.ndarray, dataset_shape: tuple, outfp: str, proba: bool, threshold: float): + nan_mask = np.isnan(cloud_mask) + + mem_driver, tiff_driver = gdal.GetDriverByName('MEM'), gdal.GetDriverByName('GTiff') + opts = ['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256'] + + if proba: + old_shape = cloud_mask.shape + cloud_mask[nan_mask] = PROBA_NO_DATA_VAL + cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1)) + ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) + ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) + ds.GetRasterBand(1).SetNoDataValue(PROBA_NO_DATA_VAL) + + _op = Path(outfp) + _op_s = str(_op.with_name(f"{_op.stem}_prob{_op.suffix}")) + _ = tiff_driver.CreateCopy(_op_s, ds, options=opts) + + logging.info("Probability cloud mask saved to %s", _op_s) + cloud_mask = cloud_mask.reshape(old_shape) + + cloud_mask[cloud_mask < threshold] = 0 + cloud_mask[cloud_mask > 0] = 1 + cloud_mask[nan_mask] = BINARY_NO_DATA_VAL + cloud_mask = cloud_mask.astype(np.uint8) + + # Reshape into input shape + cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1)) + + ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype]) + ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0]) + ds.GetRasterBand(1).SetNoDataValue(BINARY_NO_DATA_VAL) + + _ = tiff_driver.CreateCopy(outfp, ds, options=opts)