Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/earthkit/data/indexing/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class _C(PandasMixIn, SimpleFieldList):
def to_xarray(self, *args, **kwargs):
# TODO make it generic
if len(self) > 0:
if self[0]._metadata.data_format() == "grib":
if self[0]._metadata.data_format() in ("grib", "dict"):
from earthkit.data.readers.grib.xarray import XarrayMixIn

class _C(XarrayMixIn, SimpleFieldList):
Expand Down
19 changes: 18 additions & 1 deletion src/earthkit/data/utils/metadata/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def mars_area(self):
def mars_grid(self):
raise NotImplementedError("mars_grid is not implemented for this geography")

def grid_type(self):
return "none"


class UserGeography(Geography):
def __init__(self, metadata, shape=None):
Expand Down Expand Up @@ -252,6 +255,9 @@ def mars_area(self):
def mars_grid(self):
raise NotImplementedError("mars_grid is not implemented for this geography")

def grid_type(self):
return "_unstructured"


class DistinctLLGeography(UserGeography):
def __init__(self, metadata):
Expand Down Expand Up @@ -298,9 +304,11 @@ def shape(self):
Ni = len(self._distinct_longitudes())
return (Nj, Ni)

def grid_type(self):
return "_distinct_ll"

class RegularDistinctLLGeography(DistinctLLGeography):

class RegularDistinctLLGeography(DistinctLLGeography):
def dx(self):
x = self.metadata.get("DxInDegrees", None)
if x is None:
Expand All @@ -326,6 +334,9 @@ def resolution(self):
def mars_grid(self):
return [self.dx(), self.dy()]

def grid_type(self):
return "_regular_ll"


class UserMetadata(Metadata):
ALIASES = [
Expand All @@ -342,6 +353,7 @@ class UserMetadata(Metadata):
"valid_datetime": "valid_datetime",
"step_timedelta": "step_timedelta",
"param_level": "param_level",
"_grid_type": "gridType",
}

LS_KEYS = ["param", "level", "base_datetime", "valid_datetime", "step", "number"]
Expand Down Expand Up @@ -443,6 +455,11 @@ def _datetime(self, date_key, time_key):
def param_level(self):
return f"{self.get('param')}{self.get('level', default='')}"

def _grid_type(self):
if "gridType" in self._data:
return self._data["gridType"]
return self.geography.grid_type()

def _get_one(self, keys):
for k in keys:
if k in self._data:
Expand Down
25 changes: 15 additions & 10 deletions src/earthkit/data/utils/xarray/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(

def build(self, add_earthkit_attrs=True):
if add_earthkit_attrs:
md = self.tensor.source[0].metadata().override()
attrs = {
"message": md._handle.get_buffer(),
"bitsPerValue": md.get("bitsPerValue", 0),
}
self._attrs["_earthkit"] = attrs
if hasattr(self.tensor.source[0], "handle"):
md = self.tensor.source[0].metadata().override()
attrs = {
"message": md._handle.get_buffer(),
"bitsPerValue": md.get("bitsPerValue", 0),
}
self._attrs["_earthkit"] = attrs

self._attrs.update(self.fixed_local_attrs)
data = self.data_maker(self.tensor, self.var_dims, self.name)
Expand Down Expand Up @@ -567,15 +568,19 @@ def parse(self, ds, profile=None, full=False):
def grid(self, ds):
grids = ds.index("md5GridSection")

if len(grids) != 1:
raise ValueError(f"Expected one grid, got {len(grids)}")
grid = grids[0]
if not grids:
grid = "_custom_" + str(id(ds))
else:
# if len(grids) != 1:
# raise ValueError(f"Expected one grid, got {len(grids)}")
grid = grids[0]

key = (grid, self.profile.flatten_values)

if key not in self.grids:
from .grid import TensorGrid

self.grids = {key: TensorGrid(ds[0], self.profile.flatten_values)}
self.grids[key] = TensorGrid(ds[0], self.profile.flatten_values)
return self.grids[key]


Expand Down
21 changes: 18 additions & 3 deletions src/earthkit/data/utils/xarray/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,26 @@
LOG = logging.getLogger(__name__)


# TODO: refactor this when earthkit.geo grid support is implemented
class Grid:
def __init__(self, field):
self.field = field

@staticmethod
def make(field):
# NOTE: underscore grid types are coming from UserMetadata
grid_type = field.metadata("gridType", default=None)
if grid_type == "regular_ll":

if grid_type in ["regular_ll", "_regular_ll"]:
return RegularLLGrid(field)
elif grid_type in ["regular_gg", "mercator"]:
elif grid_type in ["regular_gg", "mercator", "_rectified_ll"]:
return RectifiedLLGrid(field)
elif grid_type in ["sh"]:
return SpectralGrid(field)
elif grid_type is None or grid_type == "none":
return NonGrid(field)
elif grid_type == "_unstructured":
return Grid(field)
else:
return Grid(field)

Expand Down Expand Up @@ -71,13 +78,21 @@ def to_distinct_latlon(self, field_shape):


class SpectralGrid(Grid):
def to_latlon(self):
def to_latlon(self, field_shape=None):
return None, None

def is_spectral(self):
return True


class NonGrid(Grid):
def to_latlon(self, field_shape=None):
return None, None

def is_spectral(self):
return False


class TensorGrid:
def __init__(self, field, flatten_values=False):
self.dims, self.coords, self.coords_dim = self.build(field, flatten_values)
Expand Down
121 changes: 121 additions & 0 deletions tests/xr_engine/test_xr_lod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import os
import sys

import numpy as np
import pytest

from earthkit.data import from_source

here = os.path.dirname(__file__)
sys.path.insert(0, here)
from xr_engine_fixtures import compare_dims # noqa: E402


@pytest.fixture
def xr_lod_latlon():
prototype = {
"latitudes": [10.0, 0.0, -10.0],
"longitudes": [20, 40.0],
"values": [1, 2, 3, 4, 5, 6],
"valid_datetime": "2018-08-01T09:00:00Z",
}

d = [
{"param": "t", "level": 500, **prototype},
{"param": "t", "level": 850, **prototype},
{"param": "u", "level": 500, **prototype},
{"param": "u", "level": 850, **prototype},
]
ds = from_source("list-of-dicts", d)
return ds


@pytest.fixture
def xr_lod_nongeo():
prototype = {
"values": [1, 2, 3, 4, 5, 6],
"valid_datetime": "2018-08-01T09:00:00Z",
}

d = [
{"param": "t", "level": 500, **prototype},
{"param": "t", "level": 850, **prototype},
{"param": "u", "level": 500, **prototype},
{"param": "u", "level": 850, **prototype},
]
ds = from_source("list-of-dicts", d)
return ds


@pytest.fixture
def xr_lod_forecast():
prototype = {
"latitudes": [10.0, 0.0, -10.0],
"longitudes": [20, 40.0],
"values": [1, 2, 3, 4, 5, 6],
"base_datetime": "2018-08-01T09:00:00Z",
}

d = [
{"param": "t", "level": 500, "step": 0, **prototype},
{"param": "t", "level": 500, "step": 6, **prototype},
{"param": "u", "level": 500, "step": 0, **prototype},
{"param": "u", "level": 500, "step": 6, **prototype},
]
ds = from_source("list-of-dicts", d)
return ds


def test_xr_engine_lod_latlon(xr_lod_latlon):
ds_in = xr_lod_latlon
ds = ds_in.to_xarray(time_dim_mode="raw")

assert ds is not None
assert ds["t"].shape == (2, 3, 2)
assert ds["u"].shape == (2, 3, 2)
assert np.allclose(ds["latitude"].values, np.array([10.0, 0.0, -10.0]))
assert np.allclose(ds["longitude"].values, np.array([20.0, 40.0]))


def test_xr_engine_lod_nongeo(xr_lod_nongeo):
ds_in = xr_lod_nongeo
ds = ds_in.to_xarray(time_dim_mode="raw")

assert ds is not None
assert ds["t"].shape == (2, 6)
assert ds["u"].shape == (2, 6)

ref = np.array(
[
[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6],
]
)
assert np.allclose(ds["t"].values, ref)
assert np.allclose(ds["u"].values, ref)


def test_xr_engine_lod_forecast(xr_lod_forecast):
ds_in = xr_lod_forecast
ds = ds_in.to_xarray(time_dim_mode="forecast")

assert ds is not None
assert ds["t"].shape == (2, 3, 2)
assert ds["u"].shape == (2, 3, 2)

dims = {"step": [np.timedelta64(0, "h"), np.timedelta64(6, "h")]}
compare_dims(ds, dims, order_ref_var="t")

assert np.allclose(ds["latitude"].values, np.array([10.0, 0.0, -10.0]))
assert np.allclose(ds["longitude"].values, np.array([20.0, 40.0]))
Loading