diff --git a/polytope_feature/datacube/backends/datacube.py b/polytope_feature/datacube/backends/datacube.py index 217671a18..e1785f531 100644 --- a/polytope_feature/datacube/backends/datacube.py +++ b/polytope_feature/datacube/backends/datacube.py @@ -17,6 +17,7 @@ def __init__( self, axis_options=None, compressed_axes_options=[], + return_indexes=False, ): if axis_options is None: self.axis_options = {} @@ -36,6 +37,7 @@ def __init__( self.unwanted_path = {} self.compressed_axes = compressed_axes_options self.grid_md5_hash = None + self.return_indexes = return_indexes @abstractmethod def get(self, requests: TensorIndexTree, context: Dict) -> Any: @@ -164,6 +166,7 @@ def create( compressed_axes_options=[], alternative_axes=[], use_catalogue=False, + return_indexes=False, context=None, ): # TODO: get the configs as None for pre-determined value and change them to empty dictionary inside the function @@ -174,6 +177,7 @@ def create( datacube, axis_options, compressed_axes_options, + return_indexes, context, ) return xadatacube @@ -186,8 +190,9 @@ def create( axis_options, compressed_axes_options, alternative_axes, - context, use_catalogue, + return_indexes, + context, ) return fdbdatacube if type(datacube).__name__ == "MockDatacube": diff --git a/polytope_feature/datacube/backends/fdb.py b/polytope_feature/datacube/backends/fdb.py index 5997438ac..098d2f461 100644 --- a/polytope_feature/datacube/backends/fdb.py +++ b/polytope_feature/datacube/backends/fdb.py @@ -17,8 +17,9 @@ def __init__( axis_options=None, compressed_axes_options=[], alternative_axes=[], - context=None, use_catalogue=False, + return_indexes=False, + context=None, ): self.use_catalogue = use_catalogue if config is None: @@ -26,10 +27,7 @@ def __init__( if context is None: context = {} - super().__init__( - axis_options, - compressed_axes_options, - ) + super().__init__(axis_options, compressed_axes_options, return_indexes) logging.info("Created an FDB datacube with options: " + str(axis_options)) @@ -339,6 +337,8 @@ def get_last_layer_before_leaf(self, requests, leaf_path, current_idx, fdb_range ) # TODO: change this to accommodate non consecutive indexes being compressed too current_idx[i].extend(key_value_path["values"]) + if self.return_indexes: + c.indexes = key_value_path["values"] fdb_range_n[i].append(c) return (current_idx, fdb_range_n) diff --git a/polytope_feature/datacube/backends/xarray.py b/polytope_feature/datacube/backends/xarray.py index 55504d119..ddeac2e2f 100644 --- a/polytope_feature/datacube/backends/xarray.py +++ b/polytope_feature/datacube/backends/xarray.py @@ -14,12 +14,10 @@ def __init__( dataarray: xr.DataArray, axis_options=None, compressed_axes_options=[], + return_indexes=False, context=None, ): - super().__init__( - axis_options, - compressed_axes_options, - ) + super().__init__(axis_options, compressed_axes_options, return_indexes) if axis_options is None: axis_options = {} diff --git a/polytope_feature/options.py b/polytope_feature/options.py index d1f2f8f98..4e52df15e 100644 --- a/polytope_feature/options.py +++ b/polytope_feature/options.py @@ -84,6 +84,7 @@ class Config(ConfigModel): alternative_axes: Optional[List[GribJumpAxesConfig]] = [] use_catalogue: Optional[bool] = False engine_options: Optional[Dict[str, str]] = {} + return_indexes: Optional[bool] = False class PolytopeOptions(ABC): @@ -99,6 +100,7 @@ def get_polytope_options(options): alternative_axes = config_options.alternative_axes use_catalogue = config_options.use_catalogue engine_options = config_options.engine_options + return_indexes = config_options.return_indexes return ( axis_config, compressed_axes_config, @@ -106,4 +108,5 @@ def get_polytope_options(options): alternative_axes, use_catalogue, engine_options, + return_indexes, ) diff --git a/polytope_feature/polytope.py b/polytope_feature/polytope.py index 3c5f137cb..93270d10d 100644 --- a/polytope_feature/polytope.py +++ b/polytope_feature/polytope.py @@ -70,6 +70,7 @@ def __init__( alternative_axes, use_catalogue, engine_options, + return_indexes, ) = PolytopeOptions.get_polytope_options(options) self.datacube = Datacube.create( datacube, @@ -78,6 +79,7 @@ def __init__( compressed_axes_options, alternative_axes, use_catalogue, + return_indexes, self.context, ) if engine_options == {}: diff --git a/tests/test_fdb_return_idx.py b/tests/test_fdb_return_idx.py new file mode 100644 index 000000000..535219a45 --- /dev/null +++ b/tests/test_fdb_return_idx.py @@ -0,0 +1,96 @@ +import pandas as pd +import pytest + +from polytope_feature.polytope import Polytope, Request +from polytope_feature.shapes import Box, Select + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + { + "axis_name": "date", + "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], + }, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "type": "fc", "stream": "oper"}, + } + + @pytest.mark.fdb + def test_fdb_datacube(self): + import pygribjump as gj + + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20240103T0000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + ) + self.fdbdatacube = gj.GribJump() + self.API = Polytope( + datacube=self.fdbdatacube, + options=self.options, + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 3 + assert len(result.leaves[0].result) == 3 + assert len(result.leaves[0].indexes) == 0 + + @pytest.mark.fdb + def test_fdb_datacube_return_idx(self): + import pygribjump as gj + + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20240103T0000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + ) + self.fdbdatacube = gj.GribJump() + self.options["return_indexes"] = True + self.API = Polytope( + datacube=self.fdbdatacube, + options=self.options, + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 3 + assert len(result.leaves[0].result) == 3 + assert len(result.leaves[0].indexes) == 3