diff --git a/src/earthkit/data/xr_engine/coord.py b/src/earthkit/data/xr_engine/coord.py index dfda7463..59f46e9f 100644 --- a/src/earthkit/data/xr_engine/coord.py +++ b/src/earthkit/data/xr_engine/coord.py @@ -21,12 +21,12 @@ class Coord: LOOKUP_NAME = None - def __init__(self, name, vals, dims=None, ds=None): - self.name = name + def __init__(self, name, vals, dims=None, simple_name=None, ds=None): + self.name = simple_name if simple_name is not None else name self.vals = vals self.dims = dims if not self.dims: - self.dims = (self.name,) + self.dims = (name,) @staticmethod def make(name, *args, **kwargs): @@ -69,9 +69,6 @@ def attrs(self, profile): attrs = profile.attrs.coord_attrs.get(self.name, {}) if not attrs and self.LOOKUP_NAME: attrs = profile.attrs.coord_attrs.get(self.LOOKUP_NAME, {}) - # PW: TODO: need to replace this somehow? - # if profile.add_earthkit_attrs and self.component: - # attrs["_earthkit"] = {"keys": self.component[0], "values": self.component[1]} return attrs @staticmethod @@ -98,7 +95,6 @@ def convert(self, profile): def attrs(self, profile): attrs = profile.attrs.coord_attrs.get(self.name, {}) - # PW: TODO: need to replace this somehow? # if self.component: # attrs["_earthkit"] = {"keys": self.component[0]} return attrs diff --git a/src/earthkit/data/xr_engine/dim.py b/src/earthkit/data/xr_engine/dim.py index c29c3e00..7d8dfb3e 100644 --- a/src/earthkit/data/xr_engine/dim.py +++ b/src/earthkit/data/xr_engine/dim.py @@ -268,7 +268,7 @@ def as_coord(self, values, source): if key not in self.coords: from .coord import Coord - self.coords[key] = Coord.make(key, values, ds=source) + self.coords[key] = Coord.make(key, values, simple_name=self.get_simple_name(), ds=source) return key, self.coords[key] def remapping_keys(self): diff --git a/src/earthkit/data/xr_engine/profiles/defaults.yaml b/src/earthkit/data/xr_engine/profiles/defaults.yaml index f235852a..6a588969 100644 --- a/src/earthkit/data/xr_engine/profiles/defaults.yaml +++ b/src/earthkit/data/xr_engine/profiles/defaults.yaml @@ -51,6 +51,9 @@ dim_roles: level_type: vertical.level_type dim_name_from_role_name: true coord_attrs: + member: + standard_name: realization + long_name: ensemble member id latitude: units: degrees_north standard_name: latitude diff --git a/tests/xr_engine/test_xr_engine_attrs.py b/tests/xr_engine/test_xr_engine_attrs.py index a1225b57..1a285258 100644 --- a/tests/xr_engine/test_xr_engine_attrs.py +++ b/tests/xr_engine/test_xr_engine_attrs.py @@ -467,3 +467,131 @@ def test_xr_engine_global_attrs(allow_holes, lazy_load): "centre_fixed": "_ecmf_", } assert ds.attrs == ref_global_attrs + + +@pytest.mark.cache +@pytest.mark.parametrize("lazy_load", [True, False]) +@pytest.mark.parametrize( + "kwargs,expected_coord_attrs", + [ + # Default coord_attrs from defaults.yaml profile for latitude/longitude/step + ( + { + # "profile": "mars", + "time_dims": ["forecast_reference_time", "step"], + "dim_name_from_role_name": True, + }, + { + "member": { + "standard_name": "realization", + "long_name": "ensemble member id", + }, + "latitude": { + "units": "degrees_north", + "standard_name": "latitude", + "long_name": "latitude", + }, + "longitude": { + "units": "degrees_east", + "standard_name": "longitude", + "long_name": "longitude", + }, + "forecast_reference_time": { + "standard_name": "forecast_reference_time", + "long_name": "initial time of forecast", + }, + "step": { + "standard_name": "forecast_period", + "long_name": "time since forecast_reference_time", + }, + }, + ), + # Custom coord_attrs override via kwargs + ( + { + "profile": "mars", + "time_dims": ["forecast_reference_time", "step"], + "dim_name_from_role_name": True, + "coord_attrs": { + "latitude": {"units": "degrees_north", "axis": "Y"}, + "longitude": {"units": "degrees_east", "axis": "X"}, + "step": {"units": "hours", "axis": "T"}, + }, + }, + { + "latitude": {"units": "degrees_north", "axis": "Y"}, + "longitude": {"units": "degrees_east", "axis": "X"}, + "step": {"units": "hours", "axis": "T"}, + "forecast_reference_time": {}, + "member": {}, + }, + ), + # valid_time coord attrs + ( + { + "profile": "mars", + "time_dims": ["valid_time"], + "dim_name_from_role_name": True, + }, + { + "member": { + "standard_name": "realization", + "long_name": "ensemble member id", + }, + "valid_time": { + "standard_name": "time", + "long_name": "valid_time", + }, + "latitude": { + "units": "degrees_north", + "standard_name": "latitude", + "long_name": "latitude", + }, + "longitude": { + "units": "degrees_east", + "standard_name": "longitude", + "long_name": "longitude", + }, + }, + ), + ], +) +def test_xr_engine_coord_attrs(lazy_load, kwargs, expected_coord_attrs): + """Test that coordinate variables have the expected attributes.""" + ds0 = from_source("url", earthkit_remote_test_data_file("xr_engine", "level", "pl_small.grib")).to_fieldlist() + + ds = ds0.to_xarray(lazy_load=lazy_load, ensure_dims="member", **kwargs) + + for coord_name, expected_attrs in expected_coord_attrs.items(): + assert coord_name in ds.coords, f"Coordinate '{coord_name}' not found in dataset" + actual_attrs = dict(ds.coords[coord_name].attrs) + for attr_key, attr_val in expected_attrs.items(): + assert attr_key in actual_attrs, ( + f"Attribute '{attr_key}' not found in coordinate '{coord_name}' attrs: {actual_attrs}" + ) + assert actual_attrs[attr_key] == attr_val, ( + f"Coordinate '{coord_name}' attr '{attr_key}': expected {attr_val!r}, got {actual_attrs[attr_key]!r}" + ) + + +@pytest.mark.cache +@pytest.mark.parametrize("lazy_load", [True, False]) +def test_xr_engine_level_coord_attrs_grib_profile(lazy_load): + """Test that level coordinate has standard CF attributes from the grib profile.""" + ds0 = from_source("url", earthkit_remote_test_data_file("xr_engine", "level", "pl_small.grib")).to_fieldlist() + + ds = ds0.to_xarray( + profile="grib", + level_dim_mode="level", + dim_name_from_role_name=True, + lazy_load=lazy_load, + ) + + # Level coordinate should have attrs derived from the level type (pressure levels) + level_attrs = dict(ds.coords["level"].attrs) + assert "standard_name" in level_attrs, f"Missing 'standard_name' in level coord attrs: {level_attrs}" + assert "units" in level_attrs, f"Missing 'units' in level coord attrs: {level_attrs}" + assert "long_name" in level_attrs, f"Missing 'long_name' in level coord attrs: {level_attrs}" + assert "positive" in level_attrs, f"Missing 'positive' in level coord attrs: {level_attrs}" + assert level_attrs["units"] == "hectopascal" + assert level_attrs["positive"] == "down" diff --git a/tests/xr_engine/test_xr_engine_core.py b/tests/xr_engine/test_xr_engine_core.py index ea5adfe8..0a5801a3 100644 --- a/tests/xr_engine/test_xr_engine_core.py +++ b/tests/xr_engine/test_xr_engine_core.py @@ -111,6 +111,28 @@ def test_xr_engine_detailed_check_1(allow_holes, lazy_load, api): "longitude": lons, } + # coordinate variable attributes + assert ds.coords["latitude"].attrs == { + "standard_name": "latitude", + "long_name": "latitude", + "units": "degrees_north", + } + assert ds.coords["longitude"].attrs == { + "standard_name": "longitude", + "long_name": "longitude", + "units": "degrees_east", + } + assert ds.coords["step"].attrs == { + "standard_name": "forecast_period", + "long_name": "time since forecast_reference_time", + } + assert ds.coords["levelist"].attrs == { + "standard_name": "air_pressure", + "long_name": "pressure", + "units": "hectopascal", + "positive": "down", + } + dims_ref_full = { "date": 2, "time": 2,