diff --git a/src/earthkit/data/indexing/tensor.py b/src/earthkit/data/indexing/tensor.py index 18fd2355..91b412d3 100644 --- a/src/earthkit/data/indexing/tensor.py +++ b/src/earthkit/data/indexing/tensor.py @@ -505,100 +505,6 @@ def _subset(self, indexes): ds = self.source[tuple(dataset_indexes)] return self.from_tensor(self, ds, coords) - def make_valid_datetime(self, dims_map, dtype="datetime64[ns]"): - # TODO: make it more general - # PW: TODO: make it more general - it could allow to use it when allow_holes=True - - for k in ["valid_time", "time.valid_datetime", "metadata.valid_time", "metadata.valid_datetime"]: - if k in self.user_coords: - import datetime - - return (k,), [datetime.datetime.fromisoformat(x) for x in self.user_coords[k]] - - # in the tensor the dims.coords are GRIB keys - # dims_map is a mapping from dim names to GRIB keys - DIM_ROLES = { - "forecast_reference_time": ( - "forecast_reference_time", - "time.forecast_reference_time", - "time.base_datetime", - "metadata.base_datetime", - "metadata.indexing_datetime", - "metadata.indexing_time", - ), - "step": ( - "step", - "time.step", - "metadata.step_timedelta", - "metadata.step", - "metadata.endStep", - "metadata.stepRange", - ), - "date": ("date", "metadata.dataDate"), - "time": ("time", "metadata.dataTime"), - } - - # map dim roles to keys available in the tensor - keys = {} - for k in DIM_ROLES: - for d in dims_map: - if d.name == k: - keys[k] = d.key - break - if k not in keys: - for d in self.user_dims: - if d in DIM_ROLES[k]: - keys[k] = d - break - - DIM_COMBINATIONS = [ - ["forecast_reference_time", "step"], - ["forecast_reference_time"], - ["date", "time", "step"], - ["date", "time"], - ["date", "step"], - ["time", "step"], - ["step"], - ] - - for dims in DIM_COMBINATIONS: - if all(d in keys for d in dims): - dims_step = [keys[d] for d in dims] - # use same dim order as in user_dims - dims = [d for d in self.user_dims if d in dims_step] - if len(dims) != len(dims_step): - continue - assert len(dims) == len(dims_step), f"{dims=} {dims_step=}" - other_dims = [d for d in self.user_dims if d not in dims] - - if other_dims: - import datetime - - import numpy as np - - other_coords = {k: next(iter(self.user_coords[k])) for k in other_dims if k in self.user_coords} - - vals = np.array( - [x for x in self.source.sel(**other_coords).get("time.valid_datetime")], - dtype=dtype, - ) - - shape = tuple([self.user_dims[d] for d in dims]) - return tuple(dims), vals.reshape(shape) - else: - import datetime - - import numpy as np - - vals = np.array( - [x for x in self.source.get("time.valid_datetime")], - dtype=dtype, - ) - - shape = tuple([self.user_dims[d] for d in dims]) - return tuple(dims), vals.reshape(shape) - return None, None - def __getstate__(self): r = {} r["source"] = self.source diff --git a/src/earthkit/data/xr_engine/builder.py b/src/earthkit/data/xr_engine/builder.py index 09e018fa..5fcb9092 100644 --- a/src/earthkit/data/xr_engine/builder.py +++ b/src/earthkit/data/xr_engine/builder.py @@ -337,19 +337,6 @@ def _make_field_coords(self): r[k] = xarray.Variable(dims, v, self.profile.attrs.coord_attrs.get(k, {})) return r - def collect_date_coords(self, tensor): - if ( - self.profile.add_valid_time_coord - and "valid_time" not in tensor.user_dims - and "valid_datetime" not in tensor.user_coords - and "valid_time" not in self.tensor_coords - ): - from .coord import Coord - - _dims, _vals = tensor.make_valid_datetime(self.dims) - if _dims is not None and _vals is not None: - self.tensor_coords["valid_time"] = Coord.make("valid_time", _vals, dims=_dims) - def collect_aux_coords(self): from .coord import Coord @@ -408,9 +395,6 @@ def collect_aux_coords(self): def build(self): if self.profile.allow_holes: - if self.profile.add_valid_time_coord: - raise NotImplementedError("add_valid_time_coord=True not yet supported when allow_holes=True") - global_tensor_dims, self.raw_global_tensor_coords, _ = self.prepare_tensor( self.ds, self.dims, "" ) @@ -433,6 +417,16 @@ def build(self): # From now on, self.tensor_coords is a mapping: # dimension_name->a Coord object + possibly the same for "valid_time" + # Inject valid_time as an auxiliary coordinate when requested and when valid_time is not a dimension + if ( + self.profile.add_valid_time_coord + and "valid_time" not in [d.name for d in self.dims] + and "time.valid_datetime" not in self.tensor_coords + ): + time_dim_names = self.profile.dims.active_time_dim_names + if time_dim_names: + self.profile.aux_coords.setdefault("valid_time", ("time.valid_datetime", time_dim_names)) + self.collect_aux_coords() # build variable and global attributes @@ -518,7 +512,6 @@ def pre_build_variable(self, ds_var, dims, name): var_dims.append(k) var_dims.extend(tensor.field_dims) - self.collect_date_coords(tensor) data_maker = self.build_values remapping = self.profile.remapping.build() diff --git a/src/earthkit/data/xr_engine/coord.py b/src/earthkit/data/xr_engine/coord.py index 999fbf56..dfda7463 100644 --- a/src/earthkit/data/xr_engine/coord.py +++ b/src/earthkit/data/xr_engine/coord.py @@ -83,7 +83,9 @@ def _to_datetime_list(vals): # datetime64 arrays are already in the required format if isinstance(vals, np.ndarray): if not np.issubdtype(vals.dtype, np.datetime64): - return to_datetime_list(vals.tolist()) + original_shape = vals.shape + flat = to_datetime_list(vals.flatten().tolist()) + return np.array(flat, dtype="datetime64[ns]").reshape(original_shape) else: return to_datetime_list(vals) diff --git a/src/earthkit/data/xr_engine/dim.py b/src/earthkit/data/xr_engine/dim.py index 714c839f..c29c3e00 100644 --- a/src/earthkit/data/xr_engine/dim.py +++ b/src/earthkit/data/xr_engine/dim.py @@ -98,6 +98,8 @@ def _get_metadata_keys(keys): DATETIME_KEYS = BASE_DATETIME_KEYS + VALID_DATETIME_KEYS +_TIME_RELATED_KEYS = set(DATE_KEYS + TIME_KEYS + STEP_KEYS + MONTH_KEYS + VALID_DATETIME_KEYS + BASE_DATETIME_KEYS) + KEYS = ( ENS_KEYS, LEVEL_KEYS, @@ -1015,6 +1017,31 @@ def make_coords(self): def to_list(self): return list(self.dims.values()) + @property + def active_time_dim_names(self): + """Return the names of the active time dimensions in dim order. + + Handles both the normal case (dims built via ``TimeDimBuilder``) and + the ``fixed_dims`` case where dim names are raw metadata keys. + """ + time_dim_names = set() + if not self.fixed_dims: + for role_name in self.time_dims: + if role_name in ALL_TIME_ROLES: + # Add the role-resolved name and key + _, name = self.dim_roles.role(role_name, raise_error=False) + if name is not None: + time_dim_names.add(name) + else: + # When fixed_dims are used, `self.time_dims` is irrelevant, and we check all `self.fixed_dims` + # for time-related keys. + for dim_name, dim_key in self.fixed_dims.items(): + if dim_key in _TIME_RELATED_KEYS: + time_dim_names.add(dim_name) + + # Return in dim order + return [d.name for d in self.dims.values() if d.active and d.name in time_dim_names] + def get_dims(self, names): r = [] for name in names: diff --git a/tests/xr_engine/test_xr_engine_add_valid_time_coord.py b/tests/xr_engine/test_xr_engine_add_valid_time_coord.py new file mode 100644 index 00000000..5438e12c --- /dev/null +++ b/tests/xr_engine/test_xr_engine_add_valid_time_coord.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +"""Tests for add_valid_time_coord=True using the aux_coords-based implementation.""" + +import numpy as np +import pytest + +from earthkit.data import from_source + +# Expected valid_time values for pl.grib with 4 forecast_reference_times x 2 steps +VALID_TIME_FRT_STEP = np.array( + [ + ["2024-06-03T00:00:00", "2024-06-03T06:00:00"], + ["2024-06-03T12:00:00", "2024-06-03T18:00:00"], + ["2024-06-04T00:00:00", "2024-06-04T06:00:00"], + ["2024-06-04T12:00:00", "2024-06-04T18:00:00"], + ], + dtype="datetime64[ns]", +) + +# Expected valid_time values for date x time x step (2x2x2) +VALID_TIME_DATE_TIME_STEP = np.array( + [ + [ + ["2024-06-03T00:00:00", "2024-06-03T06:00:00"], + ["2024-06-03T12:00:00", "2024-06-03T18:00:00"], + ], + [ + ["2024-06-04T00:00:00", "2024-06-04T06:00:00"], + ["2024-06-04T12:00:00", "2024-06-04T18:00:00"], + ], + ], + dtype="datetime64[ns]", +) + + +@pytest.fixture(scope="session") +def pl_fl(): + return from_source("sample", "pl.grib").to_fieldlist() + + +# ------------------------------------------------------------------------- +# dim_name_from_role_name=True vs False +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +@pytest.mark.parametrize("dim_name_from_role_name", [True, False]) +def test_dim_name_from_role_name(pl_fl, lazy_load, allow_holes, dim_name_from_role_name): + """valid_time aux coord should work regardless of dim_name_from_role_name.""" + ds = pl_fl.to_xarray( + profile="earthkit", + add_valid_time_coord=True, + dim_name_from_role_name=dim_name_from_role_name, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ds.coords["valid_time"].dims == ("forecast_reference_time", "step") + assert ds.coords["valid_time"].shape == (4, 2) + np.testing.assert_array_equal(ds.coords["valid_time"].values, VALID_TIME_FRT_STEP) + + +# ------------------------------------------------------------------------- +# Different time_dims variants +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_time_dims_date_time_step(pl_fl, lazy_load, allow_holes): + """time_dims=['date', 'time', 'step'] produces 3D valid_time.""" + ds = pl_fl.to_xarray( + profile="earthkit", + time_dims=["date", "time", "step"], + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ds.coords["valid_time"].dims == ("date", "time", "step") + assert ds.coords["valid_time"].shape == (2, 2, 2) + np.testing.assert_array_equal(ds.coords["valid_time"].values, VALID_TIME_DATE_TIME_STEP) + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_time_dims_valid_time_no_aux(pl_fl, lazy_load, allow_holes): + """When time_dims='valid_time', valid_time is a dimension, not an aux coord.""" + ds = pl_fl.to_xarray( + profile="earthkit", + time_dims="valid_time", + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + # valid_time should be a dimension, not an auxiliary coordinate + assert "valid_time" in ds.sizes + assert ds.sizes["valid_time"] == 8 + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_time_dims_frt_only(pl_fl, lazy_load, allow_holes): + """time_dims=['forecast_reference_time'] with step squeezed out => 1D valid_time.""" + # Select single step to avoid step dimension + fl_single_step = pl_fl.sel({"metadata.step": 0}) + ds = fl_single_step.to_xarray( + profile="earthkit", + time_dims=["forecast_reference_time"], + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ds.coords["valid_time"].dims == ("forecast_reference_time",) + assert ds.coords["valid_time"].shape == (4,) + + +# ------------------------------------------------------------------------- +# Custom dim_roles (GRIB metadata keys) +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dim_name_from_role_name", [True, False]) +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_custom_dim_roles(pl_fl, lazy_load, allow_holes, dim_name_from_role_name): + """Custom dim_roles mapping time roles to GRIB metadata keys.""" + ds = pl_fl.to_xarray( + profile="earthkit", + add_valid_time_coord=True, + dim_roles={ + "forecast_reference_time": "metadata.base_datetime", + "step": "metadata.endStep", + }, + lazy_load=lazy_load, + allow_holes=allow_holes, + dim_name_from_role_name=dim_name_from_role_name, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ( + ds.coords["valid_time"].dims == ("forecast_reference_time", "step") + if dim_name_from_role_name + else ("metadata.base_datetime", "metadata.endStep") + ) + assert ds.coords["valid_time"].shape == (4, 2) + np.testing.assert_array_equal(ds.coords["valid_time"].values, VALID_TIME_FRT_STEP) + + +# ------------------------------------------------------------------------- +# fixed_dims with mono_variable=True and False +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_fixed_dims_mono_variable_true(pl_fl, lazy_load, allow_holes): + """fixed_dims with mono_variable=True.""" + ds = pl_fl.to_xarray( + fixed_dims=[ + "parameter.variable", + "time.forecast_reference_time", + "time.step", + "vertical.level", + ], + mono_variable=True, + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ds.coords["valid_time"].dims == ("forecast_reference_time", "step") + assert ds.coords["valid_time"].shape == (4, 2) + np.testing.assert_array_equal(ds.coords["valid_time"].values, VALID_TIME_FRT_STEP) + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_fixed_dims_mono_variable_false(pl_fl, lazy_load, allow_holes): + """fixed_dims with mono_variable=False (default).""" + ds = pl_fl.to_xarray( + fixed_dims=[ + "time.forecast_reference_time", + "metadata.endStep", + "vertical.level", + ], + mono_variable=False, + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + assert ds.coords["valid_time"].dims == ("forecast_reference_time", "endStep") + assert ds.coords["valid_time"].shape == (4, 2) + np.testing.assert_array_equal(ds.coords["valid_time"].values, VALID_TIME_FRT_STEP) + + +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_fixed_dims_different_order(pl_fl, lazy_load, allow_holes): + """fixed_dims with time dims in reversed order.""" + ds = pl_fl.to_xarray( + fixed_dims=[ + "vertical.level", + "metadata.endStep", + "time.forecast_reference_time", + ], + add_valid_time_coord=True, + lazy_load=lazy_load, + allow_holes=allow_holes, + decode_times=False, + decode_timedelta=False, + ) + + assert "valid_time" in ds.coords + assert "valid_time" not in ds.sizes + # Dims should follow the fixed_dims order + assert ds.coords["valid_time"].dims == ("endStep", "forecast_reference_time") + assert ds.coords["valid_time"].shape == (2, 4) + + +# ------------------------------------------------------------------------- +# Edge case: add_valid_time_coord=False should not add it +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("allow_holes", [False, True]) +def test_add_valid_time_coord_false(pl_fl, allow_holes): + """add_valid_time_coord=False should not add valid_time as aux coord.""" + ds = pl_fl.to_xarray( + profile="earthkit", + add_valid_time_coord=False, + allow_holes=allow_holes, + decode_times=False, + ) + + assert "valid_time" not in ds.coords