Skip to content

Changes to inference_medium_range.py#157

Closed
pgarg7 wants to merge 1 commit into
NVIDIA:mainfrom
pgarg7:fix_io_gc
Closed

Changes to inference_medium_range.py#157
pgarg7 wants to merge 1 commit into
NVIDIA:mainfrom
pgarg7:fix_io_gc

Conversation

@pgarg7

@pgarg7 pgarg7 commented Dec 29, 2023

Copy link
Copy Markdown

Earth-2 MIP Pull Request

Description

I have changed the run_forecast function within the inference_medium_range.py module in order to make sure that deterministic scores are only calculated for variables which are same in model's input and output channels. I added some code in which it finds the indexes of model's in_channels in model's out_channels and then calls the metric on only those indexes. In this way, it's not raising an error related to input/output shape mismatch specially for Graphcast_Operational model.
-- This PR "closes #156"

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@NickGeneva NickGeneva added the 3 - Ready for Review Ready for review by team label Dec 29, 2023
@NickGeneva

Copy link
Copy Markdown
Collaborator

/blossom-ci

#outputs = metric.call(verification_torch, data)
for name, tensor in zip(metric.output_names, outputs):
v = tensor.cpu().numpy()
for c_idx in range(len(model.out_channel_names)):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe out_channels is correct here since we are operating on v which is output data.

Rather we should attempt to fix the problem up at the datasource part where it errors.

nlat = len(model.grid.lat)
channels = [
data_source.channel_names.index(name) for name in model.out_channel_names
data_source.channel_names.index(name) for name in model.in_channel_names

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than switching to input channels, we should maybe instead only select the channels that are in the datasource (I'm not even sure why we do this line here...)

Honestly, maybe just out right eliminating likes 123-125 is the move. I'll check this today and circle back.

verification_torch = verification_torch[:, -1]
for metric in metrics:
outputs = metric.call(verification_torch, data)
indexes = [i for i,name in enumerate(model.out_channel_names) if name in model.in_channel_names]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line needs to be formatted. When installing earth2mip use:

pip install .[dev] or pip install -e .[dev]

This will install dependencies for development. You can manually run formatting / linting using make format and make lint.

Just for future reference. We also use pre-commit to run these checks ahead of time, inside your earth2mip repo:

pip install pre-commit
pre-commit install

@NickGeneva

Copy link
Copy Markdown
Collaborator

Hi @pgarg7

Appreciate the bug report and the PR. I'll take an additional look at this today and hopefully run some tests. Left some initial comments but no rush.

@nbren12

nbren12 commented Jan 2, 2024

Copy link
Copy Markdown
Collaborator

I think some of this is fixed in #154.

@nbren12

nbren12 commented Jan 13, 2024

Copy link
Copy Markdown
Collaborator

#154 fixed this bug, so this PR is no longer needed. Graphcast has an input tp06 that is not available in most data sources. There is an experimental PR (#94) which adds it cds.DataSource, but I didn't think that implementation was robust enough to merge yet.

@nbren12 nbren12 closed this Jan 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

3 - Ready for Review Ready for review by team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛[BUG]: Graphcast Operational Input/Output Channel Mismatch

3 participants