diff --git a/.gitignore b/.gitignore index d81caa02a..867570ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -34,13 +34,9 @@ Cargo.lock site example_eo example_mri -new_test_venv_polytope -new_venv_polytope -new_polytope_venv -polytope_venv -polytope_venv_latest -new_updated_numpy_venv -newest-polytope-venv +.mypy_cache +*.req +polytope_python.egg-info serializedTree new_polytope_venv newest-polytope-venv @@ -49,23 +45,17 @@ polytope_feature/datacube/quadtree/target venv_python3_11 venv_gj_iterator *.json -venv_python3_11 *.txt tests/data -venv_gj_iterator -target -rust_deployment_venv *.so -*.nc - -test_w_qubed -rust_venv -new_rust_test_venv -*.so -rust_deployment_venv -*.lock -**/target +**/icon_grids +polytope_feature/_version.py **/build +**.lock +**/target +**/*venv*/ +test_w_qubed +performance_plots_qubed _version.py *.DS_Store icon_grids diff --git a/polytope_feature/datacube/backends/datacube.py b/polytope_feature/datacube/backends/datacube.py index 6c98f6495..6cd29c380 100644 --- a/polytope_feature/datacube/backends/datacube.py +++ b/polytope_feature/datacube/backends/datacube.py @@ -32,6 +32,7 @@ def __init__(self, axis_options=None, compressed_axes_options=[], grid_online_pa self.unwanted_path = {} self.compressed_axes = compressed_axes_options self.grid_md5_hash = None + self.datacube_transformations = [] self.grid_online_path = grid_online_path self.grid_local_directory = grid_local_directory @@ -92,6 +93,8 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt if transformation not in self._axes[axis_name].transformations: # Avoids duplicates being stored self._axes[axis_name].transformations.append(transformation) + if transformation not in self.datacube_transformations: + self.datacube_transformations.append(transformation) else: # Means we have an unsliceable axis since we couln't transform values to desired type if self._axes is None or axis_name not in self._axes.keys(): @@ -111,6 +114,20 @@ def _check_and_add_axes(self, options, name, values): if self._axes is None or name not in self._axes.keys(): DatacubeAxis.create_standard(name, values, self) + def _add_all_type_change_transformation_axes(self, options, name, values): + for transformation_type_key in options.transformations: + if transformation_type_key == "type_change": + self._create_axes(name, values, transformation_type_key, options) + else: + DatacubeAxis.create_standard(name, values, self) + + def _check_and_readd_axes(self, options, name, values): + if options is not None: + self._add_all_transformation_axes(options, name, values) + else: + if self._axes is None or name not in self._axes.keys(): + DatacubeAxis.create_standard(name, values, self) + def has_index(self, path: DatacubePath, axis, index): "Given a path to a subset of the datacube, checks if the index exists on that sub-datacube axis" path = self.fit_path(path) @@ -163,6 +180,7 @@ def create( alternative_axes=[], grid_online_path="", grid_local_directory="", + datacube_axes={}, context=None, ): # TODO: get the configs as None for pre-determined value and change them to empty dictionary inside the function @@ -189,6 +207,17 @@ def create( return fdbdatacube if type(datacube).__name__ == "MockDatacube": return datacube + if type(datacube).__name__ == "Qube": + from ..datacube_axis import _str_to_axis + from .qubed import QubedDatacube + + actual_datacube_axes = {} + for key, value in datacube_axes.items(): + actual_datacube_axes[key] = _str_to_axis[value] + qubed_datacube = QubedDatacube( + datacube, actual_datacube_axes, config, axis_options, compressed_axes_options, alternative_axes, context + ) + return qubed_datacube def check_branching_axes(self, request): pass diff --git a/polytope_feature/datacube/backends/qubed.py b/polytope_feature/datacube/backends/qubed.py new file mode 100644 index 000000000..d104ad7ab --- /dev/null +++ b/polytope_feature/datacube/backends/qubed.py @@ -0,0 +1,487 @@ +import logging +import operator +from copy import deepcopy +from itertools import product + +import numpy as np +import pygribjump as pygj +from qubed.value_types import QEnum + +from ...utility.exceptions import BadGridError, GribJumpNoIndexError +from ...utility.geometry import nearest_pt +from ...utility.metadata_handling import find_metadata, flatten_metadata +from .datacube import Datacube, TensorIndexTree + + +class QubedDatacube(Datacube): + def __init__( + self, + q, + datacube_axes, + config=None, + axis_options=None, + compressed_axes_options=[], + alternative_axes=[], + context=None, + ): + if config is None: + config = {} + if axis_options is None: + axis_options = {} + + self.q = q + self.datacube_axes = datacube_axes + # TODO: should the gj object be passed in instead? + self.gj = pygj.GribJump() + super().__init__(axis_options, compressed_axes_options) + + # TODO: where do these come from and are they right? + self.unwanted_path = {} + + self.axis_options = axis_options + # Find values in the level 3 FDB datacube + + self.fdb_coordinates = {} + + # TODO: we instead now have a list of axes with the actual axes types... + # TODO: here use the qubed to find all axes names and then get the values from the first val of the qubed and + # then apply transformations to get the actual right axis type... + for axis_name in datacube_axes: + axis = datacube_axes[axis_name] + self.fdb_coordinates[axis_name] = [axis.type_eg] + + self.fdb_coordinates["values"] = [] + for name, values in self.fdb_coordinates.items(): + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + self._check_and_add_axes(options, name, values) + self.treated_axes.append(name) + self.complete_axes.append(name) + + # # add other options to axis which were just created above like "lat" for the mapper transformations for eg + for name in self._axes: + if name not in self.treated_axes: + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + val = self._axes[name].type_eg + self._check_and_add_axes(options, name, val) + + # TODO: actually should separate axis creation with types from the transformations... + # TODO: we should create all axes here first maybe? + # TODO: otherwise, we need to somehow get the axis type information/objects when we transform the polytope + # points into continuous types? + # TODO: Also, if we don't have the right axis types from the start here, then when we pre-process the polytopes, + # it will be wrong... + + def add_axes_dynamically(self, qube_node): + # TODO: here look if the options have changed and we need to modify the transformations + changed_options = False + if not len(qube_node.metadata.items()) == 0: + changed_options = True + + if changed_options: + if len(qube_node.children) == 0: + axis_name = "values" + vals = [] + else: + axis_name = qube_node.key + self._axes.pop(axis_name, None) + vals = list(qube_node.values) + + options = None + + for opt in self.axis_options: + if opt.axis_name == axis_name: + options = opt + + # NOTE: be sure to remove the "fake" additional grid axes + if len(qube_node.children) == 0: + axes_names = list(self._axes.keys()) + + for name in axes_names: + if name not in self.treated_axes: + self._axes.pop(name, None) + + self._check_and_readd_axes(options, axis_name, vals) + + # NOTE: now if we have created the additional grid axes, readd the additional transformations + # associated to them + new_axes_names = list(self._axes.keys()) + for name in new_axes_names: + if name not in self.treated_axes: + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + val = [self._axes[name].type_eg] + self._check_and_readd_axes(options, name, val) + + # TODO: will this work?? How do we make sure we add the grid axes which come from the values + # transformation here?? + # TODO: we can't do a "difference" of axes like before since we don't a priori have the final axes + # set available at once?? + pass + + def datacube_natural_indexes(self, qube_node): + if qube_node is not None: + return np.asarray(list(qube_node.values)) + else: + return [] + + def find_point_cloud(self): + # find the point cloud of irregular grid if it exists + if self.grid_transformation.is_irregular: + return self.grid_transformation._final_transformation.grid_latlon_points() + + def get_indices(self, path, path_node, axis, lower, upper, method=None): + """ + Given a path to a subset of the datacube, return the discrete indexes which exist between + two non-discrete values (lower, upper) for a particular axis (given by label) + If lower and upper are equal, returns the index which exactly matches that value (if it exists) + e.g. returns integer discrete points between two floats + """ + indexes = axis.find_indexes_node(path_node, self, path) + + idx_between = axis.find_indices_between(indexes, lower, upper, self, method) + + logging.debug(f"For axis {axis.name} between {lower} and {upper}, found indices {idx_between}") + + if path_node: + indexes = [indexes.index(item) for item in idx_between] + else: + indexes = None + + return (idx_between, indexes) + + def get(self, requests, context=None): + """ + We have a compressed tree of requests, which we need to decompress completely with its metadata indexes. + BUT the last two axes, we would like to "ignore" in the decompression and instead, + we create grid index ranges from them. + WHILE we decompress, we need to keep some kind of map from decompressed request + grid index ranges + to corresponding tree node. + This mapping will map potentially several decompressed request + grid index ranges tuples + to the same tree nodes. + + ADDED DIFFICULTY: the grid index ranges MUST be ordered (which is not guaranteed from the tree) so we need + to sort them, while also sorting the tree nodes so that they match up. + + UNTIL NOW, we had a map of compressed request + grid index ranges tuples to tree nodes. + We could just add a map of compressed request -> list of decompressed request + associated metadata + """ + + if context is None: + context = {} + if len(requests.children) == 0: + return requests + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + + # here, loop through the fdb requests and request from gj and directly add to the nodes + complete_list_complete_uncompressed_requests = [] + complete_fdb_decoding_info = [] + for j, compressed_request in enumerate(fdb_requests): + compressed_metadata = compressed_request[2] + # Need to determine the possible decompressed requests + # First, find the possible combinations of compressed indices + interm_branch_tuple_values = [] + for key in compressed_request[0].keys(): + interm_branch_tuple_values.append(compressed_request[0][key]) + request_combis = product(*interm_branch_tuple_values) + + index_combis_raw = list(product(*[range(len(lst)) for lst in interm_branch_tuple_values])) + index_combis = [(0, *comb) for comb in index_combis_raw] + + # Need to extract the possible requests and add them to the right nodes + + for i, combi in enumerate(request_combis): + metadata_idxs = index_combis[i] + actual_metadata = find_metadata(metadata_idxs, compressed_metadata) + + path = flatten_metadata(actual_metadata["path"]) + scheme = flatten_metadata(actual_metadata["scheme"]) + offset = flatten_metadata(actual_metadata["offset"]) + host = flatten_metadata(actual_metadata["host"]) + port = flatten_metadata(actual_metadata["port"]) + + gj_extraction_request = pygj.PathExtractionRequest( + path, scheme, offset, host, port, compressed_request[1], self.grid_md5_hash + ) + + complete_list_complete_uncompressed_requests.append(gj_extraction_request) + complete_fdb_decoding_info.append(fdb_requests_decoding_info[j]) + + if logging.root.level <= logging.DEBUG: + printed_list_to_gj = complete_list_complete_uncompressed_requests[::1000] + logging.debug("The requests we give GribJump are: %s", printed_list_to_gj) + logging.info("Requests given to GribJump extract for %s", context) + try: + output_values = self.gj.extract_from_paths(complete_list_complete_uncompressed_requests, context) + except Exception as e: + if "BadValue: Grid hash mismatch" in str(e): + logging.info("Error is: %s", e) + raise BadGridError() + if "Missing JumpInfo" in str(e): + logging.info("Error is: %s", e) + raise GribJumpNoIndexError() + else: + raise e + + logging.info("Requests extracted from GribJump for %s", context) + if logging.root.level <= logging.DEBUG: + printed_output_values = output_values[::1000] + logging.debug("GribJump outputs: %s", printed_output_values) + self.assign_fdb_output_to_nodes(output_values, complete_fdb_decoding_info) + + def get_fdb_requests( + self, + requests, + fdb_requests=[], + fdb_requests_decoding_info=[], + leaf_path=None, + leaf_metadata=None, + ): + # TODO: collect leaf metadata from qube here too + if leaf_path is None: + leaf_path = {} + + if leaf_metadata is None: + leaf_metadata = {} + + # First when request node is root, go to its children + if requests.key == "root": + logging.debug("Looking for data for the tree") + + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.key: requests.values} + ax = self._axes[requests.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + # TODO: change to use the datacube trasnformations instead... + if requests.key == "time": + new_vals = [] + for val in key_value_path[requests.key]: + new_vals.append(val[7:9] + val[10:12]) + key_value_path[requests.key] = new_vals + if requests.key == "date": + new_vals = [] + for val in key_value_path[requests.key]: + new_vals.append(val[:4] + val[5:7] + val[8:10]) + key_value_path[requests.key] = new_vals + leaf_path.update(key_value_path) + # TODO: in the leaf metadata, try to instead store mapping leaf_path -> list(individual metadata dicts for + # each uncompressed value of tree) + leaf_metadata.update(requests.metadata) + if len(requests.children[0].children[0].children) == 0: + # find the fdb_requests and associated nodes to which to add results + (path, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values(requests, leaf_path) + ( + original_indices, + sorted_request_ranges, + fdb_node_ranges, + ) = self.sort_fdb_request_ranges(current_start_idxs, lat_length, fdb_node_ranges) + fdb_requests.append((path, sorted_request_ranges, deepcopy(leaf_metadata))) + # TODO: did we need to deepcopy the leaf metadata?? + fdb_requests_decoding_info.append((original_indices, fdb_node_ranges)) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path, leaf_metadata) + + def remove_duplicates_in_request_ranges(self, fdb_node_ranges, current_start_idxs): + # seen_indices = set() + # for i, idxs_list in enumerate(current_start_idxs): + # for k, sub_lat_idxs in enumerate(idxs_list): + # actual_fdb_node = fdb_node_ranges[i][k] + # original_fdb_node_range_vals = [] + # new_current_start_idx = [] + # for j, idx in enumerate(sub_lat_idxs): + # if idx not in seen_indices: + # # NOTE: need to remove it from the values in the corresponding tree node + # # NOTE: need to read just the range we give to gj + # original_fdb_node_range_vals.append(list(actual_fdb_node[0].values)[j]) + # seen_indices.add(idx) + # new_current_start_idx.append(idx) + # if original_fdb_node_range_vals != []: + # actual_fdb_node[0].values = tuple(original_fdb_node_range_vals) + # else: + # # there are no values on this node anymore so can remove it + # actual_fdb_node[0].remove_branch() + # if len(new_current_start_idx) == 0: + # current_start_idxs[i].pop(k) + # else: + # current_start_idxs[i][k] = new_current_start_idx + return (fdb_node_ranges, current_start_idxs) + + def nearest_lat_lon_search(self, requests): + if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].key + second_ax_name = requests.children[0].children[0].key + + axes_in_nearest_search = [ + first_ax_name not in self.nearest_search.keys(), + second_ax_name not in self.nearest_search.keys(), + ] + + if all(not item for item in axes_in_nearest_search): + raise Exception("nearest point search axes are wrong") + + second_ax = self._axes[requests.children[0].children[0].key] + + nearest_pts = self.nearest_search.get((first_ax_name, second_ax_name), None) + if nearest_pts is None: + nearest_pts = self.nearest_search.get((second_ax_name, first_ax_name), None) + for i, pt in enumerate(nearest_pts): + nearest_pts[i] = [pt[1], pt[0]] + + transformed_nearest_pts = [] + for point in nearest_pts: + transformed_nearest_pts.append([point[0], second_ax._remap_val_to_axis_range(point[1])]) + + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.values, lon_child.values]) + + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in transformed_nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + + # need to remove the branches that do not fit + lat_children_values = [child.values for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.values == lat_child_val][0] + if lat_child.values not in [(latlon[0],) for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if (latlon[0],) == lat_child.values] + lon_children_values = [child.values for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.values == lon_child_val][0] + for value in lon_child.values: + if value not in possible_lons: + lon_child.remove_compressed_branch(value) + + def get_2nd_last_values(self, requests, leaf_path=None): + if leaf_path is None: + leaf_path = {} + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + self.nearest_lat_lon_search(requests) + + lat_length = len(requests.children) + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[TensorIndexTree.root for y in range(lon_length)] for x in range(lon_length)] + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.key: list(lat_child.values)} + ax = self._axes[lat_child.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + (current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, current_start_idx, fdb_range_nodes + ) + + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values", None) + return (leaf_path_copy, current_start_idxs, fdb_node_ranges, lat_length) + + def get_last_layer_before_leaf(self, requests, leaf_path, current_idx, fdb_range_n): + current_idx = [[] for i in range(len(requests.children))] + fdb_range_n = [[] for i in range(len(requests.children))] + for i, c in enumerate(requests.children): + # now c are the leaves of the initial tree + key_value_path = {c.key: list(c.values)} + ax = self._axes[c.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + # TODO: change this to accommodate non consecutive indexes being compressed too + current_idx[i].extend(key_value_path["values"]) + fdb_range_n[i].append(c) + return (current_idx, fdb_range_n) + + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k, request_output_values in enumerate(output_values): + ( + original_indices, + fdb_node_ranges, + ) = fdb_requests_decoding_info[k] + sorted_fdb_range_nodes = [fdb_node_ranges[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + n = sorted_fdb_range_nodes[i][0] + if len(request_output_values.values) == 0: + # If we are here, no data was found for this path in the fdb + none_array = [None] * len(n.values) + if n.metadata.get("result", None) is None: + n.metadata["result"] = [] + n.metadata["result"].extend(none_array) + else: + if n.metadata.get("result", None) is None: + n.metadata["result"] = [] + n.metadata["result"].extend(request_output_values.values[i]) + + def sort_fdb_request_ranges(self, current_start_idx, lat_length, fdb_node_ranges): + (new_fdb_node_ranges, new_current_start_idx) = self.remove_duplicates_in_request_ranges( + fdb_node_ranges, current_start_idx + ) + interm_request_ranges = [] + # TODO: modify the start indexes to have as many arrays as the request ranges + new_fdb_node_ranges = [] + for i in range(lat_length): + interm_fdb_nodes = fdb_node_ranges[i] + old_interm_start_idx = current_start_idx[i] + for j in range(len(old_interm_start_idx)): + # TODO: if we sorted the cyclic values in increasing order on the tree too, + # then we wouldn't have to sort here? + sorted_list = sorted(enumerate(old_interm_start_idx[j]), key=lambda x: x[1]) + original_indices_idx, interm_start_idx = zip(*sorted_list) + for interm_fdb_nodes_obj in interm_fdb_nodes[j]: + interm_fdb_nodes_obj.values = QEnum( + tuple([list(interm_fdb_nodes_obj.values)[k] for k in original_indices_idx]) + ) + if abs(interm_start_idx[-1] + 1 - interm_start_idx[0]) <= len(interm_start_idx): + current_request_ranges = (interm_start_idx[0], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + else: + jumps = list(map(operator.sub, interm_start_idx[1:], interm_start_idx[:-1])) + last_idx = 0 + for k, jump in enumerate(jumps): + if jump > 1: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[k] + 1) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + last_idx = k + 1 + interm_request_ranges.append(current_request_ranges) + if k == len(interm_start_idx) - 2: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + return (original_indices, sorted_request_ranges, new_fdb_node_ranges) diff --git a/polytope_feature/datacube/datacube_axis.py b/polytope_feature/datacube/datacube_axis.py index fbf8a70cb..b2032a5b6 100644 --- a/polytope_feature/datacube/datacube_axis.py +++ b/polytope_feature/datacube/datacube_axis.py @@ -78,6 +78,22 @@ def find_indexes(self, path, datacube): indexes = transformation.find_modified_indexes(indexes, path, datacube, self) return indexes + def find_standard_indexes_node(self, path_node, datacube): + return datacube.datacube_natural_indexes(path_node) + + def find_indexes_node(self, path_node, datacube, path): + indexes = self.find_standard_indexes_node(path_node, datacube) + # path = {self.name: tuple(path_node.values)} + if not path: + if path_node: + path = {path_node.key: tuple(path_node.values)} + else: + path = {self.name: tuple()} + for transformation in self.transformations[::-1]: + indexes = transformation.find_modified_indexes(indexes, path, datacube, self) + # print(indexes) + return indexes + def offset(self, value): offset = 0 for transformation in self.transformations[::-1]: @@ -85,6 +101,8 @@ def offset(self, value): return offset def unmap_path_key(self, key_value_path, leaf_path, unwanted_path): + # print("WHAT ARE THE AXIS TRANSFORMATIONS??") + # print(self.transformations) for transformation in self.transformations[::-1]: (key_value_path, leaf_path, unwanted_path) = transformation.unmap_path_key( key_value_path, leaf_path, unwanted_path, self @@ -196,6 +214,7 @@ def __init__(self): # TODO: Maybe here, store transformations as a dico instead self.transformations = [] self.type = 0 + self.type_eg = 0 self.can_round = True def parse(self, value: Any) -> Any: @@ -218,6 +237,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = 0.0 + self.type_eg = 0.0 self.can_round = True def parse(self, value: Any) -> Any: @@ -240,6 +260,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = pd.Timestamp("2000-01-01T00:00:00") + self.type_eg = "20000101T000000" self.can_round = False def parse(self, value: Any) -> Any: @@ -270,6 +291,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = np.timedelta64(0, "s") + self.type_eg = "0000" self.can_round = False def parse(self, value: Any) -> Any: @@ -300,6 +322,8 @@ def __init__(self): self.range = None self.transformations = [] self.can_round = False + self.type = "" + self.type_eg = "" def parse(self, value: Any) -> Any: return value @@ -330,3 +354,11 @@ def serialize(self, value): int: IntDatacubeAxis(), float: FloatDatacubeAxis(), } + +_str_to_axis = { + "FloatDatacubeAxis": FloatDatacubeAxis(), + "IntDatacubeAxis": IntDatacubeAxis(), + "UnsliceableDatacubeAxis": UnsliceableDatacubeAxis(), + "PandasTimedeltaDatacubeAxis": PandasTimedeltaDatacubeAxis(), + "PandasTimestampDatacubeAxis": PandasTimestampDatacubeAxis(), +} diff --git a/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py b/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py index baa38009d..7a4da9bc4 100644 --- a/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py +++ b/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py @@ -1,3 +1,5 @@ +import numpy as np + from ....utility.list_tools import bisect_left_cmp, bisect_right_cmp from ..datacube_transformations import DatacubeAxisTransformation @@ -24,12 +26,17 @@ def unwanted_axes(self): def find_modified_indexes(self, indexes, path, datacube, axis): if axis.name in datacube.complete_axes: + # if isinstance(indexes, list): + # indexes.sort() + # ordered_indices = indexes + # else: ordered_indices = indexes.sort_values() else: ordered_indices = indexes return ordered_indices def find_indices_between(self, indexes, low, up, datacube, method, indexes_between_ranges, axis): + # indexes = np.asarray(indexes) indexes_between_ranges = [] if axis.name == self.name: if axis.name in datacube.complete_axes: diff --git a/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py b/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py index a71579bf2..8d09d8e16 100644 --- a/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py +++ b/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py @@ -85,14 +85,17 @@ def __init__(self, axis_name, new_type): def transform_type(self, value): try: - return pd.Timestamp(value) + return pd.Timestamp(str(value)) except ValueError: return None def make_str(self, value): values = [] for val in value: - values.append(val.strftime("%Y%m%d")) + if isinstance(val, str): + values.append(val) + else: + values.append(val.strftime("%Y%m%d")) return tuple(values) @@ -112,9 +115,12 @@ def transform_type(self, value): def make_str(self, value): values = [] for val in value: - hours = int(val.total_seconds() // 3600) - mins = int((val.total_seconds() % 3600) // 60) - values.append(f"{hours:02d}{mins:02d}") + if isinstance(val, str): + values.append(val) + else: + hours = int(val.total_seconds() // 3600) + mins = int((val.total_seconds() % 3600) // 60) + values.append(f"{hours:02d}{mins:02d}") return tuple(values) diff --git a/polytope_feature/engine/engine.py b/polytope_feature/engine/engine.py index d9674aae1..deae8af98 100644 --- a/polytope_feature/engine/engine.py +++ b/polytope_feature/engine/engine.py @@ -1,10 +1,11 @@ +import math from abc import abstractmethod from typing import List from ..datacube.backends.datacube import Datacube from ..datacube.datacube_axis import UnsliceableDatacubeAxis -from ..datacube.tensor_index_tree import TensorIndexTree -from ..shapes import ConvexPolytope +from ..shapes import ConvexPolytope, Product +from ..utility.list_tools import unique class Engine: @@ -19,9 +20,9 @@ def __init__(self, engine_options=None): self.remapped_vals = {} self.compressed_axes = [] - def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]) -> TensorIndexTree: - # Delegate to the right slicer that the axes within the polytopes need to use - pass + # def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]) -> TensorIndexTree: + # # Delegate to the right slicer that the axes within the polytopes need to use + # pass def check_slicer(self, ax): # Return the slicer instance if ax is sliceable. @@ -41,3 +42,65 @@ def default(): @abstractmethod def _build_branch(self, ax, node, datacube, next_nodes, api): pass + + def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): + for i, ax in enumerate(p._axes): + mapper = datacube.get_mapper(ax) + if self.ax_is_unsliceable.get(ax, None) is None: + self.ax_is_unsliceable[ax] = isinstance(mapper, UnsliceableDatacubeAxis) + if self.ax_is_unsliceable[ax]: + break + for j, val in enumerate(p.points): + p.points[j][i] = mapper.to_float(mapper.parse(p.points[j][i])) + # Remove duplicate points + unique(p.points) + + def remap_values(self, ax, value): + remapped_val = self.remapped_vals.get((value, ax.name), None) + if remapped_val is None: + remapped_val = value + if ax.is_cyclic: + remapped_val_interm = ax.remap([value, value])[0] + remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 + if ax.can_round: + remapped_val = round(remapped_val, int(-math.log10(ax.tol))) + self.remapped_vals[(value, ax.name)] = remapped_val + return remapped_val + + def pre_process_polytopes(self, datacube, polytopes): + for p in polytopes: + if isinstance(p, Product): + for poly in p.polytope(): + self._unique_continuous_points(poly, datacube) + else: + self._unique_continuous_points(p, datacube) + + def find_compressed_axes(self, datacube, polytopes): + # First determine compressable axes from input polytopes + compressable_axes = [] + for polytope in polytopes: + if polytope.is_orthogonal: + for ax in polytope.axes(): + compressable_axes.append(ax) + # Cross check this list with list of compressable axis from datacube + # (should not include any merged or coupled axes) + for compressed_axis in compressable_axes: + if compressed_axis in datacube.compressed_axes: + self.compressed_axes.append(compressed_axis) + # add the last axis of the grid always (longitude) as a compressed axis + k, last_value = _, datacube.axes[k] = datacube.axes.popitem() + self.compressed_axes.append(k) + + def remove_compressed_axis_in_union(self, polytopes): + for p in polytopes: + if p.is_in_union: + for axis in p.axes(): + if axis == self.compressed_axes[-1]: + self.compressed_axes.remove(axis) + + def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): + self.find_compressed_axes(datacube, polytopes) + self.remove_compressed_axis_in_union(polytopes) + self.pre_process_polytopes(datacube, polytopes) + tree = self.build_tree(polytopes, datacube) + return tree diff --git a/polytope_feature/engine/hullslicer.py b/polytope_feature/engine/hullslicer.py index 68c1cc24f..c546bb4b0 100644 --- a/polytope_feature/engine/hullslicer.py +++ b/polytope_feature/engine/hullslicer.py @@ -1,6 +1,8 @@ import math from copy import copy +from ..datacube.tensor_index_tree import TensorIndexTree +from ..utility.combinatorics import find_polytope_combinations, group, tensor_product from ..utility.exceptions import UnsliceableShapeError from .engine import Engine from .slicing_tools import slice @@ -154,3 +156,36 @@ def _build_branch(self, ax, node, datacube, next_nodes, api): ) del node["unsliced_polytopes"] + + def slice_tree(self, datacube, final_polys): + r = TensorIndexTree() + r["unsliced_polytopes"] = set(final_polys) + current_nodes = [r] + for ax in datacube.axes.values(): + next_nodes = [] + interm_next_nodes = [] + for node in current_nodes: + self._build_branch(ax, node, datacube, interm_next_nodes) + next_nodes.extend(interm_next_nodes) + interm_next_nodes = [] + current_nodes = next_nodes + return r + + def build_tree(self, polytopes, datacube): + groups, input_axes = group(polytopes) + datacube.validate(input_axes) + tree = TensorIndexTree() + combinations = tensor_product(groups) + + sub_trees = [] + + for c in combinations: + r = TensorIndexTree() + final_polys = find_polytope_combinations(c) + r = self.slice_tree(datacube, final_polys) + + sub_trees.append(r) + + for sub_tree in sub_trees[0:]: + tree.merge(sub_tree) + return tree diff --git a/polytope_feature/engine/qubed_slicer.py b/polytope_feature/engine/qubed_slicer.py new file mode 100644 index 000000000..d76caeb9f --- /dev/null +++ b/polytope_feature/engine/qubed_slicer.py @@ -0,0 +1,334 @@ +from copy import deepcopy + +import numpy as np +import pandas as pd +from qubed import Qube +from qubed.value_types import QEnum + +from ..datacube.datacube_axis import UnsliceableDatacubeAxis +from ..datacube.transformations.datacube_mappers.datacube_mappers import DatacubeMapper +from ..utility.combinatorics import ( + find_polytope_combinations, + find_polytopes_on_axis, + group, + tensor_product, +) +from .engine import Engine +from .slicing_tools import slice + + +class QubedSlicer(Engine): + def __init__(self): + self.ax_is_unsliceable = {} + self.compressed_axes = [] + self.remapped_vals = {} + + def find_datacube_vals(): + # TODO + pass + + def find_values_between(self, polytope, ax, node, datacube, lower, upper, path=None): + if isinstance(ax, UnsliceableDatacubeAxis): + filtered = [(i, v) for i, v in enumerate(node.values) if lower <= v <= upper] + indices, values = zip(*filtered) if filtered else ([], []) + return values, indices + + tol = ax.tol + lower = ax.from_float(lower - tol) + upper = ax.from_float(upper + tol) + method = polytope.method + values, indexes = datacube.get_indices(path, node, ax, lower, upper, method) + return values, indexes + + def get_sliced_polys(self, found_vals, ax, child_name, poly, slice_axis_idx): + sliced_polys = [] + for val in found_vals: + if not isinstance(ax, UnsliceableDatacubeAxis): + fval = ax.to_float(val) + # slice polytope along the value and add sliced polytope to list of polytopes in memory + sliced_poly = slice(poly, child_name, fval, slice_axis_idx) + sliced_polys.append(sliced_poly) + return sliced_polys + + def find_children_polytopes(self, polytopes, poly, sliced_polys): + child_polytopes = [p for p in polytopes if p != poly] + child_polytopes.extend([sliced_poly_ for sliced_poly_ in sliced_polys if sliced_poly_ is not None]) + return child_polytopes + + def find_new_vals(self, found_vals, ax): + new_found_vals = [] + for found_val in found_vals: + found_val = self.remap_values(ax, found_val) + # TODO: use unmap_path_key here with the transformations instead + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + new_found_vals.append(str(found_val)) + else: + new_found_vals.append(found_val) + return new_found_vals + + def build_branch( + self, + real_uncompressed_axis, + found_vals, + sliced_polys, + polytopes, + poly, + child, + datacube, + datacube_transformations, + ax, + idxs=None, + metadata_idx_stack=None, + ): + final_children_and_vals = [] + if real_uncompressed_axis: + for i, found_val in enumerate(found_vals): + if i < len(sliced_polys): + sliced_polys_ = [sliced_polys[i]] + else: + sliced_polys_ = sliced_polys + child_polytopes = self.find_children_polytopes(polytopes, poly, sliced_polys_) + if idxs: + current_metadata_idx_stack = metadata_idx_stack + [[idxs[i]]] + children = self._slice( + child, child_polytopes, datacube, datacube_transformations, current_metadata_idx_stack + ) + # If this node used to have children but now has none due to filtering, skip it. + if child.children and not children: + continue + + new_found_vals = self.find_new_vals([found_val], ax) + + if idxs: + request_child_val = (children, new_found_vals, current_metadata_idx_stack) + else: + request_child_val = (children, new_found_vals) + final_children_and_vals.append(request_child_val) + else: + # if it's compressed, then can add all found values in a single node + child_polytopes = self.find_children_polytopes(polytopes, poly, sliced_polys) + # create children + if idxs: + current_metadata_idx_stack = metadata_idx_stack + [idxs] + children = self._slice( + child, child_polytopes, datacube, datacube_transformations, current_metadata_idx_stack + ) + # If this node used to have children but now has none due to filtering, skip it. + if child.children and not children: + return None + + new_found_vals = self.find_new_vals(found_vals, ax) + if idxs: + request_child_val = (children, new_found_vals, current_metadata_idx_stack) + else: + request_child_val = (children, new_found_vals) + final_children_and_vals.append(request_child_val) + + if len(final_children_and_vals) == 0: + return None + return final_children_and_vals + + def _slice(self, q: Qube, polytopes, datacube, datacube_transformations, metadata_idx_stack=None) -> list[Qube]: + result = [] + + if metadata_idx_stack is None: + metadata_idx_stack = [[0]] + + for i, child in enumerate(q.children): + # find polytopes which are defined on axis child.key + polytopes_on_axis = find_polytopes_on_axis(child.key, polytopes) + + # TODO: here add the axes to datacube backend with transformations for child.key + # TODO: update the datacube axis_options before we dynamically change the axes + + # TODO: this is slow... will need to make it faster and only do this when we need... + datacube.add_axes_dynamically(child) + + # here now first change the values in the polytopes on the axis to reflect the axis type + for poly in polytopes_on_axis: + ax = datacube._axes[child.key] + # find extents of polytope on child.key + lower, upper, slice_axis_idx = poly.extents(child.key) + + # find values on child that are within extents + found_vals, idxs = self.find_values_between(poly, ax, child, datacube, lower, upper) + + # TODO: find the indexes of the found_vals wrt child.values, + # to extract the right metadata that we want to keep inside self.build_branch + + if len(found_vals) == 0: + continue + + sliced_polys = self.get_sliced_polys(found_vals, ax, child.key, poly, slice_axis_idx) + # decide if axis should be compressed or not according to polytope + axis_compressed = child.key in self.compressed_axes + real_uncompressed_axis = not axis_compressed and len(found_vals) > 1 + final_children_and_vals = self.build_branch( + real_uncompressed_axis, + found_vals, + sliced_polys, + polytopes, + poly, + child, + datacube, + datacube_transformations, + ax, + idxs, + metadata_idx_stack, + ) + + if final_children_and_vals is None: + continue + + def find_metadata(metadata_idx): + metadata = {} + for k, vs in child.metadata.items(): + metadata_depth = len(vs.shape) + relevant_metadata_idxs = metadata_idx[:metadata_depth] + ix = np.ix_(*relevant_metadata_idxs) + metadata[k] = vs[ix] + return metadata + + for children, new_found_vals, current_metadata_idxs in final_children_and_vals: + metadata = find_metadata(current_metadata_idxs) + qube_node = Qube.make_node( + key=child.key, values=QEnum(new_found_vals), metadata=metadata, children=children + ) + if not children: + # We've reached the end of the qube + # qube_node.sliced_polys = sliced_polys + qube_node.sliced_polys = polytopes + result.append(qube_node) + + return result + + def slice_grid_axes(self, q: Qube, datacube, datacube_transformations): + # TODO: here, instead of checking if the qube is at the leaves, we instead give it the sub-tree + # and go to its leaves + # TODO: we then find the remaining sliced_polys to continue slicing and slice along lat + lon like before + # TODO: we then return the completed tree + compressed_leaves = [leaf for leaf in q.compressed_leaf_nodes()] + actual_leaves = deepcopy(compressed_leaves) + for j, leaf in enumerate(actual_leaves): + # for leaf in q.compressed_leaf_nodes(): + result = [] + mapper_transformation = None + for transformation in datacube_transformations: + if isinstance(transformation, DatacubeMapper): + mapper_transformation = transformation + if not mapper_transformation: + # There is no grid mapping + pass + else: + grid_axes = mapper_transformation._mapped_axes() + polytopes = leaf.sliced_polys + + # Handle first grid axis + polytopes_on_axis = find_polytopes_on_axis(grid_axes[0], polytopes) + + for poly in polytopes_on_axis: + ax = datacube._axes[grid_axes[0]] + lower, upper, slice_axis_idx = poly.extents(grid_axes[0]) + + found_vals, _ = self.find_values_between(poly, ax, None, datacube, lower, upper) + + if len(found_vals) == 0: + continue + + sliced_polys = self.get_sliced_polys(found_vals, ax, grid_axes[0], poly, slice_axis_idx) + # decide if axis should be compressed or not according to polytope + # NOTE: actually the first grid axis will never be compressed + + # if it's not compressed, need to separate into different nodes to append to the tree + for i, found_val in enumerate(found_vals): + found_val = self.remap_values(ax, found_val) + child_polytopes = [p for p in polytopes if p != poly] + if sliced_polys[i]: + child_polytopes.append(sliced_polys[i]) + + second_axis_vals = mapper_transformation.second_axis_vals([found_val]) + flattened_path = {grid_axes[0]: (found_val,)} + # get second axis children through slicing + children = self._slice_second_grid_axis( + grid_axes[1], + child_polytopes, + datacube, + datacube_transformations, + second_axis_vals, + flattened_path, + ) + # If this node used to have children but now has none due to filtering, skip it. + if not children: + continue + + qube_node = Qube.make_node( + key=grid_axes[0], values=QEnum([found_val]), metadata={}, children=children + ) + result.append(qube_node) + # leaf.children = result + compressed_leaves[j].children = result + + def _slice_second_grid_axis( + self, axis_name, polytopes, datacube, datacube_transformations, second_axis_vals, path + ) -> list[Qube]: + result = [] + polytopes_on_axis = find_polytopes_on_axis(axis_name, polytopes) + + for poly in polytopes_on_axis: + ax = datacube._axes[axis_name] + lower, upper, slice_axis_idx = poly.extents(axis_name) + + found_vals, _ = self.find_values_between(poly, ax, None, datacube, lower, upper, path) + + if len(found_vals) == 0: + continue + + # decide if axis should be compressed or not according to polytope + # NOTE: actually the second grid axis will always be compressed + + # if it's not compressed, need to separate into different nodes to append to the tree + + new_found_vals = [] + for found_val in found_vals: + found_val = self.remap_values(ax, found_val) + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + new_found_vals.append(str(found_val)) + else: + new_found_vals.append(found_val) + + # NOTE this was the last axis so we do not have children... + + result.extend( + [Qube.make_node(key=axis_name, values=QEnum(new_found_vals), metadata={"result": []}, children={})] + ) + return result + + def slice_tree(self, datacube, final_polys): + q = datacube.q + datacube_transformations = datacube.datacube_transformations + # create sub-qube without grid first + request_qube = Qube.make_root(self._slice(q, final_polys, datacube, datacube_transformations)) + # recompress this sub-qube + request_qube = request_qube.compress_w_leaf_attrs("sliced_polys") + # complete the qube with grid axes and return it + self.slice_grid_axes(request_qube, datacube, datacube_transformations) + return request_qube + + def build_tree(self, polytopes_to_slice, datacube): + groups, input_axes = group(polytopes_to_slice) + combinations = tensor_product(groups) + + sub_trees = [] + + for c in combinations: + final_polys = find_polytope_combinations(c) + + # Get the sliced Qube for each combi + r = self.slice_tree(datacube, final_polys) + sub_trees.append(r) + + final_tree = sub_trees[0] + + for sub_tree in sub_trees[1:]: + final_tree | sub_tree + return final_tree diff --git a/polytope_feature/engine/slicing_tools.py b/polytope_feature/engine/slicing_tools.py index d69ac4cee..0190f71bb 100644 --- a/polytope_feature/engine/slicing_tools.py +++ b/polytope_feature/engine/slicing_tools.py @@ -9,64 +9,6 @@ from ..utility.list_tools import argmax, argmin -def slice_in_two(polytope: ConvexPolytope, value, slice_axis_idx): - if polytope is None: - return (None, None) - else: - assert len(polytope.points[0]) == 2 - - x_lower, x_upper, _ = polytope.extents(polytope._axes[slice_axis_idx]) - - intersects = _find_intersects(polytope, slice_axis_idx, value) - - if len(intersects) == 0: - if x_upper <= value: - # The vertical slicing line does not intersect the polygon, which is on the left of the line - # So we keep the same polygon for now since it is unsliced - left_polygon = polytope - right_polygon = None - if value < x_lower: - left_polygon = None - right_polygon = polytope - else: - left_points = [p for p in polytope.points if p[slice_axis_idx] <= value] - right_points = [p for p in polytope.points if p[slice_axis_idx] >= value] - left_points.extend(intersects) - right_points.extend(intersects) - # find left polygon - try: - hull = scipy.spatial.ConvexHull(left_points) - vertices = hull.vertices - except scipy.spatial.qhull.QhullError as e: - if "less than" or "is flat" in str(e): - # NOTE: this happens when we slice a polygon that has a border which coincides with the quadrant - # line and we slice this additional border with the quadrant line again. - # This is not actually a polygon we want to consider so we ignore it - vertices = None - - if vertices is not None: - left_polygon = ConvexPolytope(polytope._axes, [left_points[i] for i in vertices]) - else: - left_polygon = None - - try: - hull = scipy.spatial.ConvexHull(right_points) - vertices = hull.vertices - except scipy.spatial.qhull.QhullError as e: - # NOTE: this happens when we slice a polygon that has a border which coincides with the quadrant - # line and we slice this additional border with the quadrant line again. - # This is not actually a polygon we want to consider so we ignore it - if "less than" or "is flat" in str(e): - vertices = None - - if vertices is not None: - right_polygon = ConvexPolytope(polytope._axes, [right_points[i] for i in vertices]) - else: - right_polygon = None - - return (left_polygon, right_polygon) - - def _find_intersects(polytope, slice_axis_idx, value): intersects = [] # Find all points above and below slice axis @@ -97,7 +39,6 @@ def _reduce_dimension(intersects, slice_axis_idx): def slice(polytope: ConvexPolytope, axis, value, slice_axis_idx): - # TODO: maybe these functions should go in the slicing tools? if polytope.is_flat: if value in chain(*polytope.points): intersects = [[value]] @@ -134,3 +75,61 @@ def slice(polytope: ConvexPolytope, axis, value, slice_axis_idx): return ConvexPolytope(axes, intersects) # Sliced result is simply the convex hull return ConvexPolytope(axes, [intersects[i] for i in vertices]) + + +def slice_in_two(polytope: ConvexPolytope, value, slice_axis_idx): + if polytope is None: + return (None, None) + else: + assert len(polytope.points[0]) == 2 + + x_lower, x_upper, _ = polytope.extents(polytope._axes[slice_axis_idx]) + + intersects = _find_intersects(polytope, slice_axis_idx, value) + + if len(intersects) == 0: + if x_upper <= value: + # The vertical slicing line does not intersect the polygon, which is on the left of the line + # So we keep the same polygon for now since it is unsliced + left_polygon = polytope + right_polygon = None + if value < x_lower: + left_polygon = None + right_polygon = polytope + else: + left_points = [p for p in polytope.points if p[slice_axis_idx] <= value] + right_points = [p for p in polytope.points if p[slice_axis_idx] >= value] + left_points.extend(intersects) + right_points.extend(intersects) + # find left polygon + try: + hull = scipy.spatial.ConvexHull(left_points) + vertices = hull.vertices + except scipy.spatial.qhull.QhullError as e: + if "less than" or "is flat" in str(e): + # NOTE: this happens when we slice a polygon that has a border which coincides with the quadrant + # line and we slice this additional border with the quadrant line again. + # This is not actually a polygon we want to consider so we ignore it + vertices = None + + if vertices is not None: + left_polygon = ConvexPolytope(polytope._axes, [left_points[i] for i in vertices]) + else: + left_polygon = None + + try: + hull = scipy.spatial.ConvexHull(right_points) + vertices = hull.vertices + except scipy.spatial.qhull.QhullError as e: + # NOTE: this happens when we slice a polygon that has a border which coincides with the quadrant + # line and we slice this additional border with the quadrant line again. + # This is not actually a polygon we want to consider so we ignore it + if "less than" or "is flat" in str(e): + vertices = None + + if vertices is not None: + right_polygon = ConvexPolytope(polytope._axes, [right_points[i] for i in vertices]) + else: + right_polygon = None + + return (left_polygon, right_polygon) diff --git a/polytope_feature/options.py b/polytope_feature/options.py index 402404823..44b0e8908 100644 --- a/polytope_feature/options.py +++ b/polytope_feature/options.py @@ -40,7 +40,6 @@ class MapperConfig(TransformationConfig): Latin1InRadians: Optional[float] = None Latin2InRadians: Optional[float] = None LaDInRadians: Optional[float] = None - # points: Optional[List[List[float]]] = None points: Optional[List[Tuple[float, float]]] = None uuid: Optional[str] = None @@ -84,6 +83,7 @@ class Config(ConfigModel): alternative_axes: Optional[List[GribJumpAxesConfig]] = [] grid_online_path: Optional[str] = "" grid_local_directory: Optional[str] = "" + datacube_axes: Optional[Dict[str, str]] = {} class PolytopeOptions(ABC): @@ -99,5 +99,14 @@ def get_polytope_options(options): alternative_axes = config_options.alternative_axes grid_online_path = config_options.grid_online_path grid_local_directory = config_options.grid_local_directory - - return (axis_config, compressed_axes_config, pre_path, alternative_axes, grid_online_path, grid_local_directory) + datacube_axes = config_options.datacube_axes + + return ( + axis_config, + compressed_axes_config, + pre_path, + alternative_axes, + grid_online_path, + grid_local_directory, + datacube_axes, + ) diff --git a/polytope_feature/polytope.py b/polytope_feature/polytope.py index 133e7528d..34affade6 100644 --- a/polytope_feature/polytope.py +++ b/polytope_feature/polytope.py @@ -9,6 +9,7 @@ from .engine.optimised_quadtree_slicer import OptimisedQuadTreeSlicer from .engine.point_in_polygon_slicer import PointInPolygonSlicer from .engine.quadtree_slicer import QuadTreeSlicer +from .engine.qubed_slicer import QubedSlicer from .options import PolytopeOptions from .shapes import ConvexPolytope, Product from .utility.combinatorics import group, tensor_product @@ -64,6 +65,7 @@ def __init__( engine_options = {} self.compressed_axes = [] + self.context = context ( @@ -73,6 +75,7 @@ def __init__( alternative_axes, grid_online_path, grid_local_directory, + datacube_axes, ) = PolytopeOptions.get_polytope_options(options) self.datacube = Datacube.create( datacube, @@ -82,11 +85,16 @@ def __init__( alternative_axes, grid_online_path, grid_local_directory, + datacube_axes, self.context, ) if engine_options == {}: for ax_name in self.datacube._axes.keys(): engine_options[ax_name] = "hullslicer" + if engine_options == "qubed": + engine_options = {} + for ax_name in self.datacube._axes.keys(): + engine_options[ax_name] = "qubed" self.engine_options = engine_options self.engines = self.create_engines() self.ax_is_unsliceable = {} @@ -110,6 +118,8 @@ def create_engines(self): if "optimised_point_in_polygon" in engine_types: points = self.datacube.find_point_cloud() engines["optimised_point_in_polygon"] = OptimisedPointInPolygonSlicer(points) + if "qubed" in engine_types: + engines["qubed"] = QubedSlicer() return engines def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): diff --git a/polytope_feature/utility/combinatorics.py b/polytope_feature/utility/combinatorics.py index 4040c40db..39335ca44 100644 --- a/polytope_feature/utility/combinatorics.py +++ b/polytope_feature/utility/combinatorics.py @@ -2,7 +2,7 @@ from collections import Counter from typing import List -from ..shapes import ConvexPolytope +from ..shapes import ConvexPolytope, Product from .exceptions import AxisNotFoundError, AxisOverdefinedError, AxisUnderdefinedError @@ -46,3 +46,29 @@ def validate_axes(actual_axes, test_axes): raise AxisNotFoundError(ax) return True + + +def find_polytopes_on_axis(axis_name, polytopes): + polytopes_on_axis = [] + for poly in polytopes: + if axis_name in poly._axes: + polytopes_on_axis.append(poly) + return polytopes_on_axis + + +def find_polytope_combinations(c): + new_c = [] + for combi in c: + if isinstance(combi, list): + new_c.extend(combi) + else: + new_c.append(combi) + # NOTE TODO: here some of the polys in new_c can be a Product shape instead of a ConvexPolytope + # -> need to go through the polytopes in new_c and replace the Products with their sub-ConvexPolytopes + final_polys = [] + for poly in new_c: + if isinstance(poly, Product): + final_polys.extend(poly.polytope()) + else: + final_polys.append(poly) + return final_polys diff --git a/polytope_feature/utility/metadata_handling.py b/polytope_feature/utility/metadata_handling.py new file mode 100644 index 000000000..6081af431 --- /dev/null +++ b/polytope_feature/utility/metadata_handling.py @@ -0,0 +1,14 @@ +import numpy as np + + +def flatten_metadata(value): + return value[0] if isinstance(value, np.ndarray) else value + + +def find_metadata(metadata_idx, compressed_metadata): + metadata = {} + for k, vs in compressed_metadata.items(): + metadata_depth = len(vs.shape) + relevant_metadata_dxs = metadata_idx[:metadata_depth] + metadata[k] = vs[relevant_metadata_dxs] + return metadata diff --git a/tests/test_ecmwf_oper_data_fdb.py b/tests/test_ecmwf_oper_data_fdb.py index 74454e27a..2c4374482 100644 --- a/tests/test_ecmwf_oper_data_fdb.py +++ b/tests/test_ecmwf_oper_data_fdb.py @@ -55,6 +55,7 @@ def test_fdb_datacube(self): Select("stream", ["oper"]), Select("type", ["fc"]), Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + # Box(["latitude", "longitude"], [0, 0], [80, 80]), ) self.fdbdatacube = gj.GribJump() self.API = Polytope( diff --git a/tests/test_qubed_extraction_engine.py b/tests/test_qubed_extraction_engine.py new file mode 100644 index 000000000..861dd9ed1 --- /dev/null +++ b/tests/test_qubed_extraction_engine.py @@ -0,0 +1,238 @@ +import time + +import pandas as pd +import pygribjump as gj +import requests +from qubed import Qube + +from polytope_feature.datacube.backends.qubed import QubedDatacube +from polytope_feature.datacube.datacube_axis import ( + PandasTimedeltaDatacubeAxis, + PandasTimestampDatacubeAxis, + UnsliceableDatacubeAxis, +) +from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import ( + NestedHealpixGridMapper, +) +from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import ( + TypeChangeStrToTimedelta, + TypeChangeStrToTimestamp, +) +from polytope_feature.engine.hullslicer import HullSlicer +from polytope_feature.engine.qubed_slicer import QubedSlicer +from polytope_feature.polytope import Polytope, Request +from polytope_feature.shapes import ConvexPolytope, Select + + +def find_relevant_subcube_from_request(request, qube_url): + # NOTE: final url we want is like: + # "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?class=d1&dataset=climate-dt" + + for shape in request.shapes: + if isinstance(shape, Select): + qube_url += shape.axis + "=" + for i, val in enumerate(shape.values): + qube_url += str(val) + if i < len(shape.values) - 1: + qube_url += "," + qube_url += "&" + # TODO: remove last unnecessary & + qube_url = qube_url[:-1] + return qube_url + + +fdb_tree = Qube.from_json( + requests.get("https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate-dt.json").json() +) + + +combi_polytopes = [ + ConvexPolytope(["param"], [["164"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=12, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [["ifs-nemo"]]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [["0001"]]), + ConvexPolytope(["experiment"], [["ssp3-7.0"]]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")], [pd.Timestamp("20220912")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]), +] + +# TODO: add lat and lon axes +datacube_axes = { + "param": UnsliceableDatacubeAxis(), + "time": PandasTimedeltaDatacubeAxis(), + "resolution": UnsliceableDatacubeAxis(), + "type": UnsliceableDatacubeAxis(), + "model": UnsliceableDatacubeAxis(), + "stream": UnsliceableDatacubeAxis(), + "realization": UnsliceableDatacubeAxis(), + "expver": UnsliceableDatacubeAxis(), + "experiment": UnsliceableDatacubeAxis(), + "generation": UnsliceableDatacubeAxis(), + "levtype": UnsliceableDatacubeAxis(), + "activity": UnsliceableDatacubeAxis(), + "dataset": UnsliceableDatacubeAxis(), + "class": UnsliceableDatacubeAxis(), + "date": PandasTimestampDatacubeAxis(), +} + +time_val = pd.Timedelta(hours=0, minutes=0) +date_val = pd.Timestamp("20300101T000000") + + +# TODO: add grid axis transformation +datacube_transformations = { + "time": TypeChangeStrToTimedelta("time", time_val), + "date": TypeChangeStrToTimestamp("date", date_val), + "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024), +} + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}, + "datacube_axes": { + "param": "UnsliceableDatacubeAxis", + "time": "PandasTimedeltaDatacubeAxis", + "resolution": "UnsliceableDatacubeAxis", + "type": "UnsliceableDatacubeAxis", + "model": "UnsliceableDatacubeAxis", + "stream": "UnsliceableDatacubeAxis", + "realization": "UnsliceableDatacubeAxis", + "expver": "UnsliceableDatacubeAxis", + "experiment": "UnsliceableDatacubeAxis", + "generation": "UnsliceableDatacubeAxis", + "levtype": "UnsliceableDatacubeAxis", + "activity": "UnsliceableDatacubeAxis", + "dataset": "UnsliceableDatacubeAxis", + "class": "UnsliceableDatacubeAxis", + "date": "PandasTimestampDatacubeAxis", + }, +} + + +request = Request( + ConvexPolytope(["param"], [[164]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [["ifs-nemo"]]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], [[1]]), + ConvexPolytope(["expver"], [["0001"]]), + ConvexPolytope(["experiment"], [["ssp3-7.0"]]), + ConvexPolytope(["generation"], [[1]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]), +) + +qubeddatacube = QubedDatacube(fdb_tree, datacube_axes, datacube_transformations) +slicer = QubedSlicer() +self_API = Polytope( + datacube=fdb_tree, + # engine=slicer, + engine_options="qubed", + options=options, +) +time1 = time.time() +# result = self_API.retrieve(request) +result = self_API.slice(self_API.datacube, request.polytopes()) +time2 = time.time() + +print(result) + +print("TIME EXTRACTING USING QUBED") +print(time2 - time1) + +# USING NORMAL GJ + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + ], + "pre_path": {"class": "d1", "model": "ifs-nemo", "resolution": "high"}, +} + +fdbdatacube = gj.GribJump() +slicer = HullSlicer() + + +request = Request( + ConvexPolytope(["param"], [["164"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [["ifs-nemo"]]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [["0001"]]), + ConvexPolytope(["experiment"], [["ssp3-7.0"]]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]), +) + +time3 = time.time() +# result = self_API.retrieve(request) +# result = self_API.slice(request.polytopes()) +time4 = time.time() + +print("TIME EXTRACTING USING GJ NORMAL") +print(time4 - time3) diff --git a/tests/test_qubed_extraction_engine_w_metadata.py b/tests/test_qubed_extraction_engine_w_metadata.py new file mode 100644 index 000000000..918dc5c26 --- /dev/null +++ b/tests/test_qubed_extraction_engine_w_metadata.py @@ -0,0 +1,209 @@ +import time + +import pandas as pd +import requests +from qubed import Qube + +from polytope_feature.datacube.backends.qubed import QubedDatacube +from polytope_feature.datacube.datacube_axis import ( + PandasTimedeltaDatacubeAxis, + PandasTimestampDatacubeAxis, + UnsliceableDatacubeAxis, +) +from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import ( + NestedHealpixGridMapper, +) +from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import ( + TypeChangeStrToTimedelta, + TypeChangeStrToTimestamp, +) +from polytope_feature.engine.qubed_slicer import QubedSlicer +from polytope_feature.polytope import Polytope, Request +from polytope_feature.shapes import ConvexPolytope, Select + + +def find_relevant_subcube_from_request(request, qube_url): + # NOTE: final url we want is like: + # "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?class=d1&dataset=climate-dt" + + for shape in request.shapes: + if isinstance(shape, Select): + qube_url += shape.axis + "=" + for i, val in enumerate(shape.values): + qube_url += str(val) + if i < len(shape.values) - 1: + qube_url += "," + qube_url += "&" + # TODO: remove last unnecessary & + qube_url = qube_url[:-1] + return qube_url + + +fdb_tree = Qube.from_json( + requests.get( + "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/extremes-dt_with_metadata.json" + ).json() +) + +# TODO: add lat and lon axes +datacube_axes = { + "param": UnsliceableDatacubeAxis(), + "time": PandasTimedeltaDatacubeAxis(), + "resolution": UnsliceableDatacubeAxis(), + "type": UnsliceableDatacubeAxis(), + "model": UnsliceableDatacubeAxis(), + "stream": UnsliceableDatacubeAxis(), + "realization": UnsliceableDatacubeAxis(), + "expver": UnsliceableDatacubeAxis(), + "experiment": UnsliceableDatacubeAxis(), + "generation": UnsliceableDatacubeAxis(), + "levtype": UnsliceableDatacubeAxis(), + "activity": UnsliceableDatacubeAxis(), + "dataset": UnsliceableDatacubeAxis(), + "class": UnsliceableDatacubeAxis(), + "date": PandasTimestampDatacubeAxis(), +} + +time_val = pd.Timedelta(hours=0, minutes=0) +date_val = pd.Timestamp("20300101T000000") + + +# TODO: add grid axis transformation +datacube_transformations = { + "time": TypeChangeStrToTimedelta("time", time_val), + "date": TypeChangeStrToTimestamp("date", date_val), + "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024), +} + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "octahedral", "resolution": 2560, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}, + "datacube_axes": { + "param": "UnsliceableDatacubeAxis", + "time": "PandasTimedeltaDatacubeAxis", + "type": "UnsliceableDatacubeAxis", + "stream": "UnsliceableDatacubeAxis", + "expver": "UnsliceableDatacubeAxis", + "levtype": "UnsliceableDatacubeAxis", + "dataset": "UnsliceableDatacubeAxis", + "class": "UnsliceableDatacubeAxis", + "date": "PandasTimestampDatacubeAxis", + "step": "IntDatacubeAxis", + }, +} + + +request = Request( + ConvexPolytope(["param"], [["167"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["stream"], [["oper"]]), + ConvexPolytope(["expver"], [["0001"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["dataset"], [["extremes-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["step"], [[1]]), + ConvexPolytope(["date"], [[pd.Timestamp("20240407")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]), +) + +print(fdb_tree) + +qubeddatacube = QubedDatacube(fdb_tree, datacube_axes, datacube_transformations) +slicer = QubedSlicer() +self_API = Polytope( + datacube=fdb_tree, + # engine=slicer, + engine_options="qubed", + options=options, +) +time1 = time.time() +result = self_API.retrieve(request) +time2 = time.time() + +print(result) + +print("TIME EXTRACTING USING QUBED") +print(time2 - time1) + +# USING NORMAL GJ + + +# options = { +# "axis_config": [ +# {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, +# {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, +# {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, +# {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, +# { +# "axis_name": "values", +# "transformations": [ +# {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} +# ], +# }, +# {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, +# {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, +# ], +# "compressed_axes_config": [ +# "longitude", +# ], +# "pre_path": {"class": "d1", "model": "ifs-nemo", "resolution": "high"}, +# } + +# fdbdatacube = gj.GribJump() +# slicer = HullSlicer() + + +# request = Request( +# ConvexPolytope(["param"], [["164"]]), +# ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), +# ConvexPolytope(["resolution"], [["high"]]), +# ConvexPolytope(["type"], [["fc"]]), +# ConvexPolytope(["model"], [["ifs-nemo"]]), +# ConvexPolytope(["stream"], [["clte"]]), +# ConvexPolytope(["realization"], ["1"]), +# ConvexPolytope(["expver"], [["0001"]]), +# ConvexPolytope(["experiment"], [["ssp3-7.0"]]), +# ConvexPolytope(["generation"], [["1"]]), +# ConvexPolytope(["levtype"], [["sfc"]]), +# ConvexPolytope(["activity"], [["scenariomip"]]), +# ConvexPolytope(["dataset"], [["climate-dt"]]), +# ConvexPolytope(["class"], [["d1"]]), +# ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), +# ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]), +# ) + +# time3 = time.time() +# result = self_API.retrieve(request) +# # result = self_API.slice(request.polytopes()) +# time4 = time.time() + +# print("TIME EXTRACTING USING GJ NORMAL") +# print(time4 - time3) diff --git a/tests/test_qubed_extraction_service.py b/tests/test_qubed_extraction_service.py new file mode 100644 index 000000000..45eb28198 --- /dev/null +++ b/tests/test_qubed_extraction_service.py @@ -0,0 +1,256 @@ +import time + +import pandas as pd +import requests +from qubed import Qube + +from polytope_feature.datacube.backends.qubed import QubedDatacube +from polytope_feature.datacube.datacube_axis import ( + PandasTimedeltaDatacubeAxis, + PandasTimestampDatacubeAxis, + UnsliceableDatacubeAxis, +) +from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import ( + NestedHealpixGridMapper, +) +from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import ( + TypeChangeStrToTimedelta, + TypeChangeStrToTimestamp, +) +from polytope_feature.engine.qubed_slicer import QubedSlicer +from polytope_feature.polytope import Polytope, Request +from polytope_feature.shapes import ConvexPolytope, Select + + +def find_relevant_subcube_from_request(request, qube_url): + # NOTE: final url we want is like: + # "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?class=d1&dataset=climate-dt" + + for shape in request.shapes: + if isinstance(shape, Select): + qube_url += shape.axis + "=" + for i, val in enumerate(shape.values): + qube_url += str(val) + if i < len(shape.values) - 1: + qube_url += "," + qube_url += "&" + # TODO: remove last unnecessary & + qube_url = qube_url[:-1] + return qube_url + + +def get_fdb_tree(request): + qube_url_start = "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?" + qube_url = find_relevant_subcube_from_request(request, qube_url_start) + fdb_tree = Qube.from_json(requests.get(qube_url).json()) + return fdb_tree + + +# fdb_tree = Qube.from_json( +# requests.get("https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate-dt.json").json() +# ) + +fdb_tree = Qube.from_json( + requests.get("https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate-dt.json").json() +) + + +# TODO: add lat and lon axes +datacube_axes = { + "param": UnsliceableDatacubeAxis(), + "time": PandasTimedeltaDatacubeAxis(), + "resolution": UnsliceableDatacubeAxis(), + "type": UnsliceableDatacubeAxis(), + "model": UnsliceableDatacubeAxis(), + "stream": UnsliceableDatacubeAxis(), + "realization": UnsliceableDatacubeAxis(), + "expver": UnsliceableDatacubeAxis(), + "experiment": UnsliceableDatacubeAxis(), + "generation": UnsliceableDatacubeAxis(), + "levtype": UnsliceableDatacubeAxis(), + "activity": UnsliceableDatacubeAxis(), + "dataset": UnsliceableDatacubeAxis(), + "class": UnsliceableDatacubeAxis(), + "date": PandasTimestampDatacubeAxis(), +} + +time_val = pd.Timedelta(hours=0, minutes=0) +date_val = pd.Timestamp("20300101T000000") + + +# TODO: add grid axis transformation +datacube_transformations = { + "time": TypeChangeStrToTimedelta("time", time_val), + "date": TypeChangeStrToTimestamp("date", date_val), + "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024), +} + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "expver", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "realization", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "generation", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}, + "datacube_axes": { + "param": "UnsliceableDatacubeAxis", + "time": "PandasTimedeltaDatacubeAxis", + "resolution": "UnsliceableDatacubeAxis", + "type": "UnsliceableDatacubeAxis", + "model": "UnsliceableDatacubeAxis", + "stream": "UnsliceableDatacubeAxis", + "realization": "IntDatacubeAxis", + "expver": "IntDatacubeAxis", + "experiment": "UnsliceableDatacubeAxis", + "generation": "IntDatacubeAxis", + "levtype": "UnsliceableDatacubeAxis", + "activity": "UnsliceableDatacubeAxis", + "dataset": "UnsliceableDatacubeAxis", + "class": "UnsliceableDatacubeAxis", + "date": "PandasTimestampDatacubeAxis", + }, +} + + +request = Request( + Select("param", ["165"]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + Select("model", ["icon"]), + Select("stream", ["clte"]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [["0001"]]), + ConvexPolytope(["experiment"], [["ssp3-7.0"]]), + Select("generation", [1]), + ConvexPolytope(["levtype"], [["sfc"]]), + Select("activity", ["scenariomip"]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20200908")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [5, 5], [0, 5]]), +) + +# NOTE: this qube was deleted + +path_to_qube = "../qubed/" +full_qube_path = path_to_qube + "tests/example_qubes/climate-dt_with_metadata.json" +fdb_tree = Qube.load(full_qube_path) + +# fdb_tree = Qube.from_json( +# requests.get( +# "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate-dt_with_metadata.json" +# ).json() +# ) + +qubeddatacube = QubedDatacube(fdb_tree, datacube_axes, datacube_transformations) +slicer = QubedSlicer() +self_API = Polytope( + datacube=fdb_tree, + engine=slicer, + options=options, +) +time1 = time.time() +result = self_API.retrieve(request) +time2 = time.time() + + +print("TIME EXTRACTING USING QUBED") +print(time2 - time1) + +# # USING NORMAL GJ + + +# options = { +# "axis_config": [ +# {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, +# {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, +# # { +# # "axis_name": "date", +# # "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], +# # }, +# {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, +# {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, +# { +# "axis_name": "values", +# "transformations": [ +# {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} +# ], +# }, +# {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, +# {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, +# ], +# "compressed_axes_config": [ +# "longitude", +# # "latitude", +# # "levtype", +# # "step", +# # "date", +# # "domain", +# # "expver", +# # "param", +# # "class", +# # "stream", +# # "type", +# ], +# "pre_path": {"class": "d1", "model": "ifs-nemo", "resolution": "high"}, +# } + +# fdbdatacube = gj.GribJump() +# slicer = HullSlicer() +# self_API = Polytope( +# datacube=fdbdatacube, +# engine=slicer, +# options=options, +# ) + + +# request = Request(ConvexPolytope(["param"], [["164"]]), +# ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), +# ConvexPolytope(["resolution"], [["high"]]), +# ConvexPolytope(["type"], [["fc"]]), +# ConvexPolytope(["model"], [['ifs-nemo']]), +# ConvexPolytope(["stream"], [["clte"]]), +# ConvexPolytope(["realization"], ["1"]), +# ConvexPolytope(["expver"], [['0001']]), +# ConvexPolytope(["experiment"], [['ssp3-7.0']]), +# ConvexPolytope(["generation"], [["1"]]), +# ConvexPolytope(["levtype"], [["sfc"]]), +# ConvexPolytope(["activity"], [["scenariomip"]]), +# ConvexPolytope(["dataset"], [["climate-dt"]]), +# ConvexPolytope(["class"], [["d1"]]), +# ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), +# ConvexPolytope(["latitude", "longitude"], [[0, 0], [5, 5], [0, 5]])) + +# time3 = time.time() +# result = self_API.retrieve(request) +# time4 = time.time() + +# print("TIME EXTRACTING USING GJ NORMAL") +# print(time4 - time3)