From 0c2789e200c11ccb5b792451de1400489f1f875e Mon Sep 17 00:00:00 2001 From: James Varndell Date: Mon, 28 Apr 2025 10:26:31 +0100 Subject: [PATCH 1/4] Adds flatten argument to to_numpy methods --- src/earthkit/data/wrappers/xarray.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/earthkit/data/wrappers/xarray.py b/src/earthkit/data/wrappers/xarray.py index bdbfb2b2e..fd79680a0 100644 --- a/src/earthkit/data/wrappers/xarray.py +++ b/src/earthkit/data/wrappers/xarray.py @@ -58,14 +58,17 @@ def to_xarray(self, *args, **kwargs): """ return self.data - def to_numpy(self): + def to_numpy(self, flatten=False): """Return a numpy `ndarray` representation of the data. Returns ------- numpy.ndarray """ - return self.data.to_numpy() + arr = self.data.to_numpy() + if flatten: + arr = arr.flatten() + return arr def to_pandas(self, *args, **kwargs): """Return a pandas `dataframe` representation of the data. @@ -91,14 +94,17 @@ class XArrayDatasetWrapper(XArrayDataArrayWrapper): methods. """ - def to_numpy(self): + def to_numpy(self, flatten=False): """Return a numpy `ndarray` representation of the data. Returns ------- numpy.ndarray """ - return self.data.to_array().to_numpy() + arr = self.data.to_array().to_numpy() + if flatten: + arr = arr.flatten() + return arr # def component(self, component): # """ From ec3e2e5a9d53d96c366076c279dfbf5e254b824f Mon Sep 17 00:00:00 2001 From: James Varndell Date: Tue, 29 Apr 2025 11:28:53 +0100 Subject: [PATCH 2/4] Adds test for flatten arg --- tests/wrappers/test_xarray.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/wrappers/test_xarray.py b/tests/wrappers/test_xarray.py index d3e43b55b..fdf7f4cc6 100644 --- a/tests/wrappers/test_xarray.py +++ b/tests/wrappers/test_xarray.py @@ -49,3 +49,22 @@ def test_xarray_lazy_fieldlist_scan(): assert ds._fields is None assert len(ds) == 2 assert len(ds._fields) == 2 + + +@pytest.mark.no_eccodes +def test_xarray_to_numpy(): + import xarray as xr + + data_array = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], + dims=["x", "y"], + coords={"x": [1, 2], "y": [3, 4, 5]}, + ) + ds = from_object(data_array) + + arr_2d = ds.to_numpy() + assert arr_2d.shape == (2, 3) + + arr_1d = ds.to_numpy(flatten=True) + assert arr_1d.shape == (6,) + \ No newline at end of file From 96aca12eeb5a7a4118bd846471ed18455f074e0a Mon Sep 17 00:00:00 2001 From: James Varndell Date: Tue, 29 Apr 2025 13:02:11 +0100 Subject: [PATCH 3/4] QA tweak --- tests/wrappers/test_xarray.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/wrappers/test_xarray.py b/tests/wrappers/test_xarray.py index fdf7f4cc6..3c5efd506 100644 --- a/tests/wrappers/test_xarray.py +++ b/tests/wrappers/test_xarray.py @@ -67,4 +67,3 @@ def test_xarray_to_numpy(): arr_1d = ds.to_numpy(flatten=True) assert arr_1d.shape == (6,) - \ No newline at end of file From f5a80dbeca0f71031aafe774fc5b98f818649990 Mon Sep 17 00:00:00 2001 From: James Varndell Date: Tue, 29 Apr 2025 13:03:54 +0100 Subject: [PATCH 4/4] QA tweak --- tests/wrappers/test_xarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/wrappers/test_xarray.py b/tests/wrappers/test_xarray.py index 3c5efd506..d7466cd25 100644 --- a/tests/wrappers/test_xarray.py +++ b/tests/wrappers/test_xarray.py @@ -61,9 +61,9 @@ def test_xarray_to_numpy(): coords={"x": [1, 2], "y": [3, 4, 5]}, ) ds = from_object(data_array) - + arr_2d = ds.to_numpy() assert arr_2d.shape == (2, 3) - + arr_1d = ds.to_numpy(flatten=True) assert arr_1d.shape == (6,)