diff --git a/docs/examples/index.rst b/docs/examples/index.rst index 6b8574d72..65614ad8c 100644 --- a/docs/examples/index.rst +++ b/docs/examples/index.rst @@ -163,6 +163,7 @@ Xarray engine xarray_engine_to_grib.ipynb xarray_engine_split.ipynb xarray_engine_seasonal.ipynb + xarray_engine_chunks.ipynb Targets and encoders +++++++++++++++++++++ diff --git a/docs/examples/xarray_engine_chunks.ipynb b/docs/examples/xarray_engine_chunks.ipynb new file mode 100644 index 000000000..39eb8f6f9 --- /dev/null +++ b/docs/examples/xarray_engine_chunks.ipynb @@ -0,0 +1,1245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f3568669-9884-491d-8597-5130ad273337", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Xarray engine: chunks" + ] + }, + { + "cell_type": "raw", + "id": "b42eccf8-abcc-44a1-8406-f8aa966b1bf5", + "metadata": { + "editable": true, + "raw_mimetype": "text/x-rst", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "This notebook demonstrates how to use chunking in computations when a GRIB fieldlist is converted to to Xarray with :py:meth:`~data.readers.grib.index.GribFieldList.to_xarray`. Chunking can be used to handle data that does not fit into memory." + ] + }, + { + "cell_type": "markdown", + "id": "8b1ceb8a-967d-4324-9af3-3b6eec468da1", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "First, we get 2m temperature data for a whole year on a low resolution regular latitude-longitude grid. It contains 2 fields per day (at 0 and 12 UTC). This data obviously fit into memory, so only used for demonstration purposes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3a4f7dd0-f443-4cda-8725-cd61927d1409", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98299fdfafa74aa5b8cbc0f95188b8d5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "t2_1_year_hourly.grib: 0%| | 0.00/429k [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray '2t' (valid_time: 732, latitude: 13, longitude: 24)> Size: 2MB\n",
+       "dask.array<open_dataset-2t, shape=(732, 13, 24), dtype=float64, chunksize=(10, 13, 24), chunktype=numpy.ndarray>\n",
+       "Coordinates:\n",
+       "  * valid_time  (valid_time) datetime64[ns] 6kB 2020-01-01 ... 2020-12-31T06:...\n",
+       "  * latitude    (latitude) float64 104B 90.0 75.0 60.0 ... -60.0 -75.0 -90.0\n",
+       "  * longitude   (longitude) float64 192B 0.0 15.0 30.0 ... 315.0 330.0 345.0\n",
+       "Attributes:\n",
+       "    standard_name:  air_temperature\n",
+       "    long_name:      2 metre temperature\n",
+       "    units:          K
" + ], + "text/plain": [ + " Size: 2MB\n", + "dask.array\n", + "Coordinates:\n", + " * valid_time (valid_time) datetime64[ns] 6kB 2020-01-01 ... 2020-12-31T06:...\n", + " * latitude (latitude) float64 104B 90.0 75.0 60.0 ... -60.0 -75.0 -90.0\n", + " * longitude (longitude) float64 192B 0.0 15.0 30.0 ... 315.0 330.0 345.0\n", + "Attributes:\n", + " standard_name: air_temperature\n", + " long_name: 2 metre temperature\n", + " units: K" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = ds_fl.to_xarray(time_dim_mode=\"valid_time\", \n", + " chunks={\"valid_time\": 10}, \n", + " add_earthkit_attrs=False)\n", + "ds[\"2t\"]" + ] + }, + { + "cell_type": "markdown", + "id": "d5caa260-5e6c-432b-96b7-ea84cb261432", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "We compute the mean along the temporal dimension. Xarray will load data in chunks for this computation keeping the memory usage low." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "46e0abd9-7866-4e9f-9c89-c9234b372bc2", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray '2t' (latitude: 13, longitude: 24)> Size: 2kB\n",
+       "array([[259.17798273, 259.17798273, 259.17798273, 259.17798273,\n",
+       "        259.17798273, 259.17798273, 259.17798273, 259.17798273,\n",
+       "        259.17798273, 259.17798273, 259.17798273, 259.17798273,\n",
+       "        259.17798273, 259.17798273, 259.17798273, 259.17798273,\n",
+       "        259.17798273, 259.17798273, 259.17798273, 259.17798273,\n",
+       "        259.17798273, 259.17798273, 259.17798273, 259.17798273],\n",
+       "       [273.2611026 , 275.61228088, 275.48984236, 274.29307835,\n",
+       "        268.16812105, 267.89195131, 264.09208792, 262.4144496 ,\n",
+       "        262.67648853, 261.67375629, 261.81749775, 261.75990725,\n",
+       "        261.65672248, 261.12205718, 260.31177713, 259.69160124,\n",
+       "        259.44480308, 258.91397999, 256.69544345, 261.1351634 ,\n",
+       "        263.80255581, 245.3709899 , 246.22366237, 263.91035124],\n",
+       "       [281.80054932, 277.3069957 , 278.84242945, 276.02408075,\n",
+       "        274.49351381, 274.21627678, 274.1331996 , 272.61215281,\n",
+       "        271.49176346, 269.63533416, 273.44181469, 275.6214595 ,\n",
+       "        276.75602234, 275.00730308, 276.87611285, 273.58944106,\n",
+       "        271.92337207, 268.99705718, 266.13354113, 265.23450595,\n",
+       "        271.60276073, 273.63473648, 279.19937105, 281.8119052 ],\n",
+       "       [284.15830206, 283.85715793, 286.20103601, 283.92187788,\n",
+       "        283.76810397, 284.21051346, 282.39472624, 279.20961695,\n",
+       "...\n",
+       "        283.64451278, 283.35801176, 282.91684031, 282.9554759 ,\n",
+       "        282.14711695, 282.26140144, 281.0409011 , 280.42200595],\n",
+       "       [269.24513607, 270.23454864, 271.63341305, 271.82437105,\n",
+       "        271.60942057, 270.65992432, 270.95410978, 271.93656367,\n",
+       "        273.56741237, 274.12688129, 273.18558177, 275.06365554,\n",
+       "        275.63789564, 274.6907901 , 272.28731504, 273.95432323,\n",
+       "        275.29179762, 275.50107016, 275.76251141, 276.15411831,\n",
+       "        273.34336866, 269.42683006, 268.963758  , 268.6785804 ],\n",
+       "       [234.73941544, 228.72784611, 229.33594038, 225.86723307,\n",
+       "        238.04060226, 241.2718608 , 229.15774707, 226.06100868,\n",
+       "        224.74338573, 228.59588744, 232.7254554 , 258.2240039 ,\n",
+       "        257.83465964, 258.4833106 , 262.08999605, 259.26575595,\n",
+       "        256.58991066, 260.79205376, 251.32129728, 250.12530172,\n",
+       "        253.13652952, 256.13648682, 258.81438129, 254.06318594],\n",
+       "       [227.70048102, 227.70048102, 227.70048102, 227.70048102,\n",
+       "        227.70048102, 227.70048102, 227.70048102, 227.70048102,\n",
+       "        227.70048102, 227.70048102, 227.70048102, 227.70048102,\n",
+       "        227.70048102, 227.70048102, 227.70048102, 227.70048102,\n",
+       "        227.70048102, 227.70048102, 227.70048102, 227.70048102,\n",
+       "        227.70048102, 227.70048102, 227.70048102, 227.70048102]])\n",
+       "Coordinates:\n",
+       "  * latitude   (latitude) float64 104B 90.0 75.0 60.0 45.0 ... -60.0 -75.0 -90.0\n",
+       "  * longitude  (longitude) float64 192B 0.0 15.0 30.0 45.0 ... 315.0 330.0 345.0
" + ], + "text/plain": [ + " Size: 2kB\n", + "array([[259.17798273, 259.17798273, 259.17798273, 259.17798273,\n", + " 259.17798273, 259.17798273, 259.17798273, 259.17798273,\n", + " 259.17798273, 259.17798273, 259.17798273, 259.17798273,\n", + " 259.17798273, 259.17798273, 259.17798273, 259.17798273,\n", + " 259.17798273, 259.17798273, 259.17798273, 259.17798273,\n", + " 259.17798273, 259.17798273, 259.17798273, 259.17798273],\n", + " [273.2611026 , 275.61228088, 275.48984236, 274.29307835,\n", + " 268.16812105, 267.89195131, 264.09208792, 262.4144496 ,\n", + " 262.67648853, 261.67375629, 261.81749775, 261.75990725,\n", + " 261.65672248, 261.12205718, 260.31177713, 259.69160124,\n", + " 259.44480308, 258.91397999, 256.69544345, 261.1351634 ,\n", + " 263.80255581, 245.3709899 , 246.22366237, 263.91035124],\n", + " [281.80054932, 277.3069957 , 278.84242945, 276.02408075,\n", + " 274.49351381, 274.21627678, 274.1331996 , 272.61215281,\n", + " 271.49176346, 269.63533416, 273.44181469, 275.6214595 ,\n", + " 276.75602234, 275.00730308, 276.87611285, 273.58944106,\n", + " 271.92337207, 268.99705718, 266.13354113, 265.23450595,\n", + " 271.60276073, 273.63473648, 279.19937105, 281.8119052 ],\n", + " [284.15830206, 283.85715793, 286.20103601, 283.92187788,\n", + " 283.76810397, 284.21051346, 282.39472624, 279.20961695,\n", + "...\n", + " 283.64451278, 283.35801176, 282.91684031, 282.9554759 ,\n", + " 282.14711695, 282.26140144, 281.0409011 , 280.42200595],\n", + " [269.24513607, 270.23454864, 271.63341305, 271.82437105,\n", + " 271.60942057, 270.65992432, 270.95410978, 271.93656367,\n", + " 273.56741237, 274.12688129, 273.18558177, 275.06365554,\n", + " 275.63789564, 274.6907901 , 272.28731504, 273.95432323,\n", + " 275.29179762, 275.50107016, 275.76251141, 276.15411831,\n", + " 273.34336866, 269.42683006, 268.963758 , 268.6785804 ],\n", + " [234.73941544, 228.72784611, 229.33594038, 225.86723307,\n", + " 238.04060226, 241.2718608 , 229.15774707, 226.06100868,\n", + " 224.74338573, 228.59588744, 232.7254554 , 258.2240039 ,\n", + " 257.83465964, 258.4833106 , 262.08999605, 259.26575595,\n", + " 256.58991066, 260.79205376, 251.32129728, 250.12530172,\n", + " 253.13652952, 256.13648682, 258.81438129, 254.06318594],\n", + " [227.70048102, 227.70048102, 227.70048102, 227.70048102,\n", + " 227.70048102, 227.70048102, 227.70048102, 227.70048102,\n", + " 227.70048102, 227.70048102, 227.70048102, 227.70048102,\n", + " 227.70048102, 227.70048102, 227.70048102, 227.70048102,\n", + " 227.70048102, 227.70048102, 227.70048102, 227.70048102,\n", + " 227.70048102, 227.70048102, 227.70048102, 227.70048102]])\n", + "Coordinates:\n", + " * latitude (latitude) float64 104B 90.0 75.0 60.0 45.0 ... -60.0 -75.0 -90.0\n", + " * longitude (longitude) float64 192B 0.0 15.0 30.0 45.0 ... 315.0 330.0 345.0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m = ds[\"2t\"].mean(dim=\"valid_time\").load()\n", + "m" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9d85ea-a52f-4dc6-b081-7688f9c90536", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev", + "language": "python", + "name": "dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/xarray_engine_temporal.ipynb b/docs/examples/xarray_engine_temporal.ipynb index ad8945b35..1b1a76d57 100644 --- a/docs/examples/xarray_engine_temporal.ipynb +++ b/docs/examples/xarray_engine_temporal.ipynb @@ -2651,9 +2651,9 @@ ], "metadata": { "kernelspec": { - "display_name": "dev_ecc", + "display_name": "dev", "language": "python", - "name": "dev_ecc" + "name": "dev" }, "language_info": { "codemirror_mode": { @@ -2665,7 +2665,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.12" } }, "nbformat": 4, diff --git a/src/earthkit/data/core/config.py b/src/earthkit/data/core/config.py index 7464decb7..8ce1c2f71 100644 --- a/src/earthkit/data/core/config.py +++ b/src/earthkit/data/core/config.py @@ -246,6 +246,11 @@ def validate(self, name, value): fieldlists with data on disk. See :doc:`/guide/misc/grib_memory` for more information.""", ), + "grib-file-serialisation-policy": _( + "path", + """GRIB file serialisation policy for fieldlists with data on disk. {validator}""", + validator=ListValidator(["path", "memory"]), + ), } diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index 899c81131..20c858ae5 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -1093,6 +1093,7 @@ def _vals(f): first = next(it) is_property = not callable(getattr(first, accessor, None)) vals = _vals(first) + first = None ns = array_namespace(vals) shape = (n, *vals.shape) r = ns.empty(shape, dtype=vals.dtype) diff --git a/src/earthkit/data/indexing/tensor.py b/src/earthkit/data/indexing/tensor.py index 4d38c31c3..533f9d4f6 100644 --- a/src/earthkit/data/indexing/tensor.py +++ b/src/earthkit/data/indexing/tensor.py @@ -404,6 +404,13 @@ def field_indexes(self, indexes): assert len(indexes) == len(self._full_shape) return indexes[len(self._user_shape) :] + def is_full_field(self, indexes): + assert len(indexes) == len(self._field_shape) + for i, s in enumerate(indexes): + if not (s is None or s == slice(None, None, None) or s == slice(0, self._field_shape[i], 1)): + return False + return True + def _subset(self, indexes): """Only allow subsetting for the user coordinates. Indices for the field coordinates are ignored. @@ -496,19 +503,26 @@ def make_valid_datetime(self, dtype="datetime64[ns]"): return tuple(dims), vals.reshape(shape) return None, None + def __getstate__(self): + r = {} + r["source"] = self.source + r["user_coords"] = self.user_coords + r["user_shape"] = self.user_shape + r["user_dims"] = self.user_dims + r["field_coords"] = self.field_coords + r["field_shape"] = self.field_shape + r["field_dims"] = self.field_dims + r["full_shape"] = self.full_shape + r["flatten_values"] = self.flatten_values + return r -# class ArrayTensor(TensorCore): -# def __init__(self, array, coords, field_shape): -# self._array = array -# self._coords = coords -# self._shape = self._array.shape -# self._field_shape = field_shape - -# def to_numpy(self, **kwargs): -# return self._array - -# def _subset(self, indexes): -# coords = self._subset_coords(indexes) -# # print(f"{indexes=}") -# data = self._array[indexes] -# return ArrayTensor(data, coords, self.field_shape) + def __setstate__(self, state): + self.source = state["source"] + self._user_coords = state["user_coords"] + self._user_shape = state["user_shape"] + self._user_dims = state["user_dims"] + self._field_coords = state["field_coords"] + self._field_shape = state["field_shape"] + self._field_dims = state["field_dims"] + self._full_shape = state["full_shape"] + self.flatten_values = state["flatten_values"] diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index ed1c4be0f..f7f0ebf88 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -329,6 +329,21 @@ def message(self): def clone(self, **kwargs): return ClonedGribField(self, **kwargs) + def __getstate__(self): + state = super().__getstate__() + state["path"] = self.path + state["offset"] = self._offset + state["length"] = self._length + state["use_metadata_cache"] = self._use_metadata_cache + return state + + def __setstate__(self, state): + self.path = state["path"] + self._offset = state["offset"] + self._length = state["length"] + self._use_metadata_cache = state["use_metadata_cache"] + self._handle_manager = None + class ClonedGribField(ClonedFieldCore, GribField): def __init__(self, field, **kwargs): diff --git a/src/earthkit/data/readers/grib/file.py b/src/earthkit/data/readers/grib/file.py index 8449a1d90..1acf5a00a 100644 --- a/src/earthkit/data/readers/grib/file.py +++ b/src/earthkit/data/readers/grib/file.py @@ -18,7 +18,7 @@ class GRIBReader(GribFieldListInOneFile, Reader): appendable = True # GRIB messages can be added to the same file - def __init__(self, source, path, parts=None): + def __init__(self, source, path, parts=None, positions=None): _kwargs = {} for k in [ # "array_backend", @@ -34,7 +34,7 @@ def __init__(self, source, path, parts=None): raise KeyError(f"Invalid option {k} in GRIBReader. Option names must not contain '-'.") Reader.__init__(self, source, path) - GribFieldListInOneFile.__init__(self, path, parts=parts, **_kwargs) + GribFieldListInOneFile.__init__(self, path, parts=parts, positions=positions, **_kwargs) def __repr__(self): return "GRIBReader(%s)" % (self.path,) @@ -47,24 +47,42 @@ def is_streamable_file(self): return True def __getstate__(self): - r = {"kwargs": self.source._kwargs, "messages": []} - for f in self: - r["messages"].append(f.message()) + from earthkit.data.core.config import CONFIG + + policy = CONFIG.get("grib-file-serialisation-policy") + r = {"serialisation_policy": policy, "kwargs": self.source._kwargs} + + if policy == "path": + r["path"] = self.path + r["positions"] = self._positions + else: + r["messages"] = [f.message() for f in self] + return r def __setstate__(self, state): - from earthkit.data import from_source - from earthkit.data.core.caching import cache_file - - def _create(path, args): - with open(path, "wb") as f: - for message in state["messages"]: - f.write(message) - - path = cache_file( - "GRIBReader", - _create, - [], - ) - ds = from_source("file", path) - self.__init__(ds.source, path) + policy = state["serialisation_policy"] + if policy == "path": + from earthkit.data import from_source + + path = state["path"] + ds = from_source("file", path, **state["kwargs"]) + self.__init__(ds.source, path, positions=state["positions"]) + elif policy == "memory": + from earthkit.data import from_source + from earthkit.data.core.caching import cache_file + + def _create(path, args): + with open(path, "wb") as f: + for message in state["messages"]: + f.write(message) + + path = cache_file( + "GRIBReader", + _create, + [], + ) + ds = from_source("file", path) + self.__init__(ds.source, path) + else: + raise ValueError(f"Unknown serialisation policy {policy}") diff --git a/src/earthkit/data/readers/grib/index/file.py b/src/earthkit/data/readers/grib/index/file.py index 7e602781c..d8e493bfc 100644 --- a/src/earthkit/data/readers/grib/index/file.py +++ b/src/earthkit/data/readers/grib/index/file.py @@ -25,12 +25,12 @@ class GribFieldListInOneFile(GribFieldListInFiles): def availability_path(self): return os.path.join(self.path, ".availability.pickle") - def __init__(self, path, parts=None, **kwargs): + def __init__(self, path, parts=None, positions=None, **kwargs): assert isinstance(path, str), path self.path = path self._file_parts = parts - self.__positions = None + self.__positions = positions super().__init__(**kwargs) @property diff --git a/src/earthkit/data/readers/grib/memory.py b/src/earthkit/data/readers/grib/memory.py index b72b1b304..5087bf22d 100644 --- a/src/earthkit/data/readers/grib/memory.py +++ b/src/earthkit/data/readers/grib/memory.py @@ -168,6 +168,12 @@ def _release(self): def clone(self, **kwargs): return ClonedGribFieldInMemory(self, **kwargs) + def __getstate__(self): + return {"message": self.message()} + + def __setstate__(self, state): + self.__init__(GribCodesHandle.from_message(state["message"])) + class ClonedGribFieldInMemory(ClonedFieldCore, GribFieldInMemory): def __init__(self, field, **kwargs): diff --git a/src/earthkit/data/utils/message.py b/src/earthkit/data/utils/message.py index d925f179a..e9e2e74d6 100644 --- a/src/earthkit/data/utils/message.py +++ b/src/earthkit/data/utils/message.py @@ -201,6 +201,10 @@ def from_sample(cls, name): def _from_raw_handle(cls, handle): return cls(handle, None, None) + @classmethod + def from_message(cls, message): + return cls(eccodes.codes_new_from_message(message), None, None) + # TODO: just a wrapper around the base class implementation to handle the # s,l,d qualifiers. Once these are implemented in the base class this method can # be removed. md5GridSection is also handled! diff --git a/src/earthkit/data/utils/xarray/builder.py b/src/earthkit/data/utils/xarray/builder.py index 4d0276c82..6cb0b24a1 100644 --- a/src/earthkit/data/utils/xarray/builder.py +++ b/src/earthkit/data/utils/xarray/builder.py @@ -8,7 +8,6 @@ # import logging -import threading from abc import ABCMeta from abc import abstractmethod @@ -33,7 +32,14 @@ class VariableBuilder: def __init__( - self, name, var_dims, data_maker, tensor, remapping, local_attr_keys=None, fixed_local_attrs=None + self, + name, + var_dims, + data_maker, + tensor, + remapping, + local_attr_keys=None, + fixed_local_attrs=None, ): """ Create a builder for a single variable in the dataset. @@ -166,18 +172,21 @@ def attrs(self): class TensorBackendArray(xarray.backends.common.BackendArray): - def __init__(self, tensor, dims, shape, xp, dtype, variable): + def __init__(self, tensor, dims, shape, xp, dtype, var_name): super().__init__() self.tensor = tensor self.dims = dims self.shape = shape + self._var_name = var_name # xp and dtype must be set for xarray self.xp = xp if xp is not None else numpy if dtype is None: dtype = numpy.dtype("float64") self.dtype = xp.dtype(dtype) - self.lock = threading.Lock() + from dask.utils import SerializableLock + + self.lock = SerializableLock() @property def nbytes(self): @@ -206,16 +215,20 @@ def __getitem__(self, key: xarray.core.indexing.ExplicitIndexer): def _raw_indexing_method(self, key: tuple): with self.lock: - # print("_var", self._var) - # print(f"dims: {self.dims} key: {key} shape: {self.shape}") - # print(f"t-coords={self.tensor.user_coords}") + # LOG.debug(f"TensorBackendArray._raw_indexing_method var={self._var_name}") + # LOG.debug(f" dims={self.dims} key={key} shape={self.shape}") + # LOG.debug(f" tensor.user_coords={self.tensor.user_coords}") + r = self.tensor[key] - # print(r.source.ls()) - # print(f"r-shape: {r.user_shape}") + # LOG.debug(f" cubelet user_shape={r.user_shape}") + # LOG.debug(f" {r.user_shape=}") field_index = r.field_indexes(key) - # print(f"field.index={field_index} coords={r.user_coords}") - # result = r.to_numpy(index=field_index).squeeze() + if self.tensor.is_full_field(field_index): + field_index = None + + # LOG.debug(f" {field_index=}") + result = r.to_numpy(index=field_index, dtype=self.dtype) # ensure axes are squeezed when needed @@ -223,15 +236,9 @@ def _raw_indexing_method(self, key: tuple): if singles: result = result.squeeze(axis=tuple(singles)) - # print("result", result.shape) - # result = self.ekds.isel(**isels).to_numpy() - - # print("result", result.shape) - # print(f"Loaded {self.xp.__name__} with shape: {result.shape}") + # LOG.debug(f" {result.shape=}") - # Loading as numpy but then converting. This needs to be changed upstream (eccodes) - # to load directly into cupy. - # Maybe some incompatibilities when trying to copy from FFI to cupy directly + # Loading as numpy but then converting to the target array module if self.xp and self.xp != numpy: result = self.xp.asarray(result) @@ -444,6 +451,10 @@ def pre_build_variables(self): def build_values(self, tensor, var_dims, name): """Generate the data object stored in the xarray variable""" + # There is no need for the extra structures in the wrapped source in the + # tensor any longer. It is replaced by the original unwrapped fieldlist. + tensor.source = tensor.source.unwrap() + backend_array = TensorBackendArray( tensor, var_dims, diff --git a/src/earthkit/data/utils/xarray/fieldlist.py b/src/earthkit/data/utils/xarray/fieldlist.py index c4186da90..474cd62e5 100644 --- a/src/earthkit/data/utils/xarray/fieldlist.py +++ b/src/earthkit/data/utils/xarray/fieldlist.py @@ -103,6 +103,12 @@ def __repr__(self) -> str: class XArrayInputFieldList(FieldList): + """ + A wrapper around a fieldlist that stores unique values. + + Only for internal use for building Xarray datasets. + """ + def __init__(self, fieldlist, keys=None, db=None, remapping=None, scan_only=False, component=True): super().__init__() self.ds = fieldlist @@ -228,6 +234,26 @@ def unique_values(self, names, component=False): else: return indices, None + def unwrap(self): + ds = self.ds + while isinstance(ds, XArrayInputFieldList): + ds = ds.ds + return ds + + def __getstate__(self): + """As a simplification, only serialise the unwrapped fieldlist. + We can assume that when there is a need for serialisation the wrapper + structure can be discarded. + """ + r = {} + r["ds"] = self.unwrap() + return r + + def __setstate__(self, state): + self.ds = state["ds"] + self.db = None + self.remapping = None + class ReleasableField: def __init__(self, field): diff --git a/tests/grib/test_grib_serialise.py b/tests/grib/test_grib_serialise.py index 3eb4a75de..9cf0d47ab 100644 --- a/tests/grib/test_grib_serialise.py +++ b/tests/grib/test_grib_serialise.py @@ -16,6 +16,7 @@ import numpy as np import pytest +from earthkit.data import config from earthkit.data import from_source from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.metadata import StandAloneGribMetadata @@ -25,6 +26,7 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_FILE # noqa: E402 from grib_fixtures import FL_NUMPY # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 @@ -215,3 +217,20 @@ def test_grib_serialise_file_parts(): assert len(ds2) == 1 assert ds2[0].metadata(["param", "level"]) == ["u", 1000] + + +@pytest.mark.parametrize("fl_type", FL_FILE) +@pytest.mark.parametrize("representation", ["file", "memory"]) +@pytest.mark.parametrize("policy", ["path", "memory"]) +def test_grib_serialise_policy(fl_type, representation, policy): + ds, _ = load_grib_data("test.grib", fl_type) + + with config.temporary({"grib-file-serialisation-policy": policy}): + ds2 = _pickle(ds, representation) + + assert len(ds2) == len(ds) + assert ds2.values.shape == ds.values.shape + if policy == "path": + assert ds2.path == ds.path + else: + assert ds2.path != ds.path diff --git a/tests/xr_engine/test_xr_builder.py b/tests/xr_engine/test_xr_builder.py new file mode 100644 index 000000000..d004a3cc2 --- /dev/null +++ b/tests/xr_engine/test_xr_builder.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os +import pickle +import sys + +import pytest + +from earthkit.data import from_source +from earthkit.data.core.temporary import temp_file +from earthkit.data.testing import earthkit_remote_test_data_file + +here = os.path.dirname(__file__) +sys.path.insert(0, here) + +# Testing internal structures in the xarray engine + + +def _pickle(data, representation): + if representation == "file": + with temp_file() as tmp: + with open(tmp, "wb") as f: + pickle.dump(data, f) + + with open(tmp, "rb") as f: + data_res = pickle.load(f) + elif representation == "memory": + pickled_data = pickle.dumps(data) + data_res = pickle.loads(pickled_data) + else: + raise ValueError(f"Invalid representation: {representation}") + return data_res + + +@pytest.mark.cache +@pytest.mark.parametrize("representation", ["file", "memory"]) +def test_xr_engine_builder_fieldlist(representation): + ds_in = from_source("url", earthkit_remote_test_data_file("test-data/xr_engine/level/pl_small.grib")) + + from earthkit.data.utils.xarray.fieldlist import XArrayInputFieldList + + r = XArrayInputFieldList(ds_in) + assert not isinstance(r.ds, XArrayInputFieldList) + assert r.unwrap() is ds_in + r_p = _pickle(r, representation) + assert r_p is not r + assert r_p.ds is not r.ds + assert r_p.ds.metadata("time", astype=int) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + 1200, + ] + + r0 = r.sel(param="t", level=500) + assert not isinstance(r0.ds, XArrayInputFieldList) + assert len(r0) == 8 + assert r0.ds.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + r0_uw = r0.unwrap() + assert not isinstance(r0_uw, XArrayInputFieldList) + assert r0_uw.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + r0_p = _pickle(r0, representation) + assert r0_p is not r0 + assert r0_p.ds is not r0.ds + assert r0_p.ds.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + + r1 = r0.order_by("time") + assert not isinstance(r1.ds, XArrayInputFieldList) + assert r1.ds.metadata("time", astype=int) == [0, 0, 0, 0, 1200, 1200, 1200, 1200] + r1_uw = r1.unwrap() + assert not isinstance(r1_uw, XArrayInputFieldList) + assert r1_uw.metadata("time", astype=int) == [0, 0, 0, 0, 1200, 1200, 1200, 1200] + r1_p = _pickle(r1, representation) + assert r1_p is not r1 + assert r1_p.ds is not r1.ds + assert r1_p.ds.metadata("time", astype=int) == [0, 0, 0, 0, 1200, 1200, 1200, 1200] + + r2 = r1.order_by("step") + assert not isinstance(r2.ds, XArrayInputFieldList) + assert r2.ds.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + assert r2.ds.metadata("step", astype=int) == [0, 0, 0, 0, 6, 6, 6, 6] + r2_uw = r2.unwrap() + assert not isinstance(r2_uw, XArrayInputFieldList) + assert r2_uw.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + assert r2_uw.metadata("step", astype=int) == [0, 0, 0, 0, 6, 6, 6, 6] + r2_p = _pickle(r2, representation) + assert r2_p is not r2 + assert r2_p.ds is not r2.ds + assert r2_p.ds.metadata("time", astype=int) == [0, 0, 1200, 1200, 0, 0, 1200, 1200] + assert r2_p.ds.metadata("step", astype=int) == [0, 0, 0, 0, 6, 6, 6, 6] diff --git a/tests/xr_engine/test_xr_chunks.py b/tests/xr_engine/test_xr_chunks.py index b5bf4267d..92bace7ab 100644 --- a/tests/xr_engine/test_xr_chunks.py +++ b/tests/xr_engine/test_xr_chunks.py @@ -96,7 +96,6 @@ def test_xr_engine_chunk_2(_kwargs): assert np.isclose(r.values.mean(), 275.9938876277779) -@pytest.mark.skipif(True, reason="Needs to be fixed") @pytest.mark.cache @pytest.mark.parametrize( "_kwargs", @@ -106,6 +105,7 @@ def test_xr_engine_chunk_2(_kwargs): {"chunks": {"valid_time": 1}}, {"chunks": {"valid_time": 10}}, {"chunks": {"valid_time": (100, 200, 432), "latitude": (4, 5, 4), "longitude": (13, 3, 8)}}, + {"chunks": {"valid_time": 100, "latitude": 4, "longitude": 7}}, {"chunks": -1}, ], )