Improve lagged ensemble performance and versatility#154
Conversation
fixes bug when in_channels!=out_channels
also fix the lagged ensemble integration test.
all the lags were being collected on all ranks, but this is not necessary. Use point to point comms to only send the data needed for a given rank.
0a4be67 to
6696517
Compare
|
/blossom-ci |
dallasfoster
left a comment
There was a problem hiding this comment.
A few comments, some of which may turn into issues if desired. I think there is in this PR a pattern of code isolation that should be addressed in order to make the repo more modular.
There was a problem hiding this comment.
Can we include a test that does not depend on properscoring?
There is an explicit calculation we do in modulus that could be added here for sanity. https://github.com/NVIDIA/modulus/blob/abf970c8ecc1e681f727e79b999341a94c92c97a/test/metrics/test_metrics_general.py#L233
(Also thinking out loud here, we should test scalar and translation invariance)
There was a problem hiding this comment.
This PR replaces proper-scoring and so I think it makes sense to validate against it. This ensures our old lagged ensemble score are the same as the new ones.
Perhaps we can address when we upstream to modulus.
There was a problem hiding this comment.
I agree that we should validate against it, but I think it is also beneficial to have an additional test that doesn't have it as a dependency. In any case, we can certainly discuss further in the upstreaming PR(s). Perhaps the non-dependency tests can live in modulus (which we would rather not have properscoring as a dependency) and this test can continue to live here.
| class select_channels(Forecast): | ||
| def __init__(self, forecast: Forecast, channel_names: list[str]): | ||
| self.forecast = forecast | ||
| self._channel_names = channel_names | ||
| self._index = [self.forecast.channel_names.index(x) for x in self.channel_names] | ||
|
|
||
| @property | ||
| def channel_names(self): | ||
| return self._channel_names | ||
|
|
||
| @property | ||
| def grid(self): | ||
| return self.forecast.grid | ||
|
|
||
| async def __getitem__(self, i: int): | ||
| async for x in self.forecast[i]: | ||
| yield x[:, self._index] | ||
|
|
||
|
|
There was a problem hiding this comment.
I don't understand the motivation for making this an object, or an object of type Forecast. What is a Forecast and why should a function that selects channels be a type of it? This seems like it should be more of a utility function rather than a class.
There was a problem hiding this comment.
I explain the motivation for the Forecast API here: https://nvidia.github.io/earth2mip/userguide/concepts.html#forecast
Many scoring algorithms are most easily expressed as operations over 2D array of states that we call a Forecast Array. The rows of this array correspond to initial times, and the columns to lead times. The size of this array may be unbounded. For example, computing a lead time dependent metrics, such as RMSE corresponds to averaging the square difference of Forecast Arrays of observations and forecasts, and then averaging over the row dimension. This is defined by the earth2mip.forecasts.Forecast interface. Compared to a TimeLoop, a Forecast encapsulates any time handling and initialization logic. One advantage is that an archive of forecasts on disk can be represented as a Forecast (see earth2mip.forecasts.XarrayForecast). This allows using the same code to score both static and streaming forecasts.
There was a problem hiding this comment.
I chose to use lower case since this is a Callable[[Forecast,...], Forecast that I intended to be used like a function. See #154 (comment) for more info.
I could make it a class ChannelSubsetForecast and write a select_channels(forecast: Forecast, ...) -> ChannelSubsetForecast, but the effect is the same and this is less work.
There was a problem hiding this comment.
This is a lot of code to be placed in a file so non-descript as __main__.py. Can we consider renaming and moving some code for clarity, similar to what you did for score.py?
There was a problem hiding this comment.
This PR improves an implementation, I think renaming is a separate PR/issue.
FWIW, I feel the name python3 -m earth2mip.lagged_ensembles is relatively informative.
There was a problem hiding this comment.
I agree that some naming clean up can be done but thats ingeneral, and maybe names can be given some prefix / suffix for consistency... +1 for a different PR
There was a problem hiding this comment.
I think this comment was more addressing the Observations class and lagged_average_simple function. Perhaps these two functions can be relocated so that the only method in this file is main?
There was a problem hiding this comment.
Also, does it make sense for Observations to be located in lagged_ensembles? This can be an issue and another PR too but I think we should think of another place for such a generic notion.
There was a problem hiding this comment.
Let’s move things in a future PR. This present location of these objects is not something this pr is changing or motivated by.
There was a problem hiding this comment.
Sure, this can be resolved with a link to an issue.
| This is one way to implement the done criteria which is less easily | ||
| parallelized. I am leaving it in the code for educational value only. | ||
| return buffers # noqa | ||
|
|
There was a problem hiding this comment.
This code seems to be misplaced in this file. A lot of the functionality here does not seem to rely on the assumption of lagged ensemble. The design principle should be to move these utilities elsewhere, in a distributed utliity focused area (maybe even modulus.distributed) and to generalize sufficiently so that we can reuse code if necessary.
There was a problem hiding this comment.
Could you be more specific about what code you see as general purpose here?
There was a problem hiding this comment.
_convert_ensemble_to_cpu_asyncseems like a general purpose utility, not specific to lagged ensemble.scatteralso seems to be more general purpose utility. As a side note, if this function is meant to be private then we should mark it so, otherwise it would be helpful to have documentation on the inputs.
There was a problem hiding this comment.
Let’s defer this. This pr is not defining new general purpose utilities for scattering an ensemble distributed across ranks. Just moving code into helpers as appropriate for improved legibility or the lagged ensemble code.
There was a problem hiding this comment.
That's fine, we can resolve with a link to an issue. Perhaps the issue can be broadly centered on distributed utilities.
There was a problem hiding this comment.
on more thought. scatter is specific to the distribution pattern used in this particular script, but I can open an issue for the cpu async.
| def area_average(grid: earth2mip.grid.LatLonGrid, x: torch.Tensor): | ||
| lat = torch.tensor(grid.lat, device=x.device)[:, None] | ||
| cos_lat = torch.cos(torch.deg2rad(lat)) | ||
| return weighted_average(x, cos_lat, dim=[-2, -1]) |
There was a problem hiding this comment.
Why is this here? Modulus has lat and global averaging utilities.
| verification_torch = initial_conditions.get_data_from_source( | ||
| data_source=data_source, | ||
| time=valid_time, | ||
| channel_names=model.out_channel_names, |
There was a problem hiding this comment.
should this be model.in_channel_names as in GC they are not the same? see #157
There was a problem hiding this comment.
No. This just gets verification data. The outputs are compared against verification...not the inputs.
There was a problem hiding this comment.
in graphcast model.out_channel_names includes tp06 that is not in ERA5 so its not clear to me if this would work
There was a problem hiding this comment.
In practice, graphcast is scored against hdf5 data that have this channel. If the output_channels don't exist in the data that is a bug with the data, not with this routine.
| min_lag: int = -2, | ||
| n: int = 10, | ||
| ): | ||
| """Yield centered lagged ensembles |
There was a problem hiding this comment.
should we update the text here? not only centered
|
@dallasfoster Thanks for the discussion. I feel it is good knowledge transfer, but perhaps a bit out of scope. Not sure if you have seen these docs: https://nvidia.github.io/earth2mip/userguide/concepts.html#forecast. If not, I highly recommend reading them to understand the current design of e2mip. This PR is however not the place to revisit all aspects of the design.
Yes, this PR is focused on fixing some bugs and optimizing the lagged ensemble script, and some of the pieces may be useful in general, but that is something we can address later I think. Let's upstream the crps code to modulus. I am okay with duplicating a 1-line helper for weighted averaging a few places in the code. |
|
@nbren12 looks in general fine to me, although right now I'm not fully familiar with the lagged ensembles. But changes outside seem good enough to get things fixed / discuss more later. |
This flag subselects the channels
6696517 to
fc3d3a6
Compare
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
2 similar comments
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
1 similar comment
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
for some reason pytest-regtest was not installed for the parallel job
|
/blossom-ci |
|
The tests are failing when trying to use pytest-regtest: https://github.com/NVIDIA/earth2mip/actions/runs/7493428670/job/20399108089#step:2:1160 I'm not sure why, since the logs show it is being installed. @NickGeneva Do you have any ideas about what the issue is? |
Not sure, when running locally I had to remove |
|
/blossom-ci |
|
Thanks @NickGeneva for the tip, that fixed it. Merging now. |
Earth-2 MIP Pull Request
Description
I have been trying to score graphcast, but because it is big, I was first saving the forecasts to hindcast directory using earht2mip.time_collection instead of scoring it online with earth2mip.lagged_ensembles. Unfortunately, there is some bug such that at the lead time =0 forecast is not the same as the initial condition.
This error is not present in the lagged ensemble script so I would like to use it to score graphcast directly. The lagged ensemble script currently does not handle the grid information properly and the scoring is very slow. This PR addresses both issues.
I add a memory and compute efficient implementation of CRPS in torch. The code relies on xskillscore/properscoring, but when run on the GPU this used the O(n^2) kernel method and runs out of memory when there are many channels. I did not use the modulus version of CRPS because it lacks an efficient yet exact implementation of CRPS. The kernel method is inefficient and the binned method is inexact. A bonus is that I can remove the xskillscores/properscoring dep. This code could be upstreamed to modulus core.
Checklist
Dependencies