diff --git a/src/earthkit/data/core/field.py b/src/earthkit/data/core/field.py index 698969e0..5c4d5ae7 100644 --- a/src/earthkit/data/core/field.py +++ b/src/earthkit/data/core/field.py @@ -1308,7 +1308,7 @@ def metadata( patch=patch, ) - def set(self, *args, **kwargs): + def set(self, *args, sync=False, **kwargs): """Return a new field with the specified metadata keys set to the given values. Parameters @@ -1359,6 +1359,10 @@ def set(self, *args, **kwargs): Values are assumed to be in hours when the unit is not specified. When the unit is specified it can be either "h", "m" or "s" for hours, minutes or seconds, respectively. + sync: bool + When True, try to ensure the raw metadata is in sync with the modified field's components. + It might not be possible and can raise an exception. Same as calling :meth:`sync` on the + new field returned by :meth:`set`. **kwargs: dict Keyword arguments used to specify the metadata keys and values to set. They take precedence over the positional arguments. The same rules for the keys and values @@ -1436,6 +1440,8 @@ def set(self, *args, **kwargs): _kwargs = defaultdict(dict) if not kwargs: + if sync: + return self.sync() return self _components = dict() @@ -1465,7 +1471,10 @@ def set(self, *args, **kwargs): _components[component_name] = s if _components: - return self._from_set(**_components) + f_new = self._from_set(**_components) + if sync: + f_new = f_new.sync() + return f_new elif kwargs: raise ValueError("No valid keys to set in the field.") @@ -1476,8 +1485,14 @@ def _set_values(self, array): return Field.from_field(self, data=data) def sync(self): - """Return a field with the raw metadata in sync with the field's components. + """Return a field with the raw metadata made in sync with the field's components. + When a field is modified using :meth:`set`, the raw metadata (if existing) might become + out of sync with the field's components and some of the raw metadata keys might not be + available anymore. This method tries to synchronize the raw metadata with the field's + components. + + At the moment the raw metadata is only available to fields created from GRIB messages. When a field is created from a GRIB message, it stores this associated GRIB message/handle and the raw GRIB metadata is extracted from it e.g. when calling :meth:`get`. When the field's components are modified using :meth:`set`, the GRIB message is copied into the new field but not @@ -1490,7 +1505,7 @@ def sync(self): Returns ------- Field - A field with the raw metadata in sync with the field's components. If the field is not associated with + A field with the raw metadata made in sync with the field's components. If the field is not associated with a GRIB message or if the raw metadata is already in sync, the original field is returned. Examples diff --git a/tests/grib/test_grib_set.py b/tests/grib/test_grib_set.py index ce373a19..b8ca7f58 100644 --- a/tests/grib/test_grib_set.py +++ b/tests/grib/test_grib_set.py @@ -446,3 +446,40 @@ def test_grib_set_field_sync(fl_type): assert f1.get(("parameter.variable", "metadata.date")) == ("q", 20070101) assert f1.get("labels.my_shape") == (181, 360) assert f1.get("labels.my_name") == "t_500" + + +@pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) +# @pytest.mark.parametrize("fl_type", ["file"]) +def test_grib_set_field_sync_kwarg(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + f = ds_ori[0].set( + { + "parameter.variable": "q", + "vertical.level": 600, + "labels.my_shape": (181, 360), + "labels.my_name": "t_500", + }, + sync=True, + ) + + assert f.get("parameter.variable") == "q" + assert f.get("metadata.shortName") == "q" + assert f.get("vertical.level") == 600 + assert f.get("metadata.levelist") == 600 + assert f.get(("metadata.date", "parameter.variable")) == (20070101, "q") + assert f.get(("parameter.variable", "metadata.date")) == ("q", 20070101) + assert f.get("labels.my_shape") == (181, 360) + assert f.get("labels.my_name") == "t_500" + + # repeated sync should not change anything + for _ in range(2): + f1 = f.sync() + assert f1.get("parameter.variable") == "q" + assert f1.get("metadata.shortName") == "q" + assert f1.get("vertical.level") == 600 + assert f1.get("metadata.levelist") == 600 + assert f1.get(("metadata.date", "parameter.variable")) == (20070101, "q") + assert f1.get(("parameter.variable", "metadata.date")) == ("q", 20070101) + assert f1.get("labels.my_shape") == (181, 360) + assert f1.get("labels.my_name") == "t_500"