Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 3 additions & 39 deletions spectf_cloud/deploy/deploy_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,20 @@
import yaml
import rich_click as click
import numpy as np
from osgeo import gdal

import torch
from torch import nn
from torch.utils.data import DataLoader

from spectf.model import SpecTfEncoder
from spectf.dataset import RasterDatasetTOA
from spectf_cloud.deploy.gen_geotiff import make_geotiff
from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR

PRECISION = torch.bfloat16
ENV_VAR_PREFIX = 'SPECTF_DEPLOY_'


numpy_to_gdal = {
np.dtype(np.float64): 7,
np.dtype(np.float32): 6,
np.dtype(np.int32): 5,
np.dtype(np.uint32): 4,
np.dtype(np.int16): 3,
np.dtype(np.uint16): 2,
np.dtype(np.uint8): 1,
}

# TODO: Refactor this into the CLI
# Configure logging
logging.basicConfig(
Expand Down Expand Up @@ -79,7 +69,7 @@
"--proba",
is_flag=True,
default=False,
help="Output probability map instead of binary cloud mask.",
help="Output probability map with the binary cloud mask.",
envvar=f"{ENV_VAR_PREFIX}PROBA",
)
@click.option(
Expand Down Expand Up @@ -218,33 +208,7 @@ def deploy_pt(

logging.info("Inference complete.")

# Account for NODATA values and threshold
if proba:
cloud_mask[np.isnan(cloud_mask)] = -9999
else:
cloud_mask[cloud_mask < threshold] = 0
cloud_mask[cloud_mask > 0] = 1
cloud_mask[np.isnan(cloud_mask)] = 255
cloud_mask = cloud_mask.astype(np.uint8)


# Reshape into input shape
cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1))

driver = gdal.GetDriverByName('MEM')
ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype])
ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0])

# Set NODATA value
if proba:
ds.GetRasterBand(1).SetNoDataValue(-9999)
else:
ds.GetRasterBand(1).SetNoDataValue(255)

tiff_driver = gdal.GetDriverByName('GTiff')
_ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256'])

logging.info("Cloud mask saved to %s", outfp)
make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold)

if __name__ == "__main__":
print(MAIN_CALL_ERR_MSG % "deploy-pt")
39 changes: 3 additions & 36 deletions spectf_cloud/deploy/deploy_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from spectf.model import BandConcat
from spectf.dataset import RasterDatasetTOA
from spectf_cloud.deploy.gen_geotiff import make_geotiff
from spectf_cloud.cli import spectf_cloud, MAIN_CALL_ERR_MSG, DEFAULT_DIR

from spectf_cloud.deploy.tensor_rt_model import load_model_network_engine
Expand All @@ -32,15 +33,6 @@
PRECISION = torch.bfloat16
ENV_VAR_PREFIX = 'SPECTF_DEPLOY_'

numpy_to_gdal = {
np.dtype(np.float64): 7,
np.dtype(np.float32): 6,
np.dtype(np.int32): 5,
np.dtype(np.uint32): 4,
np.dtype(np.int16): 3,
np.dtype(np.uint16): 2,
np.dtype(np.uint8): 1,
}

# TODO: Refactor this into the CLI
# Configure logging
Expand Down Expand Up @@ -243,34 +235,9 @@ def deploy_trt(

logging.info("Inference complete.")

# Account for NODATA values and threshold
if proba:
cloud_mask[np.isnan(cloud_mask)] = -9999
else:
cloud_mask[cloud_mask < threshold] = 0
cloud_mask[cloud_mask > 0] = 1
cloud_mask[np.isnan(cloud_mask)] = 255
cloud_mask = cloud_mask.astype(np.uint8)
make_geotiff(cloud_mask, dataset.shape, outfp, proba, threshold)

# Reshape into input shape
cloud_mask = cloud_mask.reshape((dataset.shape[0], dataset.shape[1], 1))

driver = gdal.GetDriverByName('MEM')
ds = driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype])
ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0])

# Set NODATA value
if proba:
ds.GetRasterBand(1).SetNoDataValue(-9999)
else:
ds.GetRasterBand(1).SetNoDataValue(255)

tiff_driver = gdal.GetDriverByName('GTiff')
_ = tiff_driver.CreateCopy(outfp, ds, options=['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256'])

logging.info("Cloud mask saved to %s", outfp)

def pad_batch(b: torch.tensor, target_bsz:int):
def pad_batch(b: torch.Tensor, target_bsz:int):
# Pad w/ zeros
padded_shape = (target_bsz,) + b.shape[1:]
padded_batch = torch.zeros(
Expand Down
52 changes: 52 additions & 0 deletions spectf_cloud/deploy/gen_geotiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
from osgeo import gdal
import logging
from pathlib import Path

BINARY_NO_DATA_VAL = 255
Comment thread
michaelkiper marked this conversation as resolved.
PROBA_NO_DATA_VAL = -9999

numpy_to_gdal = {
np.dtype(np.float64): 7,
np.dtype(np.float32): 6,
np.dtype(np.int32): 5,
np.dtype(np.uint32): 4,
np.dtype(np.int16): 3,
np.dtype(np.uint16): 2,
np.dtype(np.uint8): 1,
}

def make_geotiff(cloud_mask: np.ndarray, dataset_shape: tuple, outfp: str, proba: bool, threshold: float):
nan_mask = np.isnan(cloud_mask)

mem_driver, tiff_driver = gdal.GetDriverByName('MEM'), gdal.GetDriverByName('GTiff')
opts = ['COMPRESS=LZW', 'COPY_SRC_OVERVIEWS=YES', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256']

if proba:
old_shape = cloud_mask.shape
cloud_mask[nan_mask] = PROBA_NO_DATA_VAL
cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1))
ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype])
ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0])
ds.GetRasterBand(1).SetNoDataValue(PROBA_NO_DATA_VAL)

_op = Path(outfp)
_op_s = str(_op.with_name(f"{_op.stem}_prob{_op.suffix}"))
_ = tiff_driver.CreateCopy(_op_s, ds, options=opts)

logging.info("Probability cloud mask saved to %s", _op_s)
cloud_mask = cloud_mask.reshape(old_shape)
Comment thread
michaelkiper marked this conversation as resolved.

cloud_mask[cloud_mask < threshold] = 0
cloud_mask[cloud_mask > 0] = 1
cloud_mask[nan_mask] = BINARY_NO_DATA_VAL
cloud_mask = cloud_mask.astype(np.uint8)

# Reshape into input shape
cloud_mask = cloud_mask.reshape((dataset_shape[0], dataset_shape[1], 1))

ds = mem_driver.Create('', cloud_mask.shape[1], cloud_mask.shape[0], cloud_mask.shape[2], numpy_to_gdal[cloud_mask.dtype])
ds.GetRasterBand(1).WriteArray(cloud_mask[:,:,0])
ds.GetRasterBand(1).SetNoDataValue(BINARY_NO_DATA_VAL)

_ = tiff_driver.CreateCopy(outfp, ds, options=opts)