Changes to inference_medium_range.py#157
Conversation
|
/blossom-ci |
| #outputs = metric.call(verification_torch, data) | ||
| for name, tensor in zip(metric.output_names, outputs): | ||
| v = tensor.cpu().numpy() | ||
| for c_idx in range(len(model.out_channel_names)): |
There was a problem hiding this comment.
I believe out_channels is correct here since we are operating on v which is output data.
Rather we should attempt to fix the problem up at the datasource part where it errors.
| nlat = len(model.grid.lat) | ||
| channels = [ | ||
| data_source.channel_names.index(name) for name in model.out_channel_names | ||
| data_source.channel_names.index(name) for name in model.in_channel_names |
There was a problem hiding this comment.
Rather than switching to input channels, we should maybe instead only select the channels that are in the datasource (I'm not even sure why we do this line here...)
Honestly, maybe just out right eliminating likes 123-125 is the move. I'll check this today and circle back.
| verification_torch = verification_torch[:, -1] | ||
| for metric in metrics: | ||
| outputs = metric.call(verification_torch, data) | ||
| indexes = [i for i,name in enumerate(model.out_channel_names) if name in model.in_channel_names] |
There was a problem hiding this comment.
I think this line needs to be formatted. When installing earth2mip use:
pip install .[dev] or pip install -e .[dev]
This will install dependencies for development. You can manually run formatting / linting using make format and make lint.
Just for future reference. We also use pre-commit to run these checks ahead of time, inside your earth2mip repo:
pip install pre-commit
pre-commit install|
Hi @pgarg7 Appreciate the bug report and the PR. I'll take an additional look at this today and hopefully run some tests. Left some initial comments but no rush. |
|
I think some of this is fixed in #154. |
Earth-2 MIP Pull Request
Description
I have changed the run_forecast function within the inference_medium_range.py module in order to make sure that deterministic scores are only calculated for variables which are same in model's input and output channels. I added some code in which it finds the indexes of model's in_channels in model's out_channels and then calls the metric on only those indexes. In this way, it's not raising an error related to input/output shape mismatch specially for Graphcast_Operational model.
-- This PR "closes #156"
Checklist
Dependencies