diff --git a/src/earthkit/data/field/component/geography.py b/src/earthkit/data/field/component/geography.py index 0a71eed15..cf6aa1079 100644 --- a/src/earthkit/data/field/component/geography.py +++ b/src/earthkit/data/field/component/geography.py @@ -488,12 +488,8 @@ def set(self, *args, shape_hint=None, **kwargs) -> "GeographyBase": kwargs = self._normalise_set_kwargs(*args, **kwargs) keys = set(kwargs.keys()) - if keys == {"grid_spec"}: - spec = self.from_grid_spec(self, kwargs["grid_spec"]) - return spec - if keys == {"latitudes", "longitudes"}: - spec = self.from_dict(kwargs, shape_hint=shape_hint) - return spec + if keys == {"grid_spec"} or keys == {"latitudes", "longitudes"}: + return self.from_dict(kwargs, shape_hint=shape_hint) raise ValueError(f"Invalid {keys=} for Geography specification") diff --git a/src/earthkit/data/field/geotiff/geography.py b/src/earthkit/data/field/geotiff/geography.py index 4451f5a46..33e608a14 100644 --- a/src/earthkit/data/field/geotiff/geography.py +++ b/src/earthkit/data/field/geotiff/geography.py @@ -86,9 +86,9 @@ def grid(self): def area(self): pass - @classmethod - def from_dict(d): - raise NotImplementedError("XArrayGeography.form_dict() is not implemented") + # @classmethod + # def from_dict(d): + # raise NotImplementedError("XArrayGeography.form_dict() is not implemented") def _grid_mapping(self): if self._ds.rio.grid_mapping == "spatial_ref": diff --git a/src/earthkit/data/field/grib/geography.py b/src/earthkit/data/field/grib/geography.py index 196886438..20528c313 100644 --- a/src/earthkit/data/field/grib/geography.py +++ b/src/earthkit/data/field/grib/geography.py @@ -170,9 +170,9 @@ def grid_type(self): r"""Return the grid type.""" return self.handle.get("gridType", default=None) - @classmethod - def from_dict(*args, **kwargs): - raise NotImplementedError("GribGeography cannot be created from a dictionary") + # @classmethod + # def from_dict(*args, **kwargs): + # raise NotImplementedError("GribGeography cannot be created from a dictionary") # def to_dict(self): # return dict() diff --git a/src/earthkit/data/field/handler/core.py b/src/earthkit/data/field/handler/core.py index 1c51a8bb0..e5a584ff8 100644 --- a/src/earthkit/data/field/handler/core.py +++ b/src/earthkit/data/field/handler/core.py @@ -217,12 +217,14 @@ def __init__(self, component: Any) -> None: def from_dict(cls, d: dict, **kwargs) -> "SimpleFieldComponentHandler": """Create a SimpleFieldComponent object from a dictionary.""" component = cls.COMPONENT_MAKER(d, **kwargs) - return cls(component) + return cls.from_component(component) @classmethod + @abstractmethod def from_component(cls, component: Any) -> "SimpleFieldComponentHandler": """Create a SimpleFieldComponent object from a component object.""" - return cls(component) + # return cls._from_component(component) + pass @classmethod def from_any(cls, data: Any, dict_kwargs=None) -> "SimpleFieldComponentHandler": diff --git a/src/earthkit/data/field/handler/ensemble.py b/src/earthkit/data/field/handler/ensemble.py index 76c3c628d..cdbc6b845 100644 --- a/src/earthkit/data/field/handler/ensemble.py +++ b/src/earthkit/data/field/handler/ensemble.py @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict: COLLECTOR.collect(self, context) + @classmethod + def from_component(cls, component: EnsembleBase) -> "EnsembleFieldComponentHandler": + return EnsembleFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "EnsembleFieldComponentHandler": return EMPTY_ENSEMBLE_HANDLER diff --git a/src/earthkit/data/field/handler/geography.py b/src/earthkit/data/field/handler/geography.py index 8deb789e3..f079244a3 100644 --- a/src/earthkit/data/field/handler/geography.py +++ b/src/earthkit/data/field/handler/geography.py @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict: COLLECTOR.collect(self, context) + @classmethod + def from_component(cls, component: GeographyBase) -> "GeographyFieldComponentHandler": + return GeographyFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "GeographyFieldComponentHandler": return EMPTY_GEOGRAPHY_HANDLER diff --git a/src/earthkit/data/field/handler/parameter.py b/src/earthkit/data/field/handler/parameter.py index e8b652f8d..a315b8eb8 100644 --- a/src/earthkit/data/field/handler/parameter.py +++ b/src/earthkit/data/field/handler/parameter.py @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict: COLLECTOR.collect(self, context) + @classmethod + def from_component(cls, component: ParameterBase) -> "ParameterFieldComponentHandler": + return ParameterFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "ParameterFieldComponentHandler": return EMPTY_PARAMETER_HANDLER diff --git a/src/earthkit/data/field/handler/proc.py b/src/earthkit/data/field/handler/proc.py index 51d317e1c..506611a8e 100644 --- a/src/earthkit/data/field/handler/proc.py +++ b/src/earthkit/data/field/handler/proc.py @@ -27,6 +27,10 @@ def set(self, *args, **kwargs): spec = self._spec.set(*args, **kwargs) return ProcFieldComponentHandler(spec) + @classmethod + def from_component(cls, component: ProcBase) -> "ProcFieldComponentHandler": + return ProcFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "ProcFieldComponentHandler": return EMPTY_PROC_HANDLER diff --git a/src/earthkit/data/field/handler/time.py b/src/earthkit/data/field/handler/time.py index e86c3fc15..097478c30 100644 --- a/src/earthkit/data/field/handler/time.py +++ b/src/earthkit/data/field/handler/time.py @@ -25,6 +25,10 @@ def get_grib_context(self, context) -> dict: COLLECTOR.collect(self, context) + @classmethod + def from_component(cls, component: TimeBase) -> "TimeFieldComponentHandler": + return TimeFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "TimeFieldComponentHandler": return EMPTY_TIME_HANDLER diff --git a/src/earthkit/data/field/handler/vertical.py b/src/earthkit/data/field/handler/vertical.py index 417397f85..aba8b185c 100644 --- a/src/earthkit/data/field/handler/vertical.py +++ b/src/earthkit/data/field/handler/vertical.py @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict: COLLECTOR.collect(self, context) + @classmethod + def from_component(cls, component: VerticalBase) -> "VerticalFieldComponentHandler": + return VerticalFieldComponentHandler(component) + @classmethod def create_empty(cls) -> "VerticalFieldComponentHandler": return EMPTY_VERTICAL_HANDLER diff --git a/src/earthkit/data/field/xarray/geography.py b/src/earthkit/data/field/xarray/geography.py index dd87629bf..db1919156 100644 --- a/src/earthkit/data/field/xarray/geography.py +++ b/src/earthkit/data/field/xarray/geography.py @@ -81,9 +81,9 @@ def grid(self): r"""Return the area of the grid.""" pass - @classmethod - def from_dict(d): - raise NotImplementedError("XArrayGeography.form_dict() is not implemented") + # @classmethod + # def from_dict(*args, **kwargs): + # raise NotImplementedError("XArrayGeography cannot be created from a dictionary") def latlons(self, flatten=False, dtype=None): lat, lon = self.owner.grid.latlons diff --git a/tests/grib/test_grib_set_data.py b/tests/grib/test_grib_set_data.py index dca494858..c349c6913 100644 --- a/tests/grib/test_grib_set_data.py +++ b/tests/grib/test_grib_set_data.py @@ -20,7 +20,7 @@ # @pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) @pytest.mark.parametrize("fl_type", ["file"]) -def test_grib_set_data(fl_type): +def test_grib_set_data_field(fl_type): ds_ori, _ = load_grib_data("test4.grib", fl_type) vals_ori = ds_ori[0].values @@ -76,9 +76,12 @@ def test_grib_set_data(fl_type): # assert grib_md.get("levelist") == 500 # assert grib_md.get("date") == 20070101 - # --------------- - # fieldlist - # --------------- + +@pytest.mark.parametrize("fl_type", ["file"]) +def test_grib_set_data_fieldlist(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + vals_ori = ds_ori[0].values fields = [] for i in range(2): diff --git a/tests/grib/test_grib_set_time.py b/tests/grib/test_grib_set_time.py index afcdc9238..b2ce8ba18 100644 --- a/tests/grib/test_grib_set_time.py +++ b/tests/grib/test_grib_set_time.py @@ -20,7 +20,6 @@ # @pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("write_method", ["target"]) @pytest.mark.parametrize( "_kwargs,ref_set,ref_saved,edition_saved", [ @@ -336,7 +335,7 @@ # ), ], ) -def test_grib_set_time_1(fl_type, write_method, _kwargs, ref_set, ref_saved, edition_saved): +def test_grib_set_time_1(fl_type, _kwargs, ref_set, ref_saved, edition_saved): ds_ori, _ = load_grib_data("test4.grib", fl_type) f = ds_ori[0].set(**_kwargs) diff --git a/tests/grib/test_grib_set_vertical.py b/tests/grib/test_grib_set_vertical.py index c7a2a3e80..e49aa2d18 100644 --- a/tests/grib/test_grib_set_vertical.py +++ b/tests/grib/test_grib_set_vertical.py @@ -19,7 +19,6 @@ # @pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("write_method", ["target"]) @pytest.mark.parametrize( "_kwargs,ref1,grib_ref,ref2", [ @@ -79,7 +78,7 @@ # ), ], ) -def test_grib_set_vertical(fl_type, write_method, _kwargs, ref1, grib_ref, ref2): +def test_grib_set_vertical(fl_type, _kwargs, ref1, grib_ref, ref2): ds_ori, _ = load_grib_data("test4.grib", fl_type) f = ds_ori[0].set(**_kwargs) diff --git a/tests/list_of_dicts/test_lod_set_data.py b/tests/list_of_dicts/test_lod_set_data.py new file mode 100644 index 000000000..fdcc99b5d --- /dev/null +++ b/tests/list_of_dicts/test_lod_set_data.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np +import pytest +from lod_fixtures import build_lod_fieldlist # noqa: E402 + +from earthkit.data import FieldList + + +@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"]) +def test_lod_set_data_field(lod_ll_flat, mode): + ds_ori = build_lod_fieldlist(lod_ll_flat, mode) + + vals_ori = ds_ori[0].values + + f = ds_ori[0].set(values=vals_ori + 1) + + assert f.get("parameter.variable") == "t" + assert f.get("vertical.level") == 500 + assert np.allclose(f.values, vals_ori + 1) + assert np.allclose(ds_ori[0].values, vals_ori) + + # --------------------- + # field - repeated use + # --------------------- + + f = ds_ori[0].set(values=vals_ori + 1) + f = f.set(values=vals_ori + 2) + + assert f.get("parameter.variable") == "t" + assert f.get("vertical.level") == 500 + assert np.allclose(f.values, vals_ori + 2) + assert np.allclose(ds_ori[0].values, vals_ori) + + +@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"]) +def test_lod_set_data_fieldlist(lod_ll_flat, mode): + ds_ori = build_lod_fieldlist(lod_ll_flat, mode) + + vals_ori = ds_ori[0].values + + fields = [] + for i in range(2): + f = ds_ori[i].set(values=vals_ori + i + 1) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.get("parameter.variable") == ["t", "t"] + assert ds.get("vertical.level") == [500, 850] + assert np.allclose(ds[0].values, vals_ori + 1) + assert np.allclose(ds[1].values, vals_ori + 2) diff --git a/tests/list_of_dicts/test_lod_set_geography.py b/tests/list_of_dicts/test_lod_set_geography.py new file mode 100644 index 000000000..adb245e45 --- /dev/null +++ b/tests/list_of_dicts/test_lod_set_geography.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime + +import numpy as np +import pytest +from lod_fixtures import build_lod_fieldlist # noqa: E402 + + +@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"]) +@pytest.mark.parametrize( + "_kwargs", + [ + { + "geography.latitudes": np.array([10.0, 20.0, 30.0]), + "geography.longitudes": np.array([0.0, 10.0, 20.0]), + "values": np.array([1.0, 2.0, 3.0]), + }, + ], +) +def test_lod_set_geo_1(lod_ll_flat, mode, _kwargs): + ds_ori = build_lod_fieldlist(lod_ll_flat, mode) + + f = ds_ori[0].set(**_kwargs) + assert np.allclose(f.get("geography.latitudes"), np.array([10.0, 20.0, 30.0])) + assert np.allclose(f.get("geography.longitudes"), np.array([0.0, 10.0, 20.0])) + assert np.allclose(f.values, np.array([1.0, 2.0, 3.0])) + assert f.get("time.base_datetime") == datetime.datetime(2018, 8, 1, 9, 0) + assert f.get("time.valid_datetime") == datetime.datetime(2018, 8, 1, 9, 0) + assert f.get("time.step") == datetime.timedelta(hours=0) + # assert f.get("time_span") == datetime.timedelta(hours=0) + + # the original field is unchanged + assert ds_ori[0].get("geography.latitudes").shape == (6,) + assert ds_ori[0].get("geography.longitudes").shape == (6,) + assert ds_ori[0].values.shape == (6,) + assert ds_ori[0].get("time.base_datetime") == datetime.datetime(2018, 8, 1, 9, 0) + assert ds_ori[0].get("time.valid_datetime") == datetime.datetime(2018, 8, 1, 9, 0) + assert ds_ori[0].get("time.step") == datetime.timedelta(hours=0) + # assert ds_ori[0].get("time_span") == datetime.timedelta(hours=0) + + +@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"]) +@pytest.mark.parametrize( + "_kwargs,shape,grid_spec,area_1", + [ + ( + { + "geography.grid_spec": {"grid": [5, 5]}, + "values": np.ones((37, 72)), + }, + (37, 72), + {"grid": [5, 5]}, + (90, 0, -90, 360), + ), + ], +) +def test_lod_set_geo_grid_spec(lod_ll_flat, mode, _kwargs, shape, grid_spec, area_1): + # the input is a 1/1 grid + ds_ori = build_lod_fieldlist(lod_ll_flat, mode) + + assert ds_ori[0].shape == (6,) + + f = ds_ori[0].set(**_kwargs) + assert f.shape == shape + assert f.get("geography.shape") == shape + assert f.get("geography.area") == area_1 + assert f.get("geography.grid_spec") == grid_spec + # assert f.get("geography.grid_type") == grid_type + assert f.get("geography.latitudes").shape == shape + assert f.get("geography.longitudes").shape == shape + assert np.allclose(f.to_numpy(), _kwargs["values"]) diff --git a/tests/list_of_dicts/test_lod_set_vertical.py b/tests/list_of_dicts/test_lod_set_vertical.py new file mode 100644 index 000000000..7c0688ce9 --- /dev/null +++ b/tests/list_of_dicts/test_lod_set_vertical.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import pytest +from lod_fixtures import build_lod_fieldlist # noqa: E402 + + +@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"]) +@pytest.mark.parametrize( + "_kwargs,ref", + [ + ( + { + "vertical.level": 320, + "vertical.level_type": "potential_temperature", + }, + { + "vertical.level": 320, + "vertical.level_type": "potential_temperature", + "vertical.units": "K", + "vertical.abbreviation": "pt", + }, + ), + ], +) +def test_lod_set_vertical(lod_ll_flat, mode, _kwargs, ref): + ds = build_lod_fieldlist(lod_ll_flat, mode) + + f = ds[0].set(**_kwargs) + + for k, v in ref.items(): + assert f.get(k) == v + + # the original field is unchanged + assert ds[0].get("vertical.level") == 500 + assert ds[0].get("vertical.level_type") == "unknown" diff --git a/tests/netcdf/test_netcdf_set_data.py b/tests/netcdf/test_netcdf_set_data.py new file mode 100644 index 000000000..a0826f73f --- /dev/null +++ b/tests/netcdf/test_netcdf_set_data.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np +import pytest + +from earthkit.data import FieldList +from earthkit.data.utils.testing import earthkit_test_data_file, load_nc_or_xr_source + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +def test_netcdf_set_data_field(mode): + ds_ori = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + vals_ori = ds_ori[0].values + + f = ds_ori[0].set(values=vals_ori + 1) + + assert f.get("parameter.variable") == "t" + assert f.get("vertical.level") == 500 + assert f.get(("metadata.date", "parameter.variable")) == (None, "t") + assert f.get(("parameter.variable", "metadata.date")) == ("t", None) + assert np.allclose(f.values, vals_ori + 1) + assert np.allclose(ds_ori[0].values, vals_ori) + + # --------------------- + # field - repeated use + # --------------------- + + f = ds_ori[0].set(values=vals_ori + 1) + f = f.set(values=vals_ori + 2) + + assert f.get("parameter.variable") == "t" + assert f.get("vertical.level") == 500 + assert f.get(("metadata.date", "parameter.variable")) == (None, "t") + assert f.get(("parameter.variable", "metadata.date")) == ("t", None) + assert np.allclose(f.values, vals_ori + 2) + assert np.allclose(ds_ori[0].values, vals_ori) + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +def test_netcdf_set_data_fieldlist(mode): + ds_ori = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + vals_ori = ds_ori[0].values + + fields = [] + for i in range(2): + f = ds_ori[i].set(values=vals_ori + i + 1) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.get("parameter.variable") == ["t", "t"] + assert ds.get("vertical.level") == [500, 850] + assert np.allclose(ds[0].values, vals_ori + 1) + assert np.allclose(ds[1].values, vals_ori + 2) diff --git a/tests/netcdf/test_netcdf_set_geography.py b/tests/netcdf/test_netcdf_set_geography.py new file mode 100644 index 000000000..7f5a599bc --- /dev/null +++ b/tests/netcdf/test_netcdf_set_geography.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime + +import numpy as np +import pytest + +from earthkit.data.utils.testing import earthkit_test_data_file, load_nc_or_xr_source + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +@pytest.mark.parametrize( + "_kwargs", + [ + { + "geography.latitudes": np.array([10.0, 20.0, 30.0]), + "geography.longitudes": np.array([0.0, 10.0, 20.0]), + "values": np.array([1.0, 2.0, 3.0]), + }, + ], +) +def test_netcdf_set_geo_1(mode, _kwargs): + ds_ori = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + f = ds_ori[0].set(**_kwargs) + assert np.allclose(f.get("geography.latitudes"), np.array([10.0, 20.0, 30.0])) + assert np.allclose(f.get("geography.longitudes"), np.array([0.0, 10.0, 20.0])) + assert np.allclose(f.values, np.array([1.0, 2.0, 3.0])) + assert f.get("time.base_datetime") == datetime.datetime(2007, 1, 1, 12) + assert f.get("time.valid_datetime") == datetime.datetime(2007, 1, 1, 12) + assert f.get("time.step") == datetime.timedelta(hours=0) + # assert f.get("time_span") == datetime.timedelta(hours=0) + + # the original field is unchanged + assert ds_ori[0].get("geography.latitudes").shape == (181, 360) + assert ds_ori[0].get("geography.longitudes").shape == (181, 360) + assert ds_ori[0].values.shape == (65160,) + assert ds_ori[0].get("time.base_datetime") == datetime.datetime(2007, 1, 1, 12) + assert ds_ori[0].get("time.valid_datetime") == datetime.datetime(2007, 1, 1, 12) + assert ds_ori[0].get("time.step") == datetime.timedelta(hours=0) + # assert ds_ori[0].get("time_span") == datetime.timedelta(hours=0) + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +@pytest.mark.parametrize( + "_kwargs,shape,grid_spec,area_1", + [ + ( + { + "geography.grid_spec": {"grid": [5, 5]}, + "values": np.ones((37, 72)), + }, + (37, 72), + {"grid": [5, 5]}, + (90, 0, -90, 360), + ), + ], +) +def test_netcdf_set_geo_grid_spec(mode, _kwargs, shape, grid_spec, area_1): + # the input is a 1/1 grid + ds_ori = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + assert ds_ori[0].shape == (181, 360) + + f = ds_ori[0].set(**_kwargs) + assert f.shape == shape + assert f.get("geography.shape") == shape + assert f.get("geography.area") == area_1 + assert f.get("geography.grid_spec") == grid_spec + # assert f.get("geography.grid_type") == grid_type + assert f.get("geography.latitudes").shape == shape + assert f.get("geography.longitudes").shape == shape + assert np.allclose(f.to_numpy(), _kwargs["values"]) diff --git a/tests/netcdf/test_netcdf_set_parameter.py b/tests/netcdf/test_netcdf_set_parameter.py new file mode 100644 index 000000000..41382b3f9 --- /dev/null +++ b/tests/netcdf/test_netcdf_set_parameter.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import pytest + +from earthkit.data.utils.testing import earthkit_examples_file, earthkit_test_data_file, load_nc_or_xr_source + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +def test_netcdf_set_parameter_1(mode): + ds = load_nc_or_xr_source(earthkit_test_data_file("test_single.nc"), mode) + f = ds[0] + + assert f.parameter.variable() == "t2m" + assert f.parameter.standard_name() == "unknown" + assert f.parameter.long_name() == "2 metre temperature" + assert f.parameter.param() == "t2m" + assert f.parameter.units() == "K" + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +def test_netcdf_set_parameter_2(mode): + ds = load_nc_or_xr_source(earthkit_examples_file("tuv_pl.grib"), mode) + f = ds[0] + + assert f.parameter.variable() == "t" + assert f.parameter.standard_name() == "air_temperature" + assert f.parameter.long_name() == "Temperature" + assert f.parameter.param() == "t" + assert f.parameter.units() == "K" diff --git a/tests/netcdf/test_netcdf_set_time.py b/tests/netcdf/test_netcdf_set_time.py new file mode 100644 index 000000000..7f3332a90 --- /dev/null +++ b/tests/netcdf/test_netcdf_set_time.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime + +import pytest + +from earthkit.data.utils.testing import earthkit_test_data_file, load_nc_or_xr_source + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +@pytest.mark.parametrize( + "_kwargs,ref_set", + [ + ( + {"time.base_datetime": "2025-08-24T12:00:00", "time.step": 6}, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 18), + "time.step": datetime.timedelta(hours=6), + # "time_span": TimeSpan(), + }, + ), + ( + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=6), + }, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 18), + "time.step": datetime.timedelta(hours=6), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.valid_datetime": "2025-08-24T18:00:00", "time.step": datetime.timedelta(hours=6)}, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 18), + "time.step": datetime.timedelta(hours=6), + # "time_span": TimeSpan(), + }, + ), + ( + { + "time.valid_datetime": datetime.datetime(2025, 8, 24, 18), + "time.step": datetime.timedelta(hours=6), + }, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 18), + "time.step": datetime.timedelta(hours=6), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.base_datetime": "2025-08-24T12:00:00"}, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=0), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.base_datetime": datetime.datetime(2025, 8, 24, 12)}, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=0), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.valid_datetime": "2025-08-24T12:00:00", "time.step": 0}, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=0), + # "time_span": TimeSpan(), + }, + ), + ( + { + "time.valid_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=0), + }, + { + "time.base_datetime": datetime.datetime(2025, 8, 24, 12), + "time.valid_datetime": datetime.datetime(2025, 8, 24, 12), + "time.step": datetime.timedelta(hours=0), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.valid_datetime": "2007-01-03T18:00:00"}, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 3, 18), + "time.step": datetime.timedelta(hours=54), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.valid_datetime": datetime.datetime(2007, 1, 3, 18)}, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 3, 18), + "time.step": datetime.timedelta(hours=54), + # "time_span": TimeSpan(), + }, + ), + ( + { + "time.valid_datetime": datetime.datetime(2007, 1, 3, 18), + # "time_span": TimeSpan(1, TimeSpanMethod.AVERAGE), + }, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 3, 18), + "time.step": datetime.timedelta(hours=54), + # "time_span": TimeSpan(datetime.timedelta(hours=1), TimeSpanMethod.AVERAGE), + }, + ), + ( + {"time.step": datetime.timedelta(hours=6)}, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 1, 18), + "time.step": datetime.timedelta(hours=6), + # "time_span": TimeSpan(), + }, + ), + ( + {"time.step": datetime.timedelta(hours=6, minutes=30)}, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 1, 18, 30), + "time.step": datetime.timedelta(hours=6, minutes=30), + # "time_span": TimeSpan(), + }, + ), + ( + { + "time.step": datetime.timedelta(hours=36), + # "time_span": TimeSpan(datetime.timedelta(days=1), TimeSpanMethod.AVERAGE), + }, + { + "time.base_datetime": datetime.datetime(2007, 1, 1, 12), + "time.valid_datetime": datetime.datetime(2007, 1, 3, 0), + "time.step": datetime.timedelta(hours=36), + # "time_span": TimeSpan(datetime.timedelta(days=1), TimeSpanMethod.AVERAGE), + }, + ), + ], +) +def test_netcdf_set_time_1(mode, _kwargs, ref_set): + ds_ori = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + f = ds_ori[0].set(**_kwargs) + + for k, v in ref_set.items(): + assert f.get(k) == v, f"key {k} expected {v} got {f.get(k)}" + + # original message is unchanged + assert ds_ori[0].get("time.base_datetime") == datetime.datetime(2007, 1, 1, 12) + assert ds_ori[0].get("time.valid_datetime") == datetime.datetime(2007, 1, 1, 12) + assert ds_ori[0].get("time.step") == datetime.timedelta(hours=0) + # assert ds_ori[0].get("time_span") == datetime.timedelta(hours=0) diff --git a/tests/netcdf/test_netcdf_set_vertical.py b/tests/netcdf/test_netcdf_set_vertical.py new file mode 100644 index 000000000..68fe7bda8 --- /dev/null +++ b/tests/netcdf/test_netcdf_set_vertical.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import pytest + +from earthkit.data.utils.testing import earthkit_test_data_file, load_nc_or_xr_source + + +@pytest.mark.parametrize("mode", ["nc", "xr"]) +@pytest.mark.parametrize( + "_kwargs,ref", + [ + ( + { + "vertical.level": 320, + "vertical.level_type": "potential_temperature", + }, + { + "vertical.level": 320, + "vertical.level_type": "potential_temperature", + "vertical.units": "K", + "vertical.abbreviation": "pt", + }, + ), + ], +) +def test_netcdf_set_vertical(mode, _kwargs, ref): + ds = load_nc_or_xr_source(earthkit_test_data_file("test4.nc"), mode) + + f = ds[0].set(**_kwargs) + + for k, v in ref.items(): + assert f.get(k) == v + + # the original field is unchanged + assert ds[0].get("vertical.level") == 500 + assert ds[0].get("vertical.level_type") == "pressure"