Skip to content

Improve lagged ensemble performance and versatility#154

Merged
nbren12 merged 22 commits into
NVIDIA:mainfrom
nbren12:graphcast-debug
Jan 13, 2024
Merged

Improve lagged ensemble performance and versatility#154
nbren12 merged 22 commits into
NVIDIA:mainfrom
nbren12:graphcast-debug

Conversation

@nbren12

@nbren12 nbren12 commented Dec 23, 2023

Copy link
Copy Markdown
Collaborator

Earth-2 MIP Pull Request

Description

I have been trying to score graphcast, but because it is big, I was first saving the forecasts to hindcast directory using earht2mip.time_collection instead of scoring it online with earth2mip.lagged_ensembles. Unfortunately, there is some bug such that at the lead time =0 forecast is not the same as the initial condition.

This error is not present in the lagged ensemble script so I would like to use it to score graphcast directly. The lagged ensemble script currently does not handle the grid information properly and the scoring is very slow. This PR addresses both issues.

I add a memory and compute efficient implementation of CRPS in torch. The code relies on xskillscore/properscoring, but when run on the GPU this used the O(n^2) kernel method and runs out of memory when there are many channels. I did not use the modulus version of CRPS because it lacks an efficient yet exact implementation of CRPS. The kernel method is inefficient and the binned method is inexact. A bonus is that I can remove the xskillscores/properscoring dep. This code could be upstreamed to modulus core.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@nbren12 nbren12 changed the title Graphcast debug Improve lagged ensemble performance and versatility Dec 24, 2023
all the lags were being collected on all ranks, but this is not
necessary.

Use point to point comms to only send the data needed for a given rank.
@nbren12 nbren12 force-pushed the graphcast-debug branch 2 times, most recently from 0a4be67 to 6696517 Compare December 24, 2023 20:57
@nbren12

nbren12 commented Jan 3, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12 nbren12 marked this pull request as ready for review January 3, 2024 16:52
@nbren12 nbren12 linked an issue Jan 3, 2024 that may be closed by this pull request

@dallasfoster dallasfoster left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, some of which may turn into issues if desired. I think there is in this PR a pattern of code isolation that should be addressed in order to make the repo more modular.

Comment thread earth2mip/crps.py
Comment thread test/test_crps.py

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include a test that does not depend on properscoring?

There is an explicit calculation we do in modulus that could be added here for sanity. https://github.com/NVIDIA/modulus/blob/abf970c8ecc1e681f727e79b999341a94c92c97a/test/metrics/test_metrics_general.py#L233

(Also thinking out loud here, we should test scalar and translation invariance)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR replaces proper-scoring and so I think it makes sense to validate against it. This ensures our old lagged ensemble score are the same as the new ones.

Perhaps we can address when we upstream to modulus.

@dallasfoster dallasfoster Jan 4, 2024

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should validate against it, but I think it is also beneficial to have an additional test that doesn't have it as a dependency. In any case, we can certainly discuss further in the upstreaming PR(s). Perhaps the non-dependency tests can live in modulus (which we would rather not have properscoring as a dependency) and this test can continue to live here.

Comment thread earth2mip/forecasts.py
Comment on lines +73 to +91
class select_channels(Forecast):
def __init__(self, forecast: Forecast, channel_names: list[str]):
self.forecast = forecast
self._channel_names = channel_names
self._index = [self.forecast.channel_names.index(x) for x in self.channel_names]

@property
def channel_names(self):
return self._channel_names

@property
def grid(self):
return self.forecast.grid

async def __getitem__(self, i: int):
async for x in self.forecast[i]:
yield x[:, self._index]


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the motivation for making this an object, or an object of type Forecast. What is a Forecast and why should a function that selects channels be a type of it? This seems like it should be more of a utility function rather than a class.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explain the motivation for the Forecast API here: https://nvidia.github.io/earth2mip/userguide/concepts.html#forecast

Many scoring algorithms are most easily expressed as operations over 2D array of states that we call a Forecast Array. The rows of this array correspond to initial times, and the columns to lead times. The size of this array may be unbounded. For example, computing a lead time dependent metrics, such as RMSE corresponds to averaging the square difference of Forecast Arrays of observations and forecasts, and then averaging over the row dimension. This is defined by the earth2mip.forecasts.Forecast interface. Compared to a TimeLoop, a Forecast encapsulates any time handling and initialization logic. One advantage is that an archive of forecasts on disk can be represented as a Forecast (see earth2mip.forecasts.XarrayForecast). This allows using the same code to score both static and streaming forecasts.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to use lower case since this is a Callable[[Forecast,...], Forecast that I intended to be used like a function. See #154 (comment) for more info.

I could make it a class ChannelSubsetForecast and write a select_channels(forecast: Forecast, ...) -> ChannelSubsetForecast, but the effect is the same and this is less work.

Comment thread earth2mip/initial_conditions/__init__.py
Comment thread earth2mip/initial_conditions/__init__.py

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of code to be placed in a file so non-descript as __main__.py. Can we consider renaming and moving some code for clarity, similar to what you did for score.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR improves an implementation, I think renaming is a separate PR/issue.

FWIW, I feel the name python3 -m earth2mip.lagged_ensembles is relatively informative.

@NickGeneva NickGeneva Jan 3, 2024

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that some naming clean up can be done but thats ingeneral, and maybe names can be given some prefix / suffix for consistency... +1 for a different PR

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment was more addressing the Observations class and lagged_average_simple function. Perhaps these two functions can be relocated so that the only method in this file is main?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does it make sense for Observations to be located in lagged_ensembles? This can be an issue and another PR too but I think we should think of another place for such a generic notion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s move things in a future PR. This present location of these objects is not something this pr is changing or motivated by.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this can be resolved with a link to an issue.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one way to implement the done criteria which is less easily
parallelized. I am leaving it in the code for educational value only.
return buffers # noqa

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems to be misplaced in this file. A lot of the functionality here does not seem to rely on the assumption of lagged ensemble. The design principle should be to move these utilities elsewhere, in a distributed utliity focused area (maybe even modulus.distributed) and to generalize sufficiently so that we can reuse code if necessary.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you be more specific about what code you see as general purpose here?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. _convert_ensemble_to_cpu_async seems like a general purpose utility, not specific to lagged ensemble.
  2. scatter also seems to be more general purpose utility. As a side note, if this function is meant to be private then we should mark it so, otherwise it would be helpful to have documentation on the inputs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s defer this. This pr is not defining new general purpose utilities for scattering an ensemble distributed across ranks. Just moving code into helpers as appropriate for improved legibility or the lagged ensemble code.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, we can resolve with a link to an issue. Perhaps the issue can be broadly centered on distributed utilities.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on more thought. scatter is specific to the distribution pattern used in this particular script, but I can open an issue for the cpu async.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread earth2mip/lagged_ensembles/score.py
Comment on lines +15 to +18
def area_average(grid: earth2mip.grid.LatLonGrid, x: torch.Tensor):
lat = torch.tensor(grid.lat, device=x.device)[:, None]
cos_lat = torch.cos(torch.deg2rad(lat))
return weighted_average(x, cos_lat, dim=[-2, -1])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here? Modulus has lat and global averaging utilities.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See reply here: #154 (comment)

Comment thread earth2mip/lagged_ensembles/score.py
verification_torch = initial_conditions.get_data_from_source(
data_source=data_source,
time=valid_time,
channel_names=model.out_channel_names,

@yairchn yairchn Jan 3, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be model.in_channel_names as in GC they are not the same? see #157

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This just gets verification data. The outputs are compared against verification...not the inputs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in graphcast model.out_channel_names includes tp06 that is not in ERA5 so its not clear to me if this would work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, graphcast is scored against hdf5 data that have this channel. If the output_channels don't exist in the data that is a bug with the data, not with this routine.

Comment thread earth2mip/lagged_ensembles/core.py Outdated
min_lag: int = -2,
n: int = 10,
):
"""Yield centered lagged ensembles

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we update the text here? not only centered

@nbren12

nbren12 commented Jan 3, 2024

Copy link
Copy Markdown
Collaborator Author

@dallasfoster Thanks for the discussion. I feel it is good knowledge transfer, but perhaps a bit out of scope. Not sure if you have seen these docs: https://nvidia.github.io/earth2mip/userguide/concepts.html#forecast. If not, I highly recommend reading them to understand the current design of e2mip. This PR is however not the place to revisit all aspects of the design.

A few comments, some of which may turn into issues if desired. I think there is in this PR a pattern of code isolation that should be addressed in order to make the repo more modular.

Yes, this PR is focused on fixing some bugs and optimizing the lagged ensemble script, and some of the pieces may be useful in general, but that is something we can address later I think.

Let's upstream the crps code to modulus. I am okay with duplicating a 1-line helper for weighted averaging a few places in the code.

Comment thread earth2mip/forecasts.py
Comment thread earth2mip/lagged_ensembles/core.py
@NickGeneva

Copy link
Copy Markdown
Collaborator

@nbren12 looks in general fine to me, although right now I'm not fully familiar with the lagged ensembles. But changes outside seem good enough to get things fixed / discuss more later.

@nbren12

nbren12 commented Jan 4, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 4, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 4, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

2 similar comments
@nbren12

nbren12 commented Jan 4, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@NickGeneva

Copy link
Copy Markdown
Collaborator

/blossom-ci

@nbren12

nbren12 commented Jan 4, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

1 similar comment
@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

for some reason pytest-regtest was not installed for the parallel job
@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 11, 2024

Copy link
Copy Markdown
Collaborator Author

The tests are failing when trying to use pytest-regtest: https://github.com/NVIDIA/earth2mip/actions/runs/7493428670/job/20399108089#step:2:1160

I'm not sure why, since the logs show it is being installed.

@NickGeneva Do you have any ideas about what the issue is?

@NickGeneva

Copy link
Copy Markdown
Collaborator

The tests are failing when trying to use pytest-regtest: https://github.com/NVIDIA/earth2mip/actions/runs/7493428670/job/20399108089#step:2:1160

I'm not sure why, since the logs show it is being installed.

@NickGeneva Do you have any ideas about what the issue is?

Not sure, when running locally I had to remove pytest-regtest from the TOML to get pytest to function at all. May need to force a lower version to like 1.5.1. Seems other are having similar issues with the latest release:
https://gitlab.com/uweschmitt/pytest-regtest/-/issues/20

@nbren12

nbren12 commented Jan 13, 2024

Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@nbren12

nbren12 commented Jan 13, 2024

Copy link
Copy Markdown
Collaborator Author

Thanks @NickGeneva for the tip, that fixed it. Merging now.

@nbren12 nbren12 merged commit 225c19a into NVIDIA:main Jan 13, 2024
@nbren12 nbren12 deleted the graphcast-debug branch January 13, 2024 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛[BUG]: Graphcast Operational Input/Output Channel Mismatch

4 participants