diff --git a/docs/source/how-tos/xr_engine/index.rst b/docs/source/how-tos/xr_engine/index.rst index 169f4f77..6a5311bd 100644 --- a/docs/source/how-tos/xr_engine/index.rst +++ b/docs/source/how-tos/xr_engine/index.rst @@ -23,6 +23,7 @@ Xarray engine xarray_engine_dims_as_attrs.ipynb xarray_engine_extra_dims.ipynb xarray_engine_remapping.ipynb + xarray_engine_aux_coords.ipynb xarray_engine_holes.ipynb xarray_engine_chunks.ipynb xarray_engine_chunks_on_dask_cluster.ipynb diff --git a/docs/source/how-tos/xr_engine/xarray_engine_aux_coords.ipynb b/docs/source/how-tos/xr_engine/xarray_engine_aux_coords.ipynb new file mode 100644 index 00000000..8165ac2f --- /dev/null +++ b/docs/source/how-tos/xr_engine/xarray_engine_aux_coords.ipynb @@ -0,0 +1,1902 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c2feafcc-430b-4718-983f-554e55dcd54a", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Xarray engine: auxiliary coordinates" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1a6e355d-3fbf-4d92-b32f-a9d7e770f9db", + "metadata": { + "editable": true, + "scrolled": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import earthkit.data as ekd" + ] + }, + { + "cell_type": "markdown", + "id": "f3117255-6cc1-4cf2-ba91-dc3134973b91", + "metadata": {}, + "source": [ + "### Basic example" + ] + }, + { + "cell_type": "markdown", + "id": "e96e8da8-8219-4a79-92ad-515606816919", + "metadata": {}, + "source": [ + "First, we get some GRIB data containing control and perturbed forecasts." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a8f1d8b7-4a3b-4186-a827-17dbb16eaa2b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
parameter.variabletime.valid_datetimetime.base_datetimetime.stepvertical.levelvertical.level_typeensemble.membergeography.grid_typemetadata.dataType
0t2024-06-03 00:00:002024-06-030 days 00:00:00500pressure0regular_llcf
1t2024-06-03 06:00:002024-06-030 days 06:00:00500pressure0regular_llcf
2t2024-06-03 00:00:002024-06-030 days 00:00:00500pressure1regular_llpf
3t2024-06-03 00:00:002024-06-030 days 00:00:00500pressure2regular_llpf
4t2024-06-03 06:00:002024-06-030 days 06:00:00500pressure1regular_llpf
5t2024-06-03 06:00:002024-06-030 days 06:00:00500pressure2regular_llpf
\n", + "
" + ], + "text/plain": [ + " parameter.variable time.valid_datetime time.base_datetime time.step \\\n", + "0 t 2024-06-03 00:00:00 2024-06-03 0 days 00:00:00 \n", + "1 t 2024-06-03 06:00:00 2024-06-03 0 days 06:00:00 \n", + "2 t 2024-06-03 00:00:00 2024-06-03 0 days 00:00:00 \n", + "3 t 2024-06-03 00:00:00 2024-06-03 0 days 00:00:00 \n", + "4 t 2024-06-03 06:00:00 2024-06-03 0 days 06:00:00 \n", + "5 t 2024-06-03 06:00:00 2024-06-03 0 days 06:00:00 \n", + "\n", + " vertical.level vertical.level_type ensemble.member geography.grid_type \\\n", + "0 500 pressure 0 regular_ll \n", + "1 500 pressure 0 regular_ll \n", + "2 500 pressure 1 regular_ll \n", + "3 500 pressure 2 regular_ll \n", + "4 500 pressure 1 regular_ll \n", + "5 500 pressure 2 regular_ll \n", + "\n", + " metadata.dataType \n", + "0 cf \n", + "1 cf \n", + "2 pf \n", + "3 pf \n", + "4 pf \n", + "5 pf " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_fl = ekd.from_source(\"sample\", \"ens_cf_pf.grib\").to_fieldlist()\n", + "ds_fl.ls(extra_keys=[\"metadata.dataType\"])" + ] + }, + { + "cell_type": "markdown", + "id": "db15e80f-4beb-441d-b334-9fc1a300d1af", + "metadata": {}, + "source": [ + "Using the Xarray engine keyword `aux_coords` one can declare an auxiliary coordinate `\"forecast_type\"` whose values are derived from the GRIB metadata key `\"dataType\"`and depend on a single dimension `\"member\"`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8856dcff-31ec-4c39-8725-a6f5e37e1065", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 33kB\n",
+       "Dimensions:        (member: 3, step: 2, latitude: 19, longitude: 36)\n",
+       "Coordinates:\n",
+       "  * member         (member) <U1 12B '0' '1' '2'\n",
+       "    forecast_type  (member) <U2 24B 'cf' 'pf' 'pf'\n",
+       "  * step           (step) timedelta64[ns] 16B 00:00:00 06:00:00\n",
+       "  * latitude       (latitude) float64 152B 90.0 80.0 70.0 ... -70.0 -80.0 -90.0\n",
+       "  * longitude      (longitude) float64 288B 0.0 10.0 20.0 ... 330.0 340.0 350.0\n",
+       "Data variables:\n",
+       "    t              (member, step, latitude, longitude) float64 33kB 250.2 ......\n",
+       "Attributes:\n",
+       "    Conventions:  CF-1.8\n",
+       "    institution:  ECMWF
" + ], + "text/plain": [ + " Size: 33kB\n", + "Dimensions: (member: 3, step: 2, latitude: 19, longitude: 36)\n", + "Coordinates:\n", + " * member (member) \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metadata.shortNamemetadata.dataDatemetadata.dataTimemetadata.stepRangemetadata.dataTypemetadata.quantilemetadata.numbermetadata.numberOfForecastsInEnsemble
02tp2025120900-168pd1:313
12tp2025120900-168pd1:515
22tp2025120900-168pd1:10110
32tp2025120900-168pd2:323
42tp2025120900-168pd2:525
52tp2025120900-168pd2:10210
62tp2025120900-168pd3:333
72tp2025120900-168pd3:535
82tp2025120900-168pd3:10310
92tp2025120900-168pd4:545
102tp2025120900-168pd4:10410
112tp2025120900-168pd5:555
122tp2025120900-168pd5:10510
132tp2025120900-168pd6:10610
142tp2025120900-168pd7:10710
152tp2025120900-168pd8:10810
162tp2025120900-168pd9:10910
172tp2025120900-168pd10:101010
\n", + "" + ], + "text/plain": [ + " metadata.shortName metadata.dataDate metadata.dataTime \\\n", + "0 2tp 20251209 0 \n", + "1 2tp 20251209 0 \n", + "2 2tp 20251209 0 \n", + "3 2tp 20251209 0 \n", + "4 2tp 20251209 0 \n", + "5 2tp 20251209 0 \n", + "6 2tp 20251209 0 \n", + "7 2tp 20251209 0 \n", + "8 2tp 20251209 0 \n", + "9 2tp 20251209 0 \n", + "10 2tp 20251209 0 \n", + "11 2tp 20251209 0 \n", + "12 2tp 20251209 0 \n", + "13 2tp 20251209 0 \n", + "14 2tp 20251209 0 \n", + "15 2tp 20251209 0 \n", + "16 2tp 20251209 0 \n", + "17 2tp 20251209 0 \n", + "\n", + " metadata.stepRange metadata.dataType metadata.quantile metadata.number \\\n", + "0 0-168 pd 1:3 1 \n", + "1 0-168 pd 1:5 1 \n", + "2 0-168 pd 1:10 1 \n", + "3 0-168 pd 2:3 2 \n", + "4 0-168 pd 2:5 2 \n", + "5 0-168 pd 2:10 2 \n", + "6 0-168 pd 3:3 3 \n", + "7 0-168 pd 3:5 3 \n", + "8 0-168 pd 3:10 3 \n", + "9 0-168 pd 4:5 4 \n", + "10 0-168 pd 4:10 4 \n", + "11 0-168 pd 5:5 5 \n", + "12 0-168 pd 5:10 5 \n", + "13 0-168 pd 6:10 6 \n", + "14 0-168 pd 7:10 7 \n", + "15 0-168 pd 8:10 8 \n", + "16 0-168 pd 9:10 9 \n", + "17 0-168 pd 10:10 10 \n", + "\n", + " metadata.numberOfForecastsInEnsemble \n", + "0 3 \n", + "1 5 \n", + "2 10 \n", + "3 3 \n", + "4 5 \n", + "5 10 \n", + "6 3 \n", + "7 5 \n", + "8 10 \n", + "9 5 \n", + "10 10 \n", + "11 5 \n", + "12 10 \n", + "13 10 \n", + "14 10 \n", + "15 10 \n", + "16 10 \n", + "17 10 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_fl2.ls(\n", + " keys=[\n", + " \"metadata.shortName\",\n", + " \"metadata.dataDate\",\n", + " \"metadata.dataTime\",\n", + " \"metadata.stepRange\",\n", + " \"metadata.dataType\",\n", + " \"metadata.quantile\",\n", + " \"metadata.number\",\n", + " \"metadata.numberOfForecastsInEnsemble\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3ea48ddf-3fc0-455b-b381-e3c8b2a3debe", + "metadata": {}, + "source": [ + "Note that, in this context, the usual meaning of the GRIB metadata key ``\"number\"`` (and the related ``\"numberOfForecastsInEnsemble\"``) is overridden by ``\"quantile\"``. As a result, the ensemble dimension normally derived from ``\"number\"`` is no longer applicable.\n", + "\n", + "For this reason, we must:\n", + "- declare the GRIB metadata key ``\"quantile\"`` as an extra dimension, and\n", + "- remove the predefined ensemble dimension ``\"number\"``, since it would otherwise conflict with the ``\"quantile\"`` dimension.\n", + "\n", + "Still, it might be useful to keep the information carried by ``\"number\"`` and ``\"numberOfForecastsInEnsemble\"`` as auxiliary coordinates." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cd65d5ce-b511-4c12-88f7-f64f5b0c18e7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 13kB\n",
+       "Dimensions:        (quantile: 18, latitude: 7, longitude: 12)\n",
+       "Coordinates:\n",
+       "  * quantile       (quantile) <U5 360B '10:10' '1:10' '1:3' ... '8:10' '9:10'\n",
+       "    quantile_rank  (quantile) <U2 144B '10' '1' '1' '1' '2' ... '6' '7' '8' '9'\n",
+       "    nquantiles     (quantile) int64 144B 10 10 3 5 10 3 5 ... 5 10 5 10 10 10 10\n",
+       "  * latitude       (latitude) float64 56B 90.0 60.0 30.0 0.0 -30.0 -60.0 -90.0\n",
+       "  * longitude      (longitude) float64 96B 0.0 30.0 60.0 ... 270.0 300.0 330.0\n",
+       "Data variables:\n",
+       "    2tp            (quantile, latitude, longitude) float64 12kB 13.37 ... 0.0\n",
+       "Attributes:\n",
+       "    Conventions:  CF-1.8\n",
+       "    institution:  ECMWF
" + ], + "text/plain": [ + " Size: 13kB\n", + "Dimensions: (quantile: 18, latitude: 7, longitude: 12)\n", + "Coordinates:\n", + " * quantile (quantile) a Coord object + possibly the same for "valid_time" + + self.collect_aux_coords() + # build variable and global attributes xr_attrs = self.profile.attrs.builder.build(self.ds, var_builders, rename=True) xr_coords = self.coords() diff --git a/src/earthkit/data/xr_engine/diff.py b/src/earthkit/data/xr_engine/diff.py index 7ca07cf9..f6436b87 100644 --- a/src/earthkit/data/xr_engine/diff.py +++ b/src/earthkit/data/xr_engine/diff.py @@ -7,7 +7,6 @@ # nor does it submit to any jurisdiction. # -import datetime import logging import math @@ -79,26 +78,14 @@ class ListDiff: @staticmethod def _compare(v1, v2): - if isinstance(v1, int) and isinstance(v2, int): - return v1 == v2, ListDiff.VALUE_DIFF - elif isinstance(v1, float) and isinstance(v2, float): + if isinstance(v1, float) and isinstance(v2, float): return math.isclose(v1, v2, rel_tol=1e-9), ListDiff.VALUE_DIFF - elif isinstance(v1, str) and isinstance(v2, str): - return v1 == v2, ListDiff.VALUE_DIFF - elif isinstance(v1, datetime.datetime) and isinstance(v2, datetime.datetime): - return v1 == v2, ListDiff.VALUE_DIFF - elif isinstance(v1, datetime.date) and isinstance(v2, datetime.date): - return v1 == v2, ListDiff.VALUE_DIFF - elif isinstance(v1, datetime.time) and isinstance(v2, datetime.time): - return v1 == v2, ListDiff.VALUE_DIFF - elif isinstance(v1, datetime.timedelta) and isinstance(v2, datetime.timedelta): - return v1 == v2, ListDiff.VALUE_DIFF elif v1 is None and v2 is None: return True, ListDiff.VALUE_DIFF elif type(v1) is not type(v2): return False, ListDiff.TYPE_DIFF else: - raise ValueError(f"Unsupported type: {type(v1)}") + return v1 == v2, ListDiff.VALUE_DIFF @staticmethod def diff(vals1, vals2, name=str()): diff --git a/src/earthkit/data/xr_engine/engine.py b/src/earthkit/data/xr_engine/engine.py index dbd788c8..97d8e397 100644 --- a/src/earthkit/data/xr_engine/engine.py +++ b/src/earthkit/data/xr_engine/engine.py @@ -39,6 +39,7 @@ def open_dataset( add_valid_time_coord=None, decode_times=None, decode_timedelta=None, + aux_coords=None, add_geo_coords=None, attrs_mode=None, attrs=None, @@ -256,6 +257,9 @@ def open_dataset( will have the attribute "units" appropriately set (to "minutes", "hours", etc.). If None (default), assume the same value of ``decode_times`` unless the ``profile`` overwrites it. + aux_coords: dict, None + Mapping from an auxiliary coordinate label to a tuple: + ``(metadata_key: str, dataset_dimension(s): str or iterable of str)``. The default value is None. add_geo_coords: bool, None If True, add geographic coordinates to the dataset when field values are represented by a single "values" dimension. Its default value (None) expands @@ -312,8 +316,8 @@ def open_dataset( Define fill values to metadata keys. Default is None. remapping: dict, None Define new metadata keys for indexing. Any key provided in ``remapping`` may be referenced - when specifying options such as ``variable_key``, ``extra_dims``, ``ensure_dims``, and others. - Default is None. + when specifying options such as ``variable_key``, ``extra_dims``, ``ensure_dims``, ``aux_coords`` + and others. Default is None. lazy_load: bool, None If True, the resulting Dataset will load data lazily from the underlying data source. If False, a DataSet holding all the data in memory diff --git a/src/earthkit/data/xr_engine/profile.py b/src/earthkit/data/xr_engine/profile.py index 459e0852..93738a19 100644 --- a/src/earthkit/data/xr_engine/profile.py +++ b/src/earthkit/data/xr_engine/profile.py @@ -40,6 +40,21 @@ def add(self, remapping, patch=None): self.patch.update(patch) +class AuxCoords(dict): + def __init__(self, aux_coords): + super().__init__() + for coord_label, key_dims in ensure_dict(aux_coords).items(): + try: + key, dims = key_dims + dims = ensure_iterable(dims) + except Exception: + raise ValueError( + f"Auxiliary coordinate {coord_label} has invalid specification: got {key_dims} " + f"while a tuple (, ) is expected" + ) + self[coord_label] = (key, dims) + + class ProfileConf: def __init__(self): self._conf = {} @@ -184,7 +199,7 @@ def check(self, profile): class Profile: - USER_ONLY_OPTIONS = ["remapping", "patch", "fill_metadata"] + USER_ONLY_OPTIONS = ["remapping", "patch", "fill_metadata", "aux_coords"] DEFAULT_PROFILE_NAME = "earthkit" def __init__( @@ -213,6 +228,7 @@ def __init__( patch[k] = v self.remapping = RemappingBuilder(kwargs.pop("remapping", None), patch) + self.aux_coords = AuxCoords(kwargs.pop("aux_coords", None)) # variables mono_variable = kwargs.pop("mono_variable") diff --git a/tests/xr_engine/test_xr_engine_aux_coords.py b/tests/xr_engine/test_xr_engine_aux_coords.py new file mode 100644 index 00000000..1e05d4c6 --- /dev/null +++ b/tests/xr_engine/test_xr_engine_aux_coords.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import numpy as np +import pytest + +from earthkit.data import from_source +from earthkit.data.utils.testing import earthkit_remote_test_data_file + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_simple(lazy_load, allow_holes): + """aux_coords with a single metadata key mapped to a single dimension.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds = fl.to_xarray( + aux_coords={"centre": ("metadata.centre", "forecast_reference_time")}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "centre" in ds.coords + assert "centre" not in ds.sizes + assert ds["centre"].dims == ("forecast_reference_time",) + assert (ds["centre"] == "ecmf").all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_multi_dim(lazy_load, allow_holes): + """aux_coords mapped to multiple dimensions.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds = fl.to_xarray( + aux_coords={"centre": ("metadata.centre", ("forecast_reference_time", "step"))}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "centre" in ds.coords + assert "centre" not in ds.sizes + assert ds["centre"].dims == ("forecast_reference_time", "step") + assert (ds["centre"] == "ecmf").all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_with_remapping(lazy_load, allow_holes): + """aux_coords using a remapped key.""" + ds0 = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds = ds0.to_xarray( + remapping={"centre_class": "{metadata.centre}_{metadata.class}"}, + aux_coords={"centre_class": ("centre_class", ("forecast_reference_time", "step"))}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "centre_class" in ds.coords + assert "centre_class" not in ds.sizes + assert ds["centre_class"].dims == ("forecast_reference_time", "step") + assert (ds["centre_class"] == "ecmf_od").all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_multiple_coords(lazy_load, allow_holes): + """Multiple aux_coords specified at once.""" + ds0 = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds = ds0.to_xarray( + profile="mars", + aux_coords={ + "centre": ("metadata.centre", "forecast_reference_time"), + "class_coord": ("metadata.class", "forecast_reference_time"), + }, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + assert "centre" in ds.coords + assert "class_coord" in ds.coords + assert "centre" not in ds.sizes + assert "class_coord" not in ds.sizes + assert ds["centre"].dims == ("forecast_reference_time",) + assert ds["class_coord"].dims == ("forecast_reference_time",) + assert (ds["centre"] == "ecmf").all() + assert (ds["class_coord"] == "od").all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +def test_xr_engine_aux_coords_unknown_dim(allow_holes): + """aux_coords referencing a non-existent dimension should raise.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + with pytest.raises(AssertionError, match="unknown dimension"): + fl.to_xarray( + aux_coords={"centre": ("metadata.centre", "nonexistent_dim")}, + allow_holes=allow_holes, + ) + + +def test_xr_engine_aux_coords_invalid_spec(): + """aux_coords with invalid tuple specification should raise ValueError.""" + from earthkit.data.xr_engine.profile import AuxCoords + + with pytest.raises(ValueError, match="invalid specification"): + AuxCoords({"bad": "not_a_tuple"}) + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +def test_xr_engine_aux_coords_empty(allow_holes): + """Empty aux_coords should produce no extra coordinates.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds_no_aux = fl.to_xarray(aux_coords={}, allow_holes=allow_holes) + ds_none = fl.to_xarray(allow_holes=allow_holes) + + assert set(ds_no_aux.coords) == set(ds_none.coords) + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_drop_dim_as_aux(lazy_load, allow_holes): + """Drop a dimension and re-add it as an auxiliary coordinate.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + + ds = fl.to_xarray( + time_dims="valid_time", + aux_coords={"step": ("time.step", ("valid_time",))}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + # step should be a coordinate but not a dimension + assert "step" in ds.coords + assert "step" not in ds.sizes + assert "valid_time" in ds.coords["step"].dims + assert (ds.coords["step"] == np.array([0, 6] * 4, dtype="m8[h]")).all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_with_mono_variable(lazy_load, allow_holes): + """aux_coords combined with mono_variable mode.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/pl_small.grib")).to_fieldlist() + ds = fl.to_xarray( + fixed_dims=["parameter.variable", "time.forecast_reference_time", "time.step", "vertical.level"], + mono_variable=True, + aux_coords={"metadata_paramId": ("metadata.paramId", "parameter.variable")}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + assert "metadata_paramId" in ds.coords + assert "metadata_paramId" not in ds.sizes + assert (ds["metadata_paramId"] == [157, 130]).all() + + +@pytest.mark.cache +@pytest.mark.parametrize("allow_holes", [False, True]) +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_aux_coords_conflicting_values_strict(lazy_load, allow_holes): + """With strict=True, conflicting aux_coord values for same dim coords should raise.""" + fl = from_source("url", earthkit_remote_test_data_file("xr_engine/level/mixed_pl_ml_small.grib")).to_fieldlist() + + with pytest.raises(AssertionError, match="Conflicting values"): + _ = fl.to_xarray( + strict=True, + level_dim_mode="level_and_type", + aux_coords={"levtype": ("metadata.levtype", "forecast_reference_time")}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + + ds = fl.to_xarray( + strict=True, + level_dim_mode="level_and_type", + aux_coords={"levtype": ("metadata.levtype", "level_and_type")}, + lazy_load=lazy_load, + allow_holes=allow_holes, + ) + assert "levtype" in ds.coords + assert "levtype" not in ds.sizes + assert (ds["levtype"] == ["ml", "pl", "pl", "ml"]).all()