Skip to content
94 changes: 0 additions & 94 deletions src/earthkit/data/indexing/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,100 +505,6 @@ def _subset(self, indexes):
ds = self.source[tuple(dataset_indexes)]
return self.from_tensor(self, ds, coords)

def make_valid_datetime(self, dims_map, dtype="datetime64[ns]"):
# TODO: make it more general
# PW: TODO: make it more general - it could allow to use it when allow_holes=True

for k in ["valid_time", "time.valid_datetime", "metadata.valid_time", "metadata.valid_datetime"]:
if k in self.user_coords:
import datetime

return (k,), [datetime.datetime.fromisoformat(x) for x in self.user_coords[k]]

# in the tensor the dims.coords are GRIB keys
# dims_map is a mapping from dim names to GRIB keys
DIM_ROLES = {
"forecast_reference_time": (
"forecast_reference_time",
"time.forecast_reference_time",
"time.base_datetime",
"metadata.base_datetime",
"metadata.indexing_datetime",
"metadata.indexing_time",
),
"step": (
"step",
"time.step",
"metadata.step_timedelta",
"metadata.step",
"metadata.endStep",
"metadata.stepRange",
),
"date": ("date", "metadata.dataDate"),
"time": ("time", "metadata.dataTime"),
}

# map dim roles to keys available in the tensor
keys = {}
for k in DIM_ROLES:
for d in dims_map:
if d.name == k:
keys[k] = d.key
break
if k not in keys:
for d in self.user_dims:
if d in DIM_ROLES[k]:
keys[k] = d
break

DIM_COMBINATIONS = [
["forecast_reference_time", "step"],
["forecast_reference_time"],
["date", "time", "step"],
["date", "time"],
["date", "step"],
["time", "step"],
["step"],
]

for dims in DIM_COMBINATIONS:
if all(d in keys for d in dims):
dims_step = [keys[d] for d in dims]
# use same dim order as in user_dims
dims = [d for d in self.user_dims if d in dims_step]
if len(dims) != len(dims_step):
continue
assert len(dims) == len(dims_step), f"{dims=} {dims_step=}"
other_dims = [d for d in self.user_dims if d not in dims]

if other_dims:
import datetime

import numpy as np

other_coords = {k: next(iter(self.user_coords[k])) for k in other_dims if k in self.user_coords}

vals = np.array(
[x for x in self.source.sel(**other_coords).get("time.valid_datetime")],
dtype=dtype,
)

shape = tuple([self.user_dims[d] for d in dims])
return tuple(dims), vals.reshape(shape)
else:
import datetime

import numpy as np

vals = np.array(
[x for x in self.source.get("time.valid_datetime")],
dtype=dtype,
)

shape = tuple([self.user_dims[d] for d in dims])
return tuple(dims), vals.reshape(shape)
return None, None

def __getstate__(self):
r = {}
r["source"] = self.source
Expand Down
27 changes: 10 additions & 17 deletions src/earthkit/data/xr_engine/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,6 @@ def _make_field_coords(self):
r[k] = xarray.Variable(dims, v, self.profile.attrs.coord_attrs.get(k, {}))
return r

def collect_date_coords(self, tensor):
if (
self.profile.add_valid_time_coord
and "valid_time" not in tensor.user_dims
and "valid_datetime" not in tensor.user_coords
and "valid_time" not in self.tensor_coords
):
from .coord import Coord

_dims, _vals = tensor.make_valid_datetime(self.dims)
if _dims is not None and _vals is not None:
self.tensor_coords["valid_time"] = Coord.make("valid_time", _vals, dims=_dims)

def collect_aux_coords(self):
from .coord import Coord

Expand Down Expand Up @@ -408,9 +395,6 @@ def collect_aux_coords(self):

def build(self):
if self.profile.allow_holes:
if self.profile.add_valid_time_coord:
raise NotImplementedError("add_valid_time_coord=True not yet supported when allow_holes=True")

global_tensor_dims, self.raw_global_tensor_coords, _ = self.prepare_tensor(
self.ds, self.dims, "<ALL VARIABLES>"
)
Expand All @@ -433,6 +417,16 @@ def build(self):
# From now on, self.tensor_coords is a mapping:
# dimension_name->a Coord object + possibly the same for "valid_time"

# Inject valid_time as an auxiliary coordinate when requested and when valid_time is not a dimension
if (
self.profile.add_valid_time_coord
and "valid_time" not in [d.name for d in self.dims]
and "time.valid_datetime" not in self.tensor_coords
):
time_dim_names = self.profile.dims.active_time_dim_names
if time_dim_names:
self.profile.aux_coords.setdefault("valid_time", ("time.valid_datetime", time_dim_names))

self.collect_aux_coords()

# build variable and global attributes
Expand Down Expand Up @@ -518,7 +512,6 @@ def pre_build_variable(self, ds_var, dims, name):
var_dims.append(k)
var_dims.extend(tensor.field_dims)

self.collect_date_coords(tensor)
data_maker = self.build_values
remapping = self.profile.remapping.build()

Expand Down
4 changes: 3 additions & 1 deletion src/earthkit/data/xr_engine/coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def _to_datetime_list(vals):
# datetime64 arrays are already in the required format
if isinstance(vals, np.ndarray):
if not np.issubdtype(vals.dtype, np.datetime64):
return to_datetime_list(vals.tolist())
original_shape = vals.shape
flat = to_datetime_list(vals.flatten().tolist())
return np.array(flat, dtype="datetime64[ns]").reshape(original_shape)
else:
return to_datetime_list(vals)

Expand Down
27 changes: 27 additions & 0 deletions src/earthkit/data/xr_engine/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def _get_metadata_keys(keys):

DATETIME_KEYS = BASE_DATETIME_KEYS + VALID_DATETIME_KEYS

_TIME_RELATED_KEYS = set(DATE_KEYS + TIME_KEYS + STEP_KEYS + MONTH_KEYS + VALID_DATETIME_KEYS + BASE_DATETIME_KEYS)

KEYS = (
ENS_KEYS,
LEVEL_KEYS,
Expand Down Expand Up @@ -1015,6 +1017,31 @@ def make_coords(self):
def to_list(self):
return list(self.dims.values())

@property
def active_time_dim_names(self):
"""Return the names of the active time dimensions in dim order.

Handles both the normal case (dims built via ``TimeDimBuilder``) and
the ``fixed_dims`` case where dim names are raw metadata keys.
"""
time_dim_names = set()
if not self.fixed_dims:
for role_name in self.time_dims:
if role_name in ALL_TIME_ROLES:
# Add the role-resolved name and key
_, name = self.dim_roles.role(role_name, raise_error=False)
if name is not None:
time_dim_names.add(name)
else:
# When fixed_dims are used, `self.time_dims` is irrelevant, and we check all `self.fixed_dims`
# for time-related keys.
for dim_name, dim_key in self.fixed_dims.items():
if dim_key in _TIME_RELATED_KEYS:
time_dim_names.add(dim_name)

# Return in dim order
return [d.name for d in self.dims.values() if d.active and d.name in time_dim_names]

def get_dims(self, names):
r = []
for name in names:
Expand Down
Loading
Loading