Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/earthkit/data/field/component/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,8 @@ def set(self, *args, shape_hint=None, **kwargs) -> "GeographyBase":
kwargs = self._normalise_set_kwargs(*args, **kwargs)
keys = set(kwargs.keys())

if keys == {"grid_spec"}:
spec = self.from_grid_spec(self, kwargs["grid_spec"])
return spec
if keys == {"latitudes", "longitudes"}:
spec = self.from_dict(kwargs, shape_hint=shape_hint)
return spec
if keys == {"grid_spec"} or keys == {"latitudes", "longitudes"}:
return self.from_dict(kwargs, shape_hint=shape_hint)

raise ValueError(f"Invalid {keys=} for Geography specification")

Expand Down
6 changes: 3 additions & 3 deletions src/earthkit/data/field/geotiff/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def grid(self):
def area(self):
pass

@classmethod
def from_dict(d):
raise NotImplementedError("XArrayGeography.form_dict() is not implemented")
# @classmethod
# def from_dict(d):
# raise NotImplementedError("XArrayGeography.form_dict() is not implemented")

def _grid_mapping(self):
if self._ds.rio.grid_mapping == "spatial_ref":
Expand Down
6 changes: 3 additions & 3 deletions src/earthkit/data/field/grib/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def grid_type(self):
r"""Return the grid type."""
return self.handle.get("gridType", default=None)

@classmethod
def from_dict(*args, **kwargs):
raise NotImplementedError("GribGeography cannot be created from a dictionary")
# @classmethod
# def from_dict(*args, **kwargs):
# raise NotImplementedError("GribGeography cannot be created from a dictionary")

# def to_dict(self):
# return dict()
Expand Down
6 changes: 4 additions & 2 deletions src/earthkit/data/field/handler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,14 @@ def __init__(self, component: Any) -> None:
def from_dict(cls, d: dict, **kwargs) -> "SimpleFieldComponentHandler":
"""Create a SimpleFieldComponent object from a dictionary."""
component = cls.COMPONENT_MAKER(d, **kwargs)
return cls(component)
return cls.from_component(component)

@classmethod
@abstractmethod
def from_component(cls, component: Any) -> "SimpleFieldComponentHandler":
"""Create a SimpleFieldComponent object from a component object."""
return cls(component)
# return cls._from_component(component)
pass

@classmethod
def from_any(cls, data: Any, dict_kwargs=None) -> "SimpleFieldComponentHandler":
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict:

COLLECTOR.collect(self, context)

@classmethod
def from_component(cls, component: EnsembleBase) -> "EnsembleFieldComponentHandler":
return EnsembleFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "EnsembleFieldComponentHandler":
return EMPTY_ENSEMBLE_HANDLER
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict:

COLLECTOR.collect(self, context)

@classmethod
def from_component(cls, component: GeographyBase) -> "GeographyFieldComponentHandler":
return GeographyFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "GeographyFieldComponentHandler":
return EMPTY_GEOGRAPHY_HANDLER
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict:

COLLECTOR.collect(self, context)

@classmethod
def from_component(cls, component: ParameterBase) -> "ParameterFieldComponentHandler":
return ParameterFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "ParameterFieldComponentHandler":
return EMPTY_PARAMETER_HANDLER
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def set(self, *args, **kwargs):
spec = self._spec.set(*args, **kwargs)
return ProcFieldComponentHandler(spec)

@classmethod
def from_component(cls, component: ProcBase) -> "ProcFieldComponentHandler":
return ProcFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "ProcFieldComponentHandler":
return EMPTY_PROC_HANDLER
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def get_grib_context(self, context) -> dict:

COLLECTOR.collect(self, context)

@classmethod
def from_component(cls, component: TimeBase) -> "TimeFieldComponentHandler":
return TimeFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "TimeFieldComponentHandler":
return EMPTY_TIME_HANDLER
Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/field/handler/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def get_grib_context(self, context) -> dict:

COLLECTOR.collect(self, context)

@classmethod
def from_component(cls, component: VerticalBase) -> "VerticalFieldComponentHandler":
return VerticalFieldComponentHandler(component)

@classmethod
def create_empty(cls) -> "VerticalFieldComponentHandler":
return EMPTY_VERTICAL_HANDLER
Expand Down
6 changes: 3 additions & 3 deletions src/earthkit/data/field/xarray/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def grid(self):
r"""Return the area of the grid."""
pass

@classmethod
def from_dict(d):
raise NotImplementedError("XArrayGeography.form_dict() is not implemented")
# @classmethod
# def from_dict(*args, **kwargs):
# raise NotImplementedError("XArrayGeography cannot be created from a dictionary")

def latlons(self, flatten=False, dtype=None):
lat, lon = self.owner.grid.latlons
Expand Down
11 changes: 7 additions & 4 deletions tests/grib/test_grib_set_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# @pytest.mark.parametrize("fl_type", ["file", "array", "memory"])
@pytest.mark.parametrize("fl_type", ["file"])
def test_grib_set_data(fl_type):
def test_grib_set_data_field(fl_type):
ds_ori, _ = load_grib_data("test4.grib", fl_type)

vals_ori = ds_ori[0].values
Expand Down Expand Up @@ -76,9 +76,12 @@ def test_grib_set_data(fl_type):
# assert grib_md.get("levelist") == 500
# assert grib_md.get("date") == 20070101

# ---------------
# fieldlist
# ---------------

@pytest.mark.parametrize("fl_type", ["file"])
def test_grib_set_data_fieldlist(fl_type):
ds_ori, _ = load_grib_data("test4.grib", fl_type)

vals_ori = ds_ori[0].values

fields = []
for i in range(2):
Expand Down
3 changes: 1 addition & 2 deletions tests/grib/test_grib_set_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

# @pytest.mark.parametrize("fl_type", ["file", "array", "memory"])
@pytest.mark.parametrize("fl_type", ["file"])
@pytest.mark.parametrize("write_method", ["target"])
@pytest.mark.parametrize(
"_kwargs,ref_set,ref_saved,edition_saved",
[
Expand Down Expand Up @@ -336,7 +335,7 @@
# ),
],
)
def test_grib_set_time_1(fl_type, write_method, _kwargs, ref_set, ref_saved, edition_saved):
def test_grib_set_time_1(fl_type, _kwargs, ref_set, ref_saved, edition_saved):
ds_ori, _ = load_grib_data("test4.grib", fl_type)

f = ds_ori[0].set(**_kwargs)
Expand Down
3 changes: 1 addition & 2 deletions tests/grib/test_grib_set_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

# @pytest.mark.parametrize("fl_type", ["file", "array", "memory"])
@pytest.mark.parametrize("fl_type", ["file"])
@pytest.mark.parametrize("write_method", ["target"])
@pytest.mark.parametrize(
"_kwargs,ref1,grib_ref,ref2",
[
Expand Down Expand Up @@ -79,7 +78,7 @@
# ),
],
)
def test_grib_set_vertical(fl_type, write_method, _kwargs, ref1, grib_ref, ref2):
def test_grib_set_vertical(fl_type, _kwargs, ref1, grib_ref, ref2):
ds_ori, _ = load_grib_data("test4.grib", fl_type)

f = ds_ori[0].set(**_kwargs)
Expand Down
62 changes: 62 additions & 0 deletions tests/list_of_dicts/test_lod_set_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import numpy as np
import pytest
from lod_fixtures import build_lod_fieldlist # noqa: E402

from earthkit.data import FieldList


@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"])
def test_lod_set_data_field(lod_ll_flat, mode):
ds_ori = build_lod_fieldlist(lod_ll_flat, mode)

vals_ori = ds_ori[0].values

f = ds_ori[0].set(values=vals_ori + 1)

assert f.get("parameter.variable") == "t"
assert f.get("vertical.level") == 500
assert np.allclose(f.values, vals_ori + 1)
assert np.allclose(ds_ori[0].values, vals_ori)

# ---------------------
# field - repeated use
# ---------------------

f = ds_ori[0].set(values=vals_ori + 1)
f = f.set(values=vals_ori + 2)

assert f.get("parameter.variable") == "t"
assert f.get("vertical.level") == 500
assert np.allclose(f.values, vals_ori + 2)
assert np.allclose(ds_ori[0].values, vals_ori)


@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"])
def test_lod_set_data_fieldlist(lod_ll_flat, mode):
ds_ori = build_lod_fieldlist(lod_ll_flat, mode)

vals_ori = ds_ori[0].values

fields = []
for i in range(2):
f = ds_ori[i].set(values=vals_ori + i + 1)
fields.append(f)

ds = FieldList.from_fields(fields)

assert ds.get("parameter.variable") == ["t", "t"]
assert ds.get("vertical.level") == [500, 850]
assert np.allclose(ds[0].values, vals_ori + 1)
assert np.allclose(ds[1].values, vals_ori + 2)
81 changes: 81 additions & 0 deletions tests/list_of_dicts/test_lod_set_geography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import datetime

import numpy as np
import pytest
from lod_fixtures import build_lod_fieldlist # noqa: E402


@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"])
@pytest.mark.parametrize(
"_kwargs",
[
{
"geography.latitudes": np.array([10.0, 20.0, 30.0]),
"geography.longitudes": np.array([0.0, 10.0, 20.0]),
"values": np.array([1.0, 2.0, 3.0]),
},
],
)
def test_lod_set_geo_1(lod_ll_flat, mode, _kwargs):
ds_ori = build_lod_fieldlist(lod_ll_flat, mode)

f = ds_ori[0].set(**_kwargs)
assert np.allclose(f.get("geography.latitudes"), np.array([10.0, 20.0, 30.0]))
assert np.allclose(f.get("geography.longitudes"), np.array([0.0, 10.0, 20.0]))
assert np.allclose(f.values, np.array([1.0, 2.0, 3.0]))
assert f.get("time.base_datetime") == datetime.datetime(2018, 8, 1, 9, 0)
assert f.get("time.valid_datetime") == datetime.datetime(2018, 8, 1, 9, 0)
assert f.get("time.step") == datetime.timedelta(hours=0)
# assert f.get("time_span") == datetime.timedelta(hours=0)

# the original field is unchanged
assert ds_ori[0].get("geography.latitudes").shape == (6,)
assert ds_ori[0].get("geography.longitudes").shape == (6,)
assert ds_ori[0].values.shape == (6,)
assert ds_ori[0].get("time.base_datetime") == datetime.datetime(2018, 8, 1, 9, 0)
assert ds_ori[0].get("time.valid_datetime") == datetime.datetime(2018, 8, 1, 9, 0)
assert ds_ori[0].get("time.step") == datetime.timedelta(hours=0)
# assert ds_ori[0].get("time_span") == datetime.timedelta(hours=0)


@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"])
@pytest.mark.parametrize(
"_kwargs,shape,grid_spec,area_1",
[
(
{
"geography.grid_spec": {"grid": [5, 5]},
"values": np.ones((37, 72)),
},
(37, 72),
{"grid": [5, 5]},
(90, 0, -90, 360),
),
],
)
def test_lod_set_geo_grid_spec(lod_ll_flat, mode, _kwargs, shape, grid_spec, area_1):
# the input is a 1/1 grid
ds_ori = build_lod_fieldlist(lod_ll_flat, mode)

assert ds_ori[0].shape == (6,)

f = ds_ori[0].set(**_kwargs)
assert f.shape == shape
assert f.get("geography.shape") == shape
assert f.get("geography.area") == area_1
assert f.get("geography.grid_spec") == grid_spec
# assert f.get("geography.grid_type") == grid_type
assert f.get("geography.latitudes").shape == shape
assert f.get("geography.longitudes").shape == shape
assert np.allclose(f.to_numpy(), _kwargs["values"])
44 changes: 44 additions & 0 deletions tests/list_of_dicts/test_lod_set_vertical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import pytest
from lod_fixtures import build_lod_fieldlist # noqa: E402


@pytest.mark.parametrize("mode", ["list-of-dicts", "loop"])
@pytest.mark.parametrize(
"_kwargs,ref",
[
(
{
"vertical.level": 320,
"vertical.level_type": "potential_temperature",
},
{
"vertical.level": 320,
"vertical.level_type": "potential_temperature",
"vertical.units": "K",
"vertical.abbreviation": "pt",
},
),
],
)
def test_lod_set_vertical(lod_ll_flat, mode, _kwargs, ref):
ds = build_lod_fieldlist(lod_ll_flat, mode)

f = ds[0].set(**_kwargs)

for k, v in ref.items():
assert f.get(k) == v

# the original field is unchanged
assert ds[0].get("vertical.level") == 500
assert ds[0].get("vertical.level_type") == "unknown"
Loading
Loading