diff --git a/src/earthkit/data/wrappers/xarray.py b/src/earthkit/data/wrappers/xarray.py index bdbfb2b2e..fd79680a0 100644 --- a/src/earthkit/data/wrappers/xarray.py +++ b/src/earthkit/data/wrappers/xarray.py @@ -58,14 +58,17 @@ def to_xarray(self, *args, **kwargs): """ return self.data - def to_numpy(self): + def to_numpy(self, flatten=False): """Return a numpy `ndarray` representation of the data. Returns ------- numpy.ndarray """ - return self.data.to_numpy() + arr = self.data.to_numpy() + if flatten: + arr = arr.flatten() + return arr def to_pandas(self, *args, **kwargs): """Return a pandas `dataframe` representation of the data. @@ -91,14 +94,17 @@ class XArrayDatasetWrapper(XArrayDataArrayWrapper): methods. """ - def to_numpy(self): + def to_numpy(self, flatten=False): """Return a numpy `ndarray` representation of the data. Returns ------- numpy.ndarray """ - return self.data.to_array().to_numpy() + arr = self.data.to_array().to_numpy() + if flatten: + arr = arr.flatten() + return arr # def component(self, component): # """ diff --git a/tests/wrappers/test_xarray.py b/tests/wrappers/test_xarray.py index d3e43b55b..d7466cd25 100644 --- a/tests/wrappers/test_xarray.py +++ b/tests/wrappers/test_xarray.py @@ -49,3 +49,21 @@ def test_xarray_lazy_fieldlist_scan(): assert ds._fields is None assert len(ds) == 2 assert len(ds._fields) == 2 + + +@pytest.mark.no_eccodes +def test_xarray_to_numpy(): + import xarray as xr + + data_array = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], + dims=["x", "y"], + coords={"x": [1, 2], "y": [3, 4, 5]}, + ) + ds = from_object(data_array) + + arr_2d = ds.to_numpy() + assert arr_2d.shape == (2, 3) + + arr_1d = ds.to_numpy(flatten=True) + assert arr_1d.shape == (6,)